LayerSkip: 使用自推测解码加速大模型推理

news/2025/3/11 13:47:39/文章来源:https://www.cnblogs.com/huggingface/p/18764887

自推测解码是一种新颖的文本生成方法,它结合了推测解码 (Speculative Decoding) 的优势和大语言模型 (LLM) 的提前退出 (Early Exit) 机制。该方法出自论文 LayerSkip: Enabling Early-Exit Inference and Self-Speculative Decoding。它通过使用 同一个模型 的早期层来生成候选词元 (token),并使用后期层进行验证,从而实现高效生成。

这项技术不仅加快了文本生成速度,还显著节省了内存并降低了计算延迟。为了实现端到端的加速,早期层的输出需要与最终层的输出足够接近。正如论文中所述,这可以通过一种训练方法来实现,该方法可以在预训练期间应用,也可以在特定领域进行微调时应用。自推测解码对于实际应用特别高效,它可以在较小的 GPU 上部署,并降低 大规模推理 所需的整体硬件资源。

在本博客中,我们将探讨自推测解码的概念、其实现方式以及在 🤗 transformers 库中的实际应用。您将了解到其技术原理,包括 提前退出层 (Early-Exit Layers)反嵌入 (Unembedding)训练修改 (Training Modifications)。为了将这些概念付诸实践,我们提供了代码示例、与传统推测解码的基准比较,以及对性能权衡的见解。

您还可以直接查看以下 Hugging Face 资源,了解更多关于该方法的信息并亲自尝试:

  1. Hugging Face 论文讨论论坛
  2. LayerSkip 模型集合
  3. 展示自推测解码深入工作原理的 Colab 笔记本

推测解码与自推测解码

LayerSkip 演示 GIF

facebook/layerskip-llama2-7B 上的 LayerSkip 推理演示 (使用 LayerSkip 方法持续预训练的 Llama2 7B)。

传统的推测解码 使用 两个 模型: 一个较小的模型 (草稿模型) 用于生成一系列候选词元,一个较大的模型 (验证模型) 用于验证草稿的准确性。较小的模型执行大部分生成工作,而较大的模型则负责改进结果。这提高了文本生成速度,因为较大的模型一次性验证完整序列,而不是逐个生成词元。

在自推测解码中,作者在此概念的基础上,使用大模型的早期层来生成草稿词元,然后由模型的更深层进行验证。这种推测解码的“自洽”特性需要特定的训练,使模型能够同时执行草稿生成和验证。这反过来又比传统的推测解码提高了速度并降低了计算成本。

transformers 中的使用

为了在 🤗 transformers 库中启用提前退出自推测解码,我们只需在 generate() 函数中添加 assistant_early_exit 参数。

以下是一个简单的代码片段,展示了该功能:

pip install transformersfrom transformers import AutoTokenizer, AutoModelForCausalLMearly_exit_layer = 4
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama2-7B"tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
outputs = model.generate(**inputs, assistant_early_exit=early_exit_layer)

注意: 虽然 assistant_early_exit 参数可以为任何仅解码器的 transformer 启用提前退出自推测解码,但除非模型经过专门训练,否则无法反嵌入 (通过 LM 头进行解码的过程,在博客文章后面有描述) 中间层的 logits。只有对检查点进行这样的训练,以提高早期层的准确性,您才能获得加速。LayerSkip 论文提出了一种训练方法来实现这一点 (即应用提前退出损失,并逐步增加层丢弃率)。这里 提供了使用 LayerSkip 训练方法持续预训练的 Llama2、Llama3 和 Code Llama 检查点的集合。

基准测试

我们进行了一系列广泛的基准测试,以衡量 LayerSkip 的自推测解码相对于自回归解码在各种模型上的加速情况。我们还将自推测解码 (基于提前退出) 与标准推测解码技术进行了比较。要复现这些结果,您可以在 这里 找到代码,并在 此电子表格 中找到运行每个实验的命令。所有实验均在单个 80GB A100 GPU 上运行,除了 Llama2 70B 实验在 8 个 A100 GPU 的节点上运行。

Llama3.2 1B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
facebook/layerskip-llama3.2-1B 1 Early Exit @ Layer 4 summarization 1 1195.28 9.96 2147.7 17.9 1.80

Llama3 8B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Meta-Llama-3-8B 8 meta-llama/Llama-3.2-1B 1 summarization 9 1872.46 19.04 2859.35 29.08 1.53
meta-llama/Meta-Llama-3-8B 8 meta-llama/Llama-3.2-3B 3 summarization 11 2814.82 28.63 2825.36 28.73 1.00
facebook/layerskip-llama3-8B 8 Early Exit @ Layer 4 summarization 8 1949.02 15.75 3571.81 28.87 1.83

