如何蒸馏 Deepseek-R1

news/2025/2/6 17:04:34/文章来源:https://www.cnblogs.com/little-horse/p/18701373

如何蒸馏 Deepseek-R1

深度学习模型已经彻底改变了人工智能领域,但其庞大的规模和计算需求可能成为现实世界应用的瓶颈。模型蒸馏是一种强大的技术,通过将知识从大型复杂模型(教师)转移到较小、更高效的模型(学生)来解决这一挑战。

在这篇博客中,这里将介绍如何使用 LoRA (Low-Rank Adaptation)等专门技术,将 DeepSeek-R1 的推理能力蒸馏成一个更小的模型,比如微软的 Phi-3-Mini。

什么是蒸馏?

蒸馏是一种机器学习技术,其中一个较小的模型(“学生”)被训练来模仿一个较大的预训练模型(“老师”)的行为。其目标是在显著降低计算成本和内存占用的同时,保留大部分教师的表现。

这个想法最早是在Geoffrey Hinton 关于知识蒸馏的开创性论文中提出的。它不是直接在原始数据上训练学生模型,而是从教师模型的输出或中间表示中学习。这实际上是受到了人类教育的启发。

为什么重要:

  • 成本效率:较小的模型需要更少的计算资源。
  • 速度:非常适合对延迟敏感的应用(例如 api、边缘设备)。
  • 专业化:在不重新训练巨头的情况下为特定领域量身定制模型。

蒸馏类型

有几种方法可以模拟蒸馏,每种方法都有自己的优点:

  1. 数据蒸馏
  • 在数据蒸馏中,教师模型生成合成数据或伪标签,然后用于训练学生模型。
  • 这种方法可以应用于广泛的任务,甚至是那些逻辑信息较少的任务(例如,开放式推理任务)。
  • Logits蒸馏
  • Logits 是应用 softmax 函数之前神经网络的原始输出分数。
  • 在logits 蒸馏中,学生模型被训练成匹配老师的logits,而不仅仅是最终的预测。
  • 这种方法保留了更多关于教师信心水平和决策过程的信息。
  • 特征蒸馏
  • 特征蒸馏涉及到将知识从教师模型的中间层传递给学生。
  • 通过对齐两个模型的隐表征,学生可以学习到更丰富、更抽象的特征。

Deepseek 的蒸馏模型

DeepSeek AI 发布了六个基于流行架构的蒸馏模型,如 Qwen

(Qwen,2024b)和 Llama (AI@Meta,2024)他们直接使用 DeepSeek-R1

收集的 80 万样本对开源模型进行微调。

尽管比 DeepSeek-R1 小得多,但经过蒸馏过的模型在各种基准测试中表现出了令人印象深刻的性能,通常与更大的模型相匹配甚至超越。如下图所示

