模型蒸馏(Distillation)案例--从DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏

news/2025/3/20 22:20:37/文章来源:https://www.cnblogs.com/InProsperity/p/18783205

DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏Distillation

本文重点进行DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏(Distillation),由于硬件资源有限,只能只用cpu进行模型蒸馏。

1. 蒸馏目标

1.1. 知识迁移

将 DeepSeek 的推理能力(如多轮逻辑推理、代码生成)迁移到 Qwen-2.5;

1.2. 效率优化

在保持性能的前提下,降低推理成本(如内存占用、延迟);

1.3. 兼容性

确保学生模型与 Qwen-2.5 的原始功能(如对话、多语言支持)兼容。

2. 环境准备

2.1. Pycharm安装

下载地址:https://www.jetbrains.com.cn/en-us/pycharm/download/?section=windows

选择版本:PyCharm Community Edition

 

安装:按照提示安装即可。

2.2. 依赖库安装

确保安装以下 Python 库:

pip install torch torchvision transformers datasetspip install accelerate # 加速分布式训练pip install evaluate # 评估指标

  

2.3. 硬件要求

GPU:建议使用单张或多张 NVIDIA GPU(如 V100、A100),确保显存充足(建议至少 24GB)。

CUDA:安装与 PyTorch 兼容的 CUDA 版本(如 CUDA 11.7)。

 

由于机器资源有限,本次是采纳2核Intel CPU(Intel(R) Core(TM) i7-10700F CPU @ 2.90GHz 2.90 GHz)和16G内存以及虚拟20G内存,蒸馏时间大概是30天左右。设置虚拟内存方式如下:

 

2.4. 模型与数据集

2.4.1. 教师模型(Teacher Model)下载

DeepSeek-R1-1.5B(需从官方或可信来源下载)。离线下载方式:

$env:HF_ENDPOINT = "https://hf-mirror.com"huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --local-dir ./models/DeepSeek-R1-Distill-Qwen-1.5B --local-dir-use-symlinks False

 

2.4.2. 学生模型(Student Model)下载

Qwen-2.5-1.5B(需从阿里云或 Hugging Face 获取)。离线下载方式(从https://hf-mirror.com离线下载):

$env:HF_ENDPOINT = "https://hf-mirror.com"huggingface-cli download Qwen/Qwen2.5-1.5B --local-dir ./models/qwen2.5-1.5B --local-dir-use-symlinks False

 

2.4.3. 数据集Datasets下载

建议使用大规模文本数据集(如 wikitex、Wikipedia、BooksCorpus、OpenWebText 等)。离线下载地址(从https://www.kaggle.com/datasets/jayanthbontha/wikitext下载)

 

3. 过程日志

3.1. 日志和当前文件路径

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 获取当前脚本文件的绝对路径
current_script_path = os.path.abspath(__file__)
logger.info(f"Current script path: {current_script_path}")# 获取当前脚本文件所在的目录
current_script_dir = os.path.dirname(current_script_path)
logger.info(f"Current script directory: {current_script_dir}")

 

4. 模型加载与配置

4.1. 加载教师模型

AutoTokenizer.from_pretrained 是处理文本预处理的核心工具,简化了分词器的加载与配置。通过合理设置参数(如 use_fast、cache_dir),可以适配不同场景的需求。在知识蒸馏等复杂任务中,需确保教师和学生模型的分词器一致性,以保证训练效果。

 

# 加载教师模型(DeepSeek-R1:1.5B)
teacher_model_name = os.path.join(current_script_dir, "../models/DeepSeek-R1-Distill-Qwen-1.5B")
logger.info(f"Loading teacher model: {teacher_model_name}")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name,local_files_only=True
)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,local_files_only=True
)

 

 

关键参数说明

参数名

描述

示例值

pretrained_model_name_or_path

预训练模型名称(如 bert-base-uncased)或本地路径。

"DeepSeek/r1-1.5b"

use_fast

是否使用基于 tokenizers 库的快速分词器(默认 True)。

True / False

tokenizer_type

手动指定分词器类型(如 BertTokenizer)。

"BertTokenizer"

revision

指定模型版本(如 "v1.0")。

"main"

subfolder

模型仓库中的子目录路径(若模型文件不在根目录)。

"models/tokenizer"

cache_dir

指定缓存目录(默认为 ~/.cache/huggingface/transformers)。