Llama2 70B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Llama-2-70b-hf 70 meta-llama/Llama-2-13b-hf 13 summarization 83 5036.54 46.3 12289.01 112.97 2.44
meta-llama/Llama-2-70b-hf 70 meta-llama/Llama-2-7b-hf 7 summarization 77 4357.55 40.06 12324.19 113.3 2.83
meta-llama/Llama-2-70b-hf 70 TinyLlama/TinyLlama_v1.1 1 summarization 71 4356.21 40.05 12363.22 113.66 2.84
facebook/layerskip-llama2-70B 70 Early Exit @ Layer 10 summarization 70 6012.04 54.96 1283.34 113.2 2.06

Llama2 13B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Llama-2-13b-hf 13 meta-llama/Llama-2-7b-hf 7 summarization 20 3557.07 27.79 4088.48 31.94 1.15
meta-llama/Llama-2-13b-hf 13 TinyLlama/TinyLlama_v1.1 1 summarization 14 2901.92 22.67 4190.42 32.74 1.44
meta-llama/Llama-2-13b-hf 13 apple/OpenELM-270M 0.27 summarization 13.27 2883.33 22.53 4521.12 35.32 1.57
meta-llama/Llama-2-13b-hf 13 apple/OpenELM-450M 0.45 summarization 13.45 3267.69 25.53 4321.75 33.76 1.32
facebook/layerskip-llama2-13B 13 Early Exit @ Layer 4 summarization 13 4238.45 33.11 4217.78 32.95 0.995
facebook/layerskip-llama2-13B 13 Early Exit @ Layer 8 summarization 13 2459.61 19.22 4294.98 33.55 1.746

Llama2 7B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Llama-2-7b-hf 7 TinyLlama/TinyLlama_v1.1 1 summarization 8 2771.54 21.65 3368.48 26.32 1.22
meta-llama/Llama-2-7b-hf 7 apple/OpenELM-270M 0.27 summarization 7.27 2607.82 20.37 4221.14 32.98 1.62
meta-llama/Llama-2-7b-hf 7 apple/OpenELM-450M 0.45 summarization 7.45 3324.68 25.97 4178.66 32.65 1.26
facebook/layerskip-llama2-7B 7 Early Exit @ Layer 4 summarization 7 2548.4 19.91 3306.73 25.83 1.297

我们可以观察到以下几点:

  • 从“ 总参数数量”列可以看出,自推测解码消耗的内存更少,因为它不需要单独的草稿模型,并且草稿阶段层的权重被重用。
  • 对于除 Llama2 70B 之外的所有模型大小和生成,提前退出自推测解码比常规的两模型推测解码更快。
  • 与其它模型相比,Llama2 70B 的自推测解码速度提升相对有限,可能有不同的原因,例如,Llama2 70B 的 LayerSkip 检查点持续预训练的 token 较少 (Llama2 70B 为 328M token,而 Llama2 7B 为 52B token)。但这是未来研究需要改进的一个方面。尽管如此,70B 的自推测解码明显快于自回归解码。

自生成和自验证

自推测解码过程从自生成开始,其中词元是通过从某个中间层提前退出来生成的。推测词元的数量定义了在此阶段生成多少草稿词元,而我们退出的层定义了草稿阶段的规模和准确性。这两个参数都可以在推理时根据草稿阶段的速度和准确性之间的权衡来指定。

下一步是自验证,其中使用完整模型来验证草稿词元。验证模型重用草稿模型中的缓存部分。如果草稿词元与验证的词元一致,则将它们添加到最终输出中,从而更好地利用我们系统中的内存带宽,因为使用完整模型生成一系列词元比验证草稿要昂贵得多,只要有几个词元匹配即可。

在自验证阶段,只有剩余的层才会被计算以进行验证,因为早期层的结果在草稿阶段已被缓存。

提前退出和反嵌入

自推测解码中的一项关键技术是提前退出,即生成过程可以在预先指定的层停止。为了实现这一点,我们通过将这些层的 logits 投影到语言模型 (LM) 头上来反嵌入它们,以预测下一个词元。这允许模型跳过后续层并提高推理时间。

可以在任何 transformer 层执行反嵌入,将提前退出转变为一种高效的词元预测机制。一个自然而然的问题出现了: 当 LM 头最初被训练为仅与最终层一起工作时,如何使其适应反嵌入较早层的 logits?这就是训练修改发挥作用的地方。

训练修改

在训练阶段,我们引入了层丢弃,它允许模型在训练期间跳过某些层。丢弃率在较深的层中逐渐增加,使模型不太依赖其后面的层,并增强模型的泛化能力并加快训练速度。

除了层丢弃之外,还应用了提前退出损失,以确保 LM 头学习反嵌入不同的层。使用每个出口 (中间层) 的归一化损失的总和来给出使用提前出口训练模型的总损失函数。这种技术通过在所有层之间分配学习任务来实现高效训练。