Deepseek 蒸馏模型基准(https://arxiv.org/html/2501.12948v1)

蒸馏自己的模型

  1. 特定任务优化
    预蒸馏模型在广泛的数据集上进行训练,以便在广泛的任务中表现良好。然而,现实世界的应用往往需要专业化
    示例场景:
    你正在构建一个财务预测聊天机器人。
    在这种情况下,使用 DeepSeek-R1 来生成金融数据集的推理痕迹(例如,股票价格预测,风险分析),并将这些知识蒸馏到一个已经知道金融细微差别的较小模型中(例如:financial - llm)。
  2. 规模成本效益
    虽然预蒸馏模型是有效的,但对于你的特定工作量来说,它们可能仍然是多余的。蒸馏自己的模型允许针对自己的确切资源约束进行优化
  3. 基准性能≠真实世界性能
    预蒸馏模型在基准测试上表现出色,但基准测试往往不能代表现实世界的任务。所以你经常需要一个模型,它在现实场景中的表现比任何预蒸馏模型都要好。
  4. 迭代改进

预蒸馏模型是静态的——它们不会随着时间的推移而改进。通过蒸馏自己的模型,你可以随着新数据的出现而不断完善它

蒸馏DeepSeek-R1 知识到自定义小模型

步骤 1:安装库

pip install -q torch transformers datasets accelerate bitsandbytes flash-attn --no-build-isolation

步骤 2:生成和格式化数据集

可以通过在你的环境中使用ollama 或任何其他部署框架部署deepseek-r1 来生成自定义的领域相关数据集。但是,对于本教程,这里将使用Magpie-Reasoning-V2数据集,其中包含由 DeepSeek-R1 生成的 250K CoT 推理样本。这些样本涵盖了数学推理、编码和一般问题解决等不同的任务。

数据集结构

每个样本包括:

  • 指令:任务描述(例如,“解决这道数学题”)。
  • 回应:DeepSeek-R1 的逐步推理(CoT)。例子:
{"instruction": "Solve for x: 2x + 5 = 15","response": "<think>First, subtract 5 from both sides: 2x = 10. Then, divide by 2: x = 5.</think>"
}
from datasets import load_dataset# Load the dataset
dataset = load_dataset("Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B", token="YOUR_HF_TOKEN")
dataset = dataset["train"]# Format the dataset
def format_instruction(example):return {"text": ("<|user|>\n"f"{example['instruction']}\n""<|end|>\n""<|assistant|>\n"f"{example['response']}\n""<|end|>")}formatted_dataset = dataset.map(format_instruction, batched=False, remove_columns=subset_dataset.column_names)
formatted_dataset = formatted_dataset.train_test_split(test_size=0.1)  # 90-10 train-test split

将数据集构造成 Phi-3 的聊天模板格式:

<|user|>:用户询问的开始。

<|assistant|>:模型响应的开始。

<|end|>:一轮结束。

每个 LLM 使用特定的指令跟随任务格式。将数据集与这种结构对齐可以确保模型学习到正确的会话模式。所以一定要根据你想要蒸馏的模型来格式化数据。

步骤 3:加载 Model 和 Tokenizer

为了增强模型的推理能力,这里向tokenizer添加特殊tokens <think>和</think>。

<think>:推理的开始。

</think>:推理结束。

这些tokens帮助模型学习生成结构化的、逐步的解决方案。

from transformers import AutoTokenizer, AutoModelForCausalLMmodel_id = "microsoft/phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)# Add custom tokens
CUSTOM_TOKENS = ["<think>", "</think>"]
tokenizer.add_special_tokens({"additional_special_tokens": CUSTOM_TOKENS})
tokenizer.pad_token = tokenizer.eos_token# Load model with flash attention
model = AutoModelForCausalLM.from_pretrained(model_id,trust_remote_code=True,device_map="auto",torch_dtype=torch.float16,attn_implementation="flash_attention_2"
)
model.resize_token_embeddings(len(tokenizer))  # Resize for custom tokens

步骤 4:为高效微调配置 LoRA

LoRA 通过冻结基本模型和只训练小的适配器层来减少内存使用。

from peft import LoraConfigpeft_config = LoraConfig(r=8,  # Rank of the low-rank matriceslora_alpha=16,  # Scaling factorlora_dropout=0.2,  # Dropout ratetarget_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Target attention layersbias="none",  # No bias termstask_type="CAUSAL_LM" # Task type
)

第 5 步:设置训练参数

from transformers import TrainingArgumentstraining_args = TrainingArguments(output_dir="./phi-3-deepseek-finetuned",num_train_epochs=3,per_device_train_batch_size=2,per_device_eval_batch_size=2,gradient_accumulation_steps=4,eval_strategy="epoch",save_strategy="epoch",logging_strategy="steps",logging_steps=50,learning_rate=2e-5,fp16=True,optim="paged_adamw_32bit",max_grad_norm=0.3,warmup_ratio=0.03,lr_scheduler_type="cosine"
)

第 6 步:训练模型

SFTTrainer 简化了指令遵循模型的监督微调。data_collator 对示例进行批处理,

peft_config 支持基于lora 的训练。

from trl import SFTTrainer
from transformers import DataCollatorForLanguageModeling# Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)# Trainer
trainer = SFTTrainer(model=model,args=training_args,train_dataset=formatted_dataset["train"],eval_dataset=formatted_dataset["test"],data_collator=data_collator,peft_config=peft_config
)# Start training
trainer.train()
trainer.save_model("./phi-3-deepseek-finetuned")
tokenizer.save_pretrained("./phi-3-deepseek-finetuned")

第 7 步:合并并保存最终模型

训练后,LoRA 适配器必须与base模型合并进行推理。这一步确保了模型可以在没有 PEFT 的情况下独立使用。

final_model = trainer.model.merge_and_unload()
final_model.save_pretrained("./phi-3-deepseek-finetuned-final")
tokenizer.save_pretrained("./phi-3-deepseek-finetuned-final")

第 8 步:推理

from transformers import pipeline# Load fine-tuned model
model = AutoModelForCausalLM.from_pretrained("./phi-3-deepseek-finetuned-final",device_map="auto",torch_dtype=torch.float16
)tokenizer = AutoTokenizer.from_pretrained("./phi-3-deepseek-finetuned-final")
model.resize_token_embeddings(len(tokenizer))# Create chat pipeline
chat_pipeline = pipeline("text-generation",model=model,tokenizer=tokenizer,device_map="auto"
)# Generate response
prompt = """<|user|>
What's the probability of rolling a 7 with two dice?
<|end|>
<|assistant|>
"""output = chat_pipeline(prompt,max_new_tokens=5000,temperature=0.7,do_sample=True,eos_token_id=tokenizer.eos_token_id
)print(output[0]['generated_text'])

下图可以看到 phi 模型在蒸馏前后的响应。

question: what’s the probability of rolling a 7 with two dice?

问题:用两个骰子摇到 7 的概率是多少?

蒸馏前的推理:回答直白简洁。直接提供了计算答案的步骤。

蒸馏后的推理:蒸馏后的回答引入了一种更详细和结构化的方法,包括一个明确的“思考”部分,概述了思维过程和推理,这将对复杂问题产生准确的回答非常有帮助。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/879724.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UU 跑腿云原生化,突围同城配送赛道

我们起初是把业务部署在 IDC,但经历过频繁的服务器网线意外断掉,震网病毒在无通知的情况下封禁一批端口,其中包含数据库 alwayson 的端口,导致大量的同步日志挤压,最终数据库崩溃,无法启动。这些都严重制约了我们业务的发展,于是真正决定要开始上云。作者:袁沼&望宸…

UUbox-FastCMD:windows下自定义快捷指令

【快捷指令 UUbox-fastcmd】 windows版,绿色免费免安装。 适用于PC端经常需要切换工作场景的需求,可以将高频次操作自定义位快捷文字指令。 主页面:【下载地址】https://files.cnblogs.com/files/blogs/837238/UUbox-FastCMD_V0.1.rar?t=1738832053&download=true 【功…

微软发布基于PostgreSQL的开源文档数据库平台DocumentDB

我们很高兴地宣布正式发布DocumentDB——一个开源文档数据库平台,以及基于 vCore、基于 PostgreSQL 构建的 Azure Cosmos DB for MongoDB 的引擎。过去,NoSQL 数据库提供云专用解决方案,而没有通用的互操作性标准。这导致对可互操作、可移植且完全支持生产就绪的文档数据存储…

CTFShow-Web151:文件上传漏洞

CTFShow-Web151:文件上传漏洞 🛠️ Web151 题解 本题考察 文件上传漏洞,仅在前端进行了文件类型验证,允许上传 .png 图片文件。我们可以通过抓包修改文件后缀的方式绕过限制并获取 WebShell。 🔍 源码分析 在 upload.php 代码中,我们发现了以下 HTML 代码片段: <bu…

云大使 X 函数计算 FC 专属活动上线!享返佣,一键打造 AI 应用

通过函数计算 FC 一键部署 Flux 模型,快速生成毛茸茸萌宠风格图像。我们将为您提供预置的工作流文件+内置大模型+Lora 模型,让您基于函数计算部署 ComfyUI 快速体验AI生图。如今,AI 技术已经成为推动业务创新和增长的重要力量。但对于许多企业和开发者来说,如何高效、便捷地…

标准化管理数字化转型的实践与价值

在当下数字化转型的汹涌浪潮中,企业所处的竞争环境变得愈发复杂且瞬息万变。 标准化管理作为企业实现高质量发展的稳固基石,正切实面临着前所未有的机遇与严峻挑战。 从机遇层面来看,数字化技术的迅猛发展为标准化管理带来了全新的手段和方法,使管理效率和精准度提升成为可…

JS-52 定时器之setTimeout()

JavaScript提供定时执行代码的功能,叫做定时器(timer),主要由setTimeout和setlnterval()这两个函数来完成。他们向任务队列添加定时任务 setTimeout函数用来指定某个函数或某段代码,在多少毫秒之后执行。它返回一个整数,表示定时器的编号,以后可以用来取消这个定时器。…

spark实验一

使用 Linux 系统的常用命令 启动 Linux 虚拟机,进入 Linux 系统,通过查阅相关 Linux 书籍和网络资料,或者参考 本教程官网的“实验指南”的“Linux 系统常用命令”,完成如下操作: (1) 切换到目录 /usr/bin; (2) 查看目录/usr/local 下所有的文件;(3)…

博客园-awescnb插件-geek皮肤优化-Markdown样式支持

💖简介 博客园-awescnb插件-geek皮肤下,Markdown语法中对部分样式未正常支持,可以通过自定义CSS进行完善。 ✨定义列表定义自定义CSS 博客园->管理->设置->页面定制 CSS 代码 添加代码/* 定义列表 */ dl dt{font-size: 14px;font-weight: bold;font-style: italic…

uniapp vue3 路由传参 利用props获取参数

A页面跳转B页面 A页面 function toDetail(value) {console.log(click);let chuansVal = decodeURIComponent(JSON.stringify(value));console.log(chuansVal);uni.navigateTo({url: "/pages/material/receiveDetail?data=" + chuansVal}); }B页面 const props = def…

团队协作工具私有化部署优选:板栗看板的安全与高效之道

在进行企业私有化选择时,建议详细咨询软件供应商或查看其官方文档以获取最准确的信息。板栗看板是一款非常适合中小团队的协作工具,尤其在任务管理、项目进度跟踪和沟通协作方面表现出色。如果你正在寻找一款简洁高效、功能强大的团队协作工具进行企业私有化,板栗看板无疑是…

06 软件安全测试

13. 软件安全性测试 黑客、病毒、蠕虫、间谍软件、后门程序、木马、拒绝服务攻击等。 安全产品:指在系统的所有者或者管理员的控制下,保护用户信息的保密性、完整性、可获得性,以及处理资源的完整性和可获得性。 安全漏洞:产品不可行的缺陷,正确使用产品时来防止攻击者窃取…