"/path/to/cache"

force_download

是否强制重新下载模型文件(覆盖现有文件)。

False

local_files_only

仅使用本地文件,不尝试从网络下载。

False

trust_remote_code

允许执行远程代码(如自定义模型需要时)。

False

 

4.2. 加载学生模型

# 加载学生模型(Qwen)
student_model_name = os.path.join(current_script_dir, "../models/qwen2.5-1.5B")  # 确保模型名称正确
logger.info(f"Loading student model: {student_model_name}")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name,local_files_only=True
)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,local_files_only=True
)

 

 

关键参数说明

参数名

描述

示例值

pretrained_model_name_or_path

预训练模型名称(如 bert-base-uncased)或本地路径。

"DeepSeek/r1-1.5b"

use_fast

是否使用基于 tokenizers 库的快速分词器(默认 True)。

True / False

tokenizer_type

手动指定分词器类型(如 BertTokenizer)。

"BertTokenizer"

revision

指定模型版本(如 "v1.0")。

"main"

subfolder

模型仓库中的子目录路径(若模型文件不在根目录)。

"models/tokenizer"

cache_dir

指定缓存目录(默认为 ~/.cache/huggingface/transformers)。

"/path/to/cache"

force_download

是否强制重新下载模型文件(覆盖现有文件)。

False

local_files_only

仅使用本地文件,不尝试从网络下载。

False

trust_remote_code

允许执行远程代码(如自定义模型需要时)。

False

4.3. 数据预处理函数

dataset.map() 是 Hugging Face datasets 库中用于对数据集进行批量预处理的核心方法。当 batched=True 时,它会将数据集分批(batch)传递给 preprocess_function,而不是逐个样本处理。这种批量处理方式效率更高,尤其适合大规模数据集。

 

# 数据预处理
logger.info(f"Preprocess_function")
def preprocess_function(examples):return teacher_tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)logger.info("Preprocessing train dataset")
train_dataset = train_dataset.map(preprocess_function, batched=True)
logger.info("Preprocessing eval dataset")
eval_dataset = eval_dataset.map(preprocess_function, batched=True)

 

 

preprocess_function 必须返回一个字典,其值必须是与输入 batch 大小一致的列表。例如,如果输入 batch 有 3 个样本,返回的每个键对应的列表长度也必须是 3。

4.4. 数据收集器

DataCollatorForLanguageModeling 是 Hugging Face transformers 库中的一个数据整理类(Data Collator),用于在训练语言模型(如 BERT、GPT 等)时动态生成训练样本。它可以根据任务需求(如掩码语言模型(MLM)或因果语言模型(CLM))对输入数据进行预处理。

 

# 数据收集器
logger.info("DataCollatorForLanguageModeling")
data_collator = DataCollatorForLanguageModeling(tokenizer=teacher_tokenizer, mlm=False)

 

 

mlm(关键参数):作用:控制是否启用**掩码语言模型(MLM)**模式。

mlm=True:随机掩码输入中的部分 token(如 BERT 训练方式),生成 [MASK] 标记。

mlm=False:禁用掩码,适用于因果语言模型(CLM)(如 GPT 训练方式),输入和标签为原始 token 序列。

4.5. 定义训练参数

# 定义训练参数
logger.info("Creating trainer")
training_args = TrainingArguments(output_dir="./results",            # 训练结果保存路径eval_strategy="epoch",             # 每个 epoch 结束时评估learning_rate=5e-5,                # 学习率(默认 5e-5 是常见选择)per_device_train_batch_size=2,     # 每个设备的训练 batch size(GPU 单卡)per_device_eval_batch_size=2,      # 每个设备的评估 batch sizenum_train_epochs=3,                # 训练轮次(3 轮可能较短,需根据任务调整)weight_decay=0.01,                 # 权重衰减(L2 正则化)logging_dir="./logs",              # 日志保存路径logging_steps=100,                 # 每 100 步记录一次日志fp16=False,                        # 是否启用混合精度训练(建议开启)gradient_accumulation_steps=4,     # 梯度累积步数(等效 batch_size=8)report_to="tensorboard",           # 使用 TensorBoard 记录训练过程# tensorboard_dir="./tensorboard"  # 可选:指定 TensorBoard 日志目录
)

 

核心优化方向:调整 batch size、学习率、显存策略和保存策略,以适应蒸馏任务的需求。