优化: 共享权重、共享 KV 缓存和共享计算

自推测解码显著受益于缓存重用,特别是 KV 缓存,它存储在草稿阶段计算的键值对。此缓存允许模型跳过冗余计算,因为草稿和验证阶段都使用相同的早期层。此外,退出查询缓存存储来自退出层的查询向量,允许验证从草稿阶段无缝继续。

与传统的双模型推测解码相比,提前退出自推测解码可以从以下节省中受益:

  • 共享权重: 为草稿和验证重用前 E 层 的权重。
  • 共享 KV 缓存: 为草稿和验证重用前 E 层的键值对
  • 共享计算: 通过使用仅保存退出层 E-1 的查询向量的退出查询缓存来重用前 E 层的计算,以便验证过程无需计算层 0 到 E-1。

KV 和退出查询缓存的组合称为 KVQ 缓存,可减少内存开销并提高推理延迟。

到目前为止,🤗 transformers 库已在此 pull request 中实现了第一个优化 (共享权重)。随着使用此方法的模型数量增加,我们将考虑其他优化。如果您有兴趣,请随时提出 PR!

提前退出层的选择策略

草稿阶段的提前退出层是一个超参数,我们可以在推理期间调整或修改:

  • 我们越早退出,生成草稿词元的速度就越快,但它们的准确性就越低。
  • 我们越晚退出,生成的草稿词元就越准确,但它们的速度就越慢。

我们编写了一个脚本来遍历不同的提前退出层并测量 A100 GPU 上的每秒词元数。在下面的表格中,我们绘制了针对不同 Llama 模型的 LayerSkip 和基线检查点的每秒词元数与提前退出层的关系图 (您可以在 此处 查看完整日志)。

Llama3.2 1B

Normal (常规模型) LayerSkip (LayerSkip 模型)
llama 3.2 1b layer skip llama 3.2 1b

Llama3 8B

Normal (常规模型) LayerSkip (LayerSkip 模型)
llama 3 8b layer skip llama 3 8b

Code Llama3 34B

Normal (常规模型) LayerSkip (LayerSkip 模型)
code llama 3 34b code layer skip llama 3 34b

Code Llama3 7B

Normal (常规模型) LayerSkip (LayerSkip 模型)
code llama 3 7b code layer skip llama 3 7b

Llama2 70B

Normal (常规模型) LayerSkip (LayerSkip 模型)
llama 2 70b layer skip llama 2 70b

Llama2 13B

Normal (常规模型) LayerSkip (LayerSkip 模型)
llama 2 13b layer skip llama 2 13b

Llama2 7B

Normal (常规模型) LayerSkip (LayerSkip 模型)
llama 2 7b layer skip llama 2 7b

我们可以观察到以下几点:

  • 对于没有使用 LayerSkip 训练方法进行预训练或持续预训练的基线检查点,提前退出自推测解码比自回归解码更慢。这是因为在大多数 LLM 的训练过程中,早期层并没有被激励去学习预测输出,因此使用早期层生成词元的接受率会非常低。
  • 另一方面,对于使用 LayerSkip 训练方法持续预训练的 Llama 检查点,提前退出自推测解码在至少一部分层中比自回归解码具有更高的加速比。
    • 对于大多数模型 (除了 Llama3.2 1B),当我们遍历各层时,我们注意到一个规律模式: 加速比在前几层较低,逐渐增加到一个最佳点,然后再次下降。
    • 提前退出层的最佳点是在预测的高准确性和生成词元的低开销之间达到最佳权衡时。这个最佳点取决于每个模型,也可能取决于提示或提示的领域。

这些观察为进一步的实验和探索提供了有趣的机会。我们鼓励读者在这些想法的基础上进行构建,测试变体,并进行自己的研究。这些努力可以带来有价值的见解,并为该领域做出有意义的贡献。

结论

LayerSkip 利用提前退出、层丢弃和缓存重用之间的协同作用,创建了一个快速高效的文本生成流程。通过训练模型从不同层反嵌入输出,并使用缓存优化验证过程,这种方法在速度和准确性之间取得了平衡。因此,它显著改善了大语言模型的推理时间,同时保持了高质量的输出。由于使用单个模型作为草稿和验证模型,它还比传统的推测解码技术减少了内存使用。

自推测是一个令人兴奋的领域,同一个 LLM 可以创建草稿词元并自我修正。其他自推测方法包括:

  • Draft & Verify: 其中草稿阶段涉及跳过预定的注意力和前馈层。
  • MagicDec: 其中草稿阶段使用 KV 缓存的子集,这对长上下文输入很有用。
  • Jacobi Decoding 和 Lookahead Decoding: 其中草稿阶段是一系列“猜测词元”,可以是随机的或从 n-gram 查找表中获得的。