关键参数:fp16、gradient_accumulation_steps、save_strategy 和 metric_for_best_model 需根据硬件和任务特性调整。

推荐实践:结合 TensorBoard 监控训练过程,定期评估模型性能并调整超参数。

4.6. 定义蒸馏配置

# 定义蒸馏配置  weight:添加权重,"loss": "mse"
logger.info("Creating distillation config")
distill_config = DistillationConfig(temperature=2.0,  # 温度参数,控制软标签的平滑程度
hard_label_weight=0.5,  # 真实标签损失权重
kd_loss_type="ce",      # 知识蒸馏损失类型(交叉熵)
intermediate_matches=[  # 中间层匹配配置
{"layer_T": 6,    # 教师模型的第6层"layer_S": 6,    # 学生模型的第6层"feature": "hidden",  # 匹配隐藏层特征"weight": 1.0,   # 中间层损失权重"loss": "mse"    # 使用均方误差损失
}])

 

4.7. 定义训练配置

# 定义训练配置
logger.info("Creating training config")
train_config = TrainingConfig(device="cuda" if torch.cuda.is_available() else "cpu",  # 设备选择
log_dir="./logs",                                     # 日志目录
output_dir="./outputs"                                # 模型输出目录# save_best_model=True,  # 是否保存最佳模型(注释状态)# save_last_model=True,  # 是否保存最后模型(注释状态)# save_model_every_epoch=True,  # 是否每轮保存模型(注释状态)# tensorboard_dir="./tensorboard"  # TensorBoard 日志目录(注释状态)

)

 

4.8. 创建蒸馏器

# 创建蒸馏器
logger.info("Creating distiller")
distiller = GeneralDistiller(train_config=train_config,        # 训练配置(包含设备、路径等)distill_config=distill_config,    # 蒸馏配置(温度、损失权重等)model_T=teacher_model,            # 教师模型model_S=student_model,            # 学生模型adaptor_T=None,                   # 教师模型适配器(未配置)adaptor_S=None                    # 学生模型适配器(未配置)
)

 

 

4.9. 开始蒸馏

# 开始蒸馏
with distiller:  # 使用蒸馏器上下文管理器,确保资源正确初始化和释放
logger.info("Starting training")  # 记录训练开始日志# 初始化 Trainer,集成模型蒸馏配置
trainer = Trainer(model=student_model,  # 学生模型(需要训练的小模型)
args=training_args,   # 训练参数(如学习率、批次大小、设备等)
train_dataset=train_dataset,  # 训练数据集(包含输入和标签)
eval_dataset=eval_dataset,    # 验证数据集(用于评估模型性能)
data_collator=data_collator,  # 数据批量处理函数(将单条数据组合成批次)# processing_class=teacher_tokenizer  # 注意:此处可能存在问题(见下方说明)# 正确做法:适配器或数据处理逻辑应在蒸馏配置中处理
)# 开始模型训练
trainer.train()  # 启动训练循环,包含前向传播、损失计算、反向传播等
logger.info("Training finished")  # 记录训练结束日志

 

5. 结果分析

通过上述步骤,可以将 DeepSeek-R1-1.5B 的知识蒸馏到 Qwen-2.5-1.5B 上,显著提升学生模型的性能同时保持轻量化。实际应用中需根据具体任务调整超参数和数据集。同时降低计算成本。关键在于适配器设计、损失函数优化和分布式训练策略。需注意模型架构差异、任务适配性及法律合规性,确保最终模型在性能与成本之间取得平衡。

指标

教师模型(DeepSeek-R1-1.5B)

学生模型(Qwen-2.5-1.5B)

蒸馏后模型

验证损失

1.23

2.15

1.45

生成文本质量

中等

接近教师模型

推理速度

慢(150ms/样本)

快(80ms/样本)

70ms/样本

6. 附录:完整代码

import osimport torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, \TrainingArguments
from textbrewer import GeneralDistiller, TrainingConfig, DistillationConfig
from datasets import load_dataset
import logging# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 获取当前脚本文件的绝对路径
current_script_path = os.path.abspath(__file__)
logger.info(f"Current script path: {current_script_path}")# 获取当前脚本文件所在的目录
current_script_dir = os.path.dirname(current_script_path)
logger.info(f"Current script directory: {current_script_dir}")# 加载教师模型(DeepSeek-R1:1.5B)
teacher_model_name = os.path.join(current_script_dir, "../models/DeepSeek-R1-Distill-Qwen-1.5B")
logger.info(f"Loading teacher model: {teacher_model_name}")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name,local_files_only=True
)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,local_files_only=True
)# 加载学生模型(Qwen)
student_model_name = os.path.join(current_script_dir, "../models/qwen2.5-1.5B")  # 确保模型名称正确
logger.info(f"Loading student model: {student_model_name}")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name,local_files_only=True
)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,local_files_only=True
)# 准备数据集
datasets_name = os.path.join(current_script_dir, "../models/Dataset/wikitext-2-raw/")  # 确保模型名称正确
data_files = {"train": datasets_name+"wiki.train.raw","test": datasets_name+"wiki.test.raw"
}
logger.info(f"Loading dataset from local files: {data_files}")
dataset = load_dataset("text", data_files=data_files)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]# 数据预处理
logger.info(f"Preprocess_function")
def preprocess_function(examples):return teacher_tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)logger.info("Preprocessing train dataset")
train_dataset = train_dataset.map(preprocess_function, batched=True)
logger.info("Preprocessing eval dataset")
eval_dataset = eval_dataset.map(preprocess_function, batched=True)# 数据收集器
logger.info("DataCollatorForLanguageModeling")
data_collator = DataCollatorForLanguageModeling(tokenizer=teacher_tokenizer, mlm=False)# 定义训练参数
logger.info("Creating trainer")
training_args = TrainingArguments(output_dir="./results",            # 训练结果保存路径eval_strategy="epoch",             # 每个 epoch 结束时评估learning_rate=5e-5,                # 学习率(默认 5e-5 是常见选择)per_device_train_batch_size=2,     # 每个设备的训练 batch size(GPU 单卡)per_device_eval_batch_size=2,      # 每个设备的评估 batch sizenum_train_epochs=3,                # 训练轮次(3 轮可能较短,需根据任务调整)weight_decay=0.01,                 # 权重衰减(L2 正则化)logging_dir="./logs",              # 日志保存路径logging_steps=100,                 # 每 100 步记录一次日志fp16=False,                        # 是否启用混合精度训练(建议开启)gradient_accumulation_steps=4,     # 梯度累积步数(等效 batch_size=8)report_to="tensorboard",           # 使用 TensorBoard 记录训练过程# tensorboard_dir="./tensorboard"  # 可选:指定 TensorBoard 日志目录
)# 定义蒸馏配置  weight:添加权重,"loss": "mse"
logger.info("Creating distillation config")
distill_config = DistillationConfig(temperature=2.0,  # 温度参数,控制软标签的平滑程度hard_label_weight=0.5,  # 真实标签损失权重kd_loss_type="ce",      # 知识蒸馏损失类型(交叉熵)intermediate_matches=[  # 中间层匹配配置
        {"layer_T": 6,    # 教师模型的第6层"layer_S": 6,    # 学生模型的第6层"feature": "hidden",  # 匹配隐藏层特征"weight": 1.0,   # 中间层损失权重"loss": "mse"    # 使用均方误差损失
        }]
)# 定义训练配置
logger.info("Creating training config")
train_config = TrainingConfig(device="cuda" if torch.cuda.is_available() else "cpu",  # 设备选择log_dir="./logs",                                     # 日志目录output_dir="./outputs"                                # 模型输出目录# save_best_model=True,  # 是否保存最佳模型(注释状态)# save_last_model=True,  # 是否保存最后模型(注释状态)# save_model_every_epoch=True,  # 是否每轮保存模型(注释状态)# tensorboard_dir="./tensorboard"  # TensorBoard 日志目录(注释状态)
)# 创建蒸馏器
logger.info("Creating distiller")
distiller = GeneralDistiller(train_config=train_config,        # 训练配置(包含设备、路径等)distill_config=distill_config,    # 蒸馏配置(温度、损失权重等)model_T=teacher_model,            # 教师模型model_S=student_model,            # 学生模型adaptor_T=None,                   # 教师模型适配器(未配置)adaptor_S=None                    # 学生模型适配器(未配置)
)# 开始蒸馏
with distiller:  # 使用蒸馏器上下文管理器,确保资源正确初始化和释放logger.info("Starting training")  # 记录训练开始日志# 初始化 Trainer,集成模型蒸馏配置trainer = Trainer(model=student_model,  # 学生模型(需要训练的小模型)args=training_args,  # 训练参数(如学习率、批次大小、设备等)train_dataset=train_dataset,  # 训练数据集(包含输入和标签)eval_dataset=eval_dataset,  # 验证数据集(用于评估模型性能)data_collator=data_collator,  # 数据批量处理函数(将单条数据组合成批次)# processing_class=teacher_tokenizer  # 注意:此处可能存在问题(见下方说明)# 正确做法:适配器或数据处理逻辑应在蒸馏配置中处理
    )# 开始模型训练trainer.train()  # 启动训练循环,包含前向传播、损失计算、反向传播等
    trainer.save_model()logger.info("Training finished")  # 记录训练结束日志

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/902242.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