英文原文: https://huggingface.co/blog/layerskip

原文作者: Aritra Roy Gosthipaty, Mostafa Elhoushi, Pedro Cuenca, Vaibhav Srivastav

译者: smartisan

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/897240.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年我用 Compose 写了一个 Todo App

标题党嫌疑犯实锤 序言 从2月12日到3月4日这整整三周时间里,我从零开始又学习了一次 Compose。 为什么说又,是因为这已经是我第二次学习这套课程了。 故事从 4 年前说起,2021 年在意外获悉扔物线朱凯老师准备发布一套名为 Compose 的新课程,意识到这是 Android 未来的方向,…

Ubuntu设置静态IP——NetworkManager方式

1、直接在系统界面上设置静态IP的方式,不再赘述 2、命令行方式查看已经有哪些工具#查看状态 sudo systemctl status Netplan sudo systemctl status NetworkManager sudo systemctl status systemd-networkd sudo systemctl status NetworkManager出现Active,说明电脑已经安装…

《Quick Start Kubernetes》读后感

一、 为什么选择这本书? 面试的时候经常被问到 kubernetes(下称 k8s),所以打算学习 k8s。看到《Quick Start Kubernetes》的作者对自己所写的书持续地更新,被这种认真打动了,外加这本书只有100多页,所以选择了这本书作为入门 k8s 的教材。 二、这本书写了什么? 这本书介绍…

正交实验法python实现

1.等水平正交表 每个条件下的种类一样多 例1: 这是一个7因子2状态 列表里内部每一个[]表示一个因子,然后每个因子都有2种类型 #7因子2状态 from allpairspy import AllPairs parameters = [["Chrome", "Firefox"],#因子1有"Chrome"或"Fir…

如何调用 DeepSeek 的自然语言处理 API 接口并集成到在线客服系统

本文将提供一个详细的示例,展示如何调用 DeepSeek 的自然语言处理 API 接口。我们将以情感分析为例,演示如何发送请求、处理响应以及处理可能的错误。我在业余时间开发了一款自己的独立产品:升讯威在线客服与营销系统。陆陆续续开发了几年,从一开始的偶有用户尝试,到如今线…

使用 Pixi.js 插件实现探险者小游戏(一)

什么是 Pixi Pixi 是一个非常快的 2D sprite 渲染引擎。使用它你可以轻松的利用 JavaScript 和其他 HTML5 技术制作游戏和应用程序。 Pixi 的官网地址:https://pixijs.com/ 本游戏使用的是 Pixi 的 V4.5.5 版本,官网最新版本更新到了 V8.x,两个版本 API 相差很大,建议大家学…

【设计模式】利用组合模式带你走进树形结构的世界

概述对于这个图片肯定会非常熟悉,上图我们可以看做是一个文件系统,对于这样的结构我们称之为树形结构。在树形结构中可以通过调用某个方法来遍历整个树,当我们找到某个叶子节点后,就可以对叶子节点进行相关的操作。可以将这颗树理解成一个大的容器,容器里面包含很多的成员…

20250311

1. 沪镍还有两个上涨波段

【设计模式】从智能音箱到软件设计:探索外观模式的实际应用案例

概述 有些人可能炒过股票,但其实大部分人都不太懂,这种没有足够了解证券知识的情况下做股票是很容易亏钱的,刚开始炒股肯定都会想,如果有个懂行的帮手就好,其实基金就是个好帮手,支付宝里就有许多的基金,它将投资者分散的资金集中起来,交由专业的经理人进行管理,投资于…

PMC必须要懂的四个关键流程:生产、库存、交期全过程解析!

PMC(生产计划与物料控制)这个岗位,看起来就是三个字,但实际干起来,简直是让人上蹿下跳、手忙脚乱。一边要盯着生产线, 一边要和供应商、采购、销售对接,稍微一个环节没控好,就可能导致生产停滞、库存爆仓、交期延误,直接影响公司运营。 很多PMC天天在救火,但其实掌握…

JavaScript HTML DOM - 改变 HTML 功能 用法运用 详解

JavaScript中的HTML DOM提供了强大的功能来改变HTML文档的内容和结构。通过JavaScript,我们可以动态地更新网页上的文本、属性、样式以及整个HTML结构。以下是对这些功能的详细解释和用法示例: 一、改变HTML内容使用innerHTML:innerHTML属性用于获取或设置元素的HTML内容。这…

20241905 2024-2025-2 《网络攻防实践》 第2次作业

1. 实验内容 本次实验为网络信息收集技术,主要有以下五个任务选择一个DNS域名进行查询获取信息 通过IP地址查询地理位置的信息 使用nmap扫描靶机环境 使用nessus扫描靶机环境 通过搜索引擎查询自己的隐私和信息泄露问题结合实验内容阅读书本,总结知识如下:网络踩点:攻击者通…