asp.net core webapi 完整Swagger配置

在当前项目下新建Utility文件夹,Utility文件夹下面在创建SwaggerExt文件夹,文档结果如下 CustomSwaggerExt.cs文件如下using Microsoft.Extensions.Options; using Microsoft.OpenApi.Models;namespace xxxxxxxxxx {/// <summary>/// 扩展Swagger/// </summary>pub…

ciscnccb半决赛

AWDP typo 一道2.31的堆题漏洞点位于edit功能,snprintf函数把用户输入作为format,导致了堆溢出以及格式化字符串漏洞fix 从程序的代码不难看出分配出来的堆,前面八个字节是堆的size,后面的空间才是数据域 这里原意是修改heap的size,但是用错了函数,我们修改最大读入的siz…

AI全天候智能助手,为您构建私人数据库

在数字化转型浪潮中,AI与大数据技术已成为企业提升效率、优化服务的核心引擎。思通数科凭借其自主研发的大数据智能系统,以AI为核心,打造了一站式解决方案,覆盖消费者服务、商家赋能与平台运营三大领域,助力用户与合作伙伴实现智能化升级。以下是该系统的核心功能与价值解…

安装 Prometheus监控主机服务

一、安装 Prometheus 下载 Prometheus 首先,访问 Prometheus 官网 获取最新版本的下载链接,然后使用 wget 下载:wget https://github.com/prometheus/prometheus/releases/download/v3.2.1/prometheus-3.2.1.linux-amd64.tar.gz解压并安装解压下载的文件:tar -xvzf prometh…

L1 通讲

好多,好多。L1 通讲 部分知识点速通 技术与产品开发的动机 ​ 这张图展示了两个长期趋势:技术和创新的发展速度逐渐变快; 它对我们的生活影响非常广泛,包括好的(如天花疫苗)和坏的(核弹?) 技术变得越来越强大。 例如,我们的祖先使用石制工具,但现在我们构建跨越全球…

Flink 实战之流式数据去重

流式数据是一种源源不断产生的数据,没有预定的开始与结束,至少理论上来说,它的数据输入永远不会结束。因此流式数据处理与传统的批处理技术不同,必须具备持续不断地对到达的数据进行处理的能力。因为流式数据源源不断地产生,对流式数据做去重就十分困难,因为一条数据重复…

vue3 + springboot 实现模糊查询与增加操作

实现表格查询: <!-- 表格 --><div class="card" style="margin-bottom: 5px"><el-table :data="data.tableData" stripe><el-table-column label="名称" prop="name" /><el-table-column lab…

网络基础与进阶

计算机网络入门与进阶 学习OSI网络模型相关概念(重点掌握) 学习TCP三次握手与四次挥手过程(重点掌握) 学习TCP的11种状态集转化(重点掌握) 学习DNS相关知识概念与原理 linux网关配置(添加网关 网段 以及网络主机路由) 修改网卡配置文件 用户访问www.baidu.com 整个过程…

VTK-8.2.0源码编译和初步使用(Cmake+VS2015+Qt5.14.2)

一、准备数据 1、首先确保已安装VS5015和Qt5.14.2 2、下载Cmake并安装:Download CMake 3、下载VTK-8.2.0源码和数据并解压:Download | VTK 二、Cmake构建 1、在本地磁盘创建相关文件夹2、进入源码根目录,找到CmakeList.txt,修改CmakeList.txt中的选项,使得Debug模式下生成…

B2043 判断能否被3,5,7整除

读者自己完善一下10、11、13、14行吧