使用推测解码 (Speculative Decoding) 使 Whisper 实现 2 倍的推理加速

Open AI 推出的 Whisper 是一个通用语音转录模型,在各种基准和音频条件下都取得了非常棒的结果。最新的 large-v3 模型登顶了 OpenASR 排行榜,被评为最佳的开源英语语音转录模型。该模型在 Common Voice 15 数据集的 58 种语言中也展现出了强大的多语言性能,在 42 种语言上的单词错误率 (WER) 低于 30%。

尽管转录准确度非常优秀,但推理速度非常缓慢。即使利用 flash attention 、半精度和 分块 等优化推理技术,1 小时长度的音频在 16GB T4 GPU 上也需要超过 6 分钟的转录时间。

在本文中,我们将演示如何运用推测解码将 Whisper 的推理时间缩减 2 倍,同时在数学上确保完全取得与原模型 相同的输出。因此,这种方法可以完美地替换现有的 Whisper 流水线,因为它可以在不降低准确性的情况下免费获得 2 倍的加速。想要看附带有更简洁解释的全部代码,请参阅配套的 Google Colab。

推测解码

推测解码由 Yaniv Leviathan 等人在 Fast Inference from Transformers via Speculative Decoding 中提出。其思想是,一个更快的 辅助模型 通常会生成和更大的 主模型 相同的 token。

首先,辅助模型会通过自回归生成 个 候选 token 序列: 。在下图中,辅助模型生成了一个包含 5 个候选 token 的序列: The quick brown sock jumps

bcf438baca9f6baac7e55bc49f5fbc5e.gif

尽管这些候选 token 可以快速生成,但它们可能与主模型预测的 token 不同。因此,在第二步中,候选 token 被传入主模型以进行“验证”。主模型将候选 token 作为输入,并执行 单次前馈传播。主模型的输出是每个步骤中“正确”token 的序列 。

98105453dbbb75fafb7a6821afc154a6.gif

在上图中,我们看到主模型预测的前三个 token 与辅助模型的 token 一致: <span style="color:green"> The quick brown 但是,辅助模型的第四个候选 token: “ <span style="color:red"> sock”与主模型的正确 token: “ <span style="color:green"> fox”不一致。

我们知道,所有候选 token 一直到第一个不匹配之前都是正确的 ( <span style="color:green"> The quick brown),因为这些与主模型的预测一致。但是,在第一个不匹配之后,候选 token 开始偏离主模型实际预测的 token。因此,我们可以用主模型的正确 token ( <span style="color:green"> fox) 替换第一个不正确的候选 token ( <span style="color:red"> sock),并放弃之后所有预测的 token,因为这些已经逐渐偏离主模型的预测。经过校正的序列 The quick brown fox 现在成为辅助模型的新输入:

92c60802e33e9ad87389af497ca51a31.gif

然后,辅助模型再次通过自回归推理,生成一组新的 个候选 token,这些 token 再次通过主模型的单次前馈传播进行验证。

cfe2e6634856e33d8cd3629efbeda612.gif

由于我们在生成的时候使用的快速的辅助模型进行自回归,并且缓慢的主模型仅用于验证前馈传播,解码过程将大大加快。此外,经过主模型前馈传播验证后可以确保与仅使用主模型时获得完全相同的输出。这使得推测解码可以完美地替换现有的 Whisper 流水线,因为我们可以确定会取得相同质量的输出。

为了最大限度地减少延迟,辅助模型应该比主模型快得多,同时尽可能频繁地预测相同的 token 分布。实际上,这两个属性之间需要权衡: 模型越快,其准确度越低。然而,由于所有预测 token 中的 70-80% 往往是“较易”的 token,此权衡倾向于选择一个更快的模型,而不是一个更准确的模型。因此,辅助模型应该至少比主模型快 3 倍 (越快越好),同时在示例中正确预测所有较“易”token。剩余的 20-30% 更“难”的 token 可以由更大的主模型进行验证。

选择辅助模型的唯一约束是它必须与主模型使用相同的词汇表。也就是说,辅助模型必须使用与主模型完全一对一相同的分词器。因此,如果我们想对诸如 large-v2 (多语言) 的 Whisper 多语言版本使用推测解码,我们需要选择诸如 tiny 的 Whisper 多语言版本作为辅助模型。而如果我们想对诸如 medium.en 的 Whisper 英文版本使用推测解码,我们需要选择诸如 tiny.en 的 Whisper 英文版本作为辅助模型。目前,large-v3 是唯一一个扩展了词汇量的 Whisper 检查点,因此与以前的 Whisper 检查点不兼容。

现在我们已经了解了推测解码背后的原理,我们准备实际实现它。在 🤗 Transformers 库中,推测解码被实现为“辅助生成 (Assisted Generation)”推理策略。欲了解更多实现细节,建议读者阅读 Joao Gante 关于 辅助生成 的精彩博文。

英文语音转录

基准实现

我们首先使用 Whisper large-v2 进行基准测试,以获得推理速度的基准数值。我们可以通过便捷的 AutoModelForSpeechSeq2SeqAutoProcessor 类加载主模型及其对应的处理器。我们将以 float16 精度加载模型,并通过传递 low_cpu_mem_usage=True 确保加载时间尽可能少。此外,我们要确保模型以 safetensors 格式加载,方法是传递 use_safetensors=True。最后,我们将传递参数 attn_implementation="sdpa" ,以通过 PyTorch 的 SDPA 注意力内核 进行 Flash 注意力加速。

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessordevice = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32model_id = "openai/whisper-large-v2"model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id,torch_dtype=torch_dtype,low_cpu_mem_usage=True,use_safetensors=True,attn_implementation="sdpa",
)
model.to(device)processor = AutoProcessor.from_pretrained(model_id)

让我们加载将用于基准测试的英语语音转录数据集。我们将加载 LibriSpeech ASR 中验证数据集的 clean 分组中的 73 个样本组成的小型数据集。这大约有 9MB 的数据,因此非常轻量且可以快速下载到设备上。

from datasets import load_datasetdataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

对于基准测试,我们只想测量生成时间,所以让我们编写一个简短的辅助函数来测量此步骤运行的时间。下面的函数将同时返回解码的 token 和运行模型所需的时间:

import timedef generate_with_time(model, inputs, **kwargs):start_time = time.time()outputs = model.generate(**inputs, **kwargs)generation_time = time.time() - start_timereturn outputs, generation_time

现在我们可以迭代语音数据集中的音频样本,并统计整体生成时间:

from tqdm import tqdmall_time = 0
predictions = []
references = []for sample in tqdm(dataset):audio = sample["audio"]inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")inputs = inputs.to(device=device, dtype=torch.float16)output, gen_time = generate_with_time(model, inputs)all_time += gen_timepredictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])references.append(processor.tokenizer._normalize(sample["text"]))print(all_time)

Output:

100%|██████████| 73/73 [01:37<00:00, 1.33s/it]
72.99542546272278

很好!我们看到转录 73 个样本花了 73 秒。让我们检查一下预测的 WER:

from evaluate import loadwer = load("wer")
print(wer.compute(predictions=predictions, references=references))

Output:

0.03507271171941831

我们的最终基准数值为 73 秒,WER 为 3.5%。

推测解码

现在让我们加载推测解码的辅助模型。在此示例中,我们将使用 Whisper 蒸馏后的版本 distil-large-v2。蒸馏模型只使用了 Whisper 中 32 个解码器层中的 2 个编码器。因此,它比 Whisper 快 6 倍,同时在分布测试集上的 WER 性能相比于蒸馏前仅下降了 1%。这使其成为理想的辅助模型,因为它在转录准确性和生成速度方面都非常优秀。


我们即将发布 Distil-Whisper 的改进版本,在 token 分布中具有更佳的对齐性,这将进一步提高推测解码性能。关注 Distil-Whisper 存储库 来追踪最新的更新信息。


由于 Distil-Whisper 使用与 Whisper 模型完全相同的编码器,我们可以在主模型和辅助模型之间共享编码器。然后,我们只需要从 Distil-Whisper 加载 2 层解码器作为“仅解码器”模型。我们可以通过便捷的 AutoModelForCausalLM 自动类实现这一点。在实践中,相比于仅使用主模型,这仅增加了 8%的 VRAM 占用量。

from transformers import AutoModelForCausalLMassistant_model_id = "distil-whisper/distil-large-v2"assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id,torch_dtype=torch_dtype,low_cpu_mem_usage=True,use_safetensors=True,attn_implementation="sdpa",
)assistant_model.to(device)

我们可以为推测解码的基准测试定义一个新的函数。与前面的函数唯一的区别是,我们在对 .generate 的调用中传递辅助模型:

def assisted_generate_with_time(model, inputs, **kwargs):start_time = time.time()outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)generation_time = time.time() - start_timereturn outputs, generation_time

让我们使用 Distil-Whisper 作为 Whisper 的助手运行推测解码的基准测试:

all_time = 0
predictions = []
references = []for sample in tqdm(dataset):audio = sample["audio"]inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")inputs = inputs.to(device=device, dtype=torch.float16)output, gen_time = assisted_generate_with_time(model, inputs)all_time += gen_timepredictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])references.append(processor.tokenizer._normalize(sample["text"]))print(all_time)

Outputs:

100%|██████████| 73/73 [00:38<00:00, 1.88it/s]
32.69683289527893

使用推测解码,推理时间仅为 33 秒,比之前快 2.2 倍!让我们验证一下 WER 是否相同:

print(wer.compute(predictions=predictions, references=references))

Outputs:

0.03507271171941831

太完美了!再次达到 3.5%的 WER,因为我们的输出与仅使用主模型的时候完全相同。

推测解码也可以与基础的 🤗 Transformers pipeline API 一起用于推理。下面,我们使用模型和处理器实例化管道,然后使用它来转录测试数据集中的第一个样本。这可以扩展为转录任意长度的音频样本,包括进行批处理:

from transformers import pipelinepipe = pipeline("automatic-speech-recognition",model=model,tokenizer=processor.tokenizer,feature_extractor=processor.feature_extractor,max_new_tokens=128,chunk_length_s=15,batch_size=4,generate_kwargs={"assistant_model": assistant_model},torch_dtype=torch_dtype,device=device,
)sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

Outputs:

Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

使用 Whisper 和 Distil-Whisper 运行推测解码的端到端代码示例可在 Distil-Whisper 模型卡 中找到。它将本文中涵盖的推理阶段组合成一个代码示例。

多语言语音转录

Distil-Whisper 是英语语音转录的最佳辅助模型,因为它与原始 Whisper 模型的 WER 误差率仅相差 1%,而对短长语音样本的推理速度提高了 6 倍。然而,官方的 Distil-Whisper 检查点仅支持英语,这意味着它们无法用于多语言语音转录。

要使用推测解码进行多语言语音转录,您可以使用 官方 Whisper 多语言检查点 之一,或者 Whisper 的微调版本。在撰写本文时,Hugging Face Hub 上已有超过 5000 个微调过的 Whisper 检查点,支持超过 100 种语言。这些为选择表现出色的辅助模型提供了极好的起点。在此示例中,我们将使用最小的官方多语言检查点 Whisper tiny。您可以使用任意一个您的语言中微调过的不同检查点!

让我们为新的辅助模型 Whisper tiny 加载权重。由于 Whisper tiny 的编码器与 large-v2 不同,这次我们将使用 AutoModelForSpeechSeq2Seq 类同时加载编码器和解码器:

assistant_model_id = "openai/whisper-tiny"assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(assistant_model_id,torch_dtype=torch_dtype,low_cpu_mem_usage=True,use_safetensors=True,attn_implementation="sdpa",
)assistant_model.to(device);

我们的基准数据集,将从 VoxPopuli 数据集的荷兰语 (“nl”) 部分中加载 73 个样本:

dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

非常好!现在我们可以像前面一样重新运行我们的 Whisper large-v2 模型的基准测试。我们所做的唯一更改是在 generate 函数中传递语言和任务参数,以确保执行语音转录 (而不是语音翻译)。推测解码完全兼容语音转录和翻译任务。只需如下所示设置任务参数即可:

all_time = 0
predictions = []
references = []for sample in tqdm(dataset):audio = sample["audio"]inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")inputs = inputs.to(device=device, dtype=torch.float16)output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")all_time += gen_timepredictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])references.append(processor.tokenizer._normalize(sample["normalized_text"]))wer_result = wer.compute(predictions=predictions, references=references)print("Time:", all_time)
print("WER:", wer_result)

Outputs:

100%|██████████| 73/73 [02:05<00:00, 1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146

没错!我们的基准时间为 117 秒,WER 为 12.8%。让我们使用推测解码重新运行生成过程:

all_time = 0
predictions = []
references = []for sample in tqdm(dataset):audio = sample["audio"]inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")inputs = inputs.to(device=device, dtype=torch.float16)output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")all_time += gen_timepredictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])references.append(processor.tokenizer._normalize(sample["normalized_text"]))wer_result = wer.compute(predictions=predictions, references=references)print("Time:", all_time)
print("WER:", wer_result)

Outputs:

100%|██████████| 73/73 [01:08<00:00, 1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146

Nice!我们达到了 12.8% 的 WER,但这次的推理时间只有 62 秒,表示速度提高了 1.9 倍。考虑到加载辅助模型的低开销和确保获得完全相同输出的数学证明,推测解码为现有的 Whisper 管道提供了完美的即插即用的替代方案。

高效推测解码的策略

在本最终部分,我们将介绍两种策略,以确保使用推测解码时获得可能最快的推理时间。

辅助模型

我们的目标是选择一个至少比主模型快 3 倍 并且 正确转录至少 70-80% 的预测 token (通常是示例中的“更简单”token) 的辅助模型。如果您想要转录某种特定语言,一种有效的策略是训练两个不同大小的 Whisper 模型,并将其中一个用作另一个的辅助模型:

  • 首先,微调 Whisper large-v3 以用作主模型

  • 其次,在同一数据集上蒸馏 Whisper large-v3 以用作快速的辅助模型

微调和蒸馏都可以提高主模型和辅助模型在您选择的语言上的 WER 性能,同时最大化 token 分布的对齐。有关 Whisper 微调的完整指南,请参阅 此处,有关蒸馏的指南请参阅 此处。

批次大小

值得注意的是,使用推测解码获得的最大速度提升来自批次大小为 1。对于批处理推测解码,批处理中的所有候选 token 必须与验证 token 相匹配,才能被接受。如果批处理中给定位置的 token 不一致,则所有在该位置之前的候选 token 将被丢弃。因此,推测解码更倾向于较小的批次大小。在实践中,我们发现推测解码可以提供速度提升,直到批次大小达到 4 为止。当批次大小超过 4 时,推测解码的推理速度比仅用主模型还要慢。有关完整结果,请参阅 Distil-Whisper 论文 的第 D.3 节。

结论

在本博文中,我们介绍了推测解码的推理策略,以及如何将其应用于语音转录的 Whisper 模型。我们展示了如何实现 2 倍的速度提升,同时数学上确保获得与仅使用原始模型相同的输出。我们鼓励您尝试将推测解码用作现有 Whisper 管道的即插即用替代方案,因为使用额外的辅助模型的开销很小,并且可以保证获得相同的转录结果。

致谢

本博客由 Sanchit Gandhi 撰写。非常感谢 Patrick von Platen 和 Pedro Cuenca 的建设性意见,以及 Joao Gante 在 🤗 Transformers 中实现辅助生成的贡献。

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/whisper-speculative-decoding

作者: Sanchit Gandhi

译者: Hu Yaoqi (yaoqi)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/341646.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024最新适用于 Windows 、Mac 的最佳屏幕录制软件

屏幕录制软件可以帮助我们录制 PC 和MacBook的实时屏幕视频。如果您想为 优酷录制视频&#xff0c;或者您正在为您的公司制作基于视频的项目&#xff0c;并且需要捕获屏幕的实时视频录制&#xff0c;那么我们在此列出了 一 款适合您的 Windows 、Mac的 2024 年最佳屏幕录制软件…

【K8S 云原生】Kurbernets集群的调度策略

目录 一、Kubernetes的list-watch机制 1、List-watch 2、创建pod的过程&#xff1a; 二、scheduler调度的过程和策略&#xff1a; 1、简介 2、预算策略&#xff1a;predicate 3、优先策略&#xff1a; 3.1、leastrequestedpriority&#xff1a; 3.2、balanceresourceal…

深信服态势感知一体机SIP-1000 Y2100 3.0.1Y升级3.0.3Y步骤

当前版本&#xff1a;3.0.1Y 升级后版本&#xff1a;3.0.3Y PS&#xff1a;3.0.1Y不能直升3.0.3Y&#xff0c;需要先通过升级工具升级到3.0.2Y&#xff0c;再安装前置补丁从3.0.2Y升级到3.0.3Y&#xff1b;每一次升级时间为20-30分钟&#xff0c;设备升级会重启&#xff0c;需提…

【一、测试基础】Java基础语法

Java 的用法及注意事项有很多&#xff0c;今天的目标是了解Java基础语法&#xff0c;且能够输出"hello world" 几个基础的概念 对象&#xff1a;对象是类的一个实例&#xff0c;有状态和行为。一只猫是一个对象&#xff0c;猫的状态有&#xff1a;颜色、名字、品种&…

Linux第25步_在虚拟机中备份“ST官方的TF-A源码”

TF-A是ARM公司提供的&#xff0c;ST公司通过修改它&#xff0c;做了一个自己的TF-A代码。因为在后期开发中&#xff0c;若硬件被改变了&#xff0c;我们需要通过修改"ST官方的TF-A源码"就可以自己的TF-A代码了。为了防止源文件被误改了&#xff0c;我们需要将"S…

合并 K 个升序链表[困难]

一、题目 给你一个链表数组&#xff0c;每个链表都已经按升序排列。请你将所有链表合并到一个升序链表中&#xff0c;返回合并后的链表。 示例 1&#xff1a; 输入&#xff1a;lists [[1,4,5],[1,3,4],[2,6]] 输出&#xff1a;[1,1,2,3,4,4,5,6] 解释&#xff1a;链表数组如…

PDF-XChange Editor v10.2.0.384

软件介绍 PDF-XChange Editor&#xff0c;号称打开速度最快最强大的PDF编辑器/PDF阅读器&#xff0c;PDF-XChange专注于PDF文档的编辑&#xff0c;打开PDF文件速度快&#xff0c;软件小功能强大&#xff0c;可以自定义制作PDF电子文档&#xff0c;具有创建&#xff0c;查看&am…

文件与流(C++)

六、文件与流6.1 IO流库概念6.2 istream 与 ostream6.2.1 istream6.2.2 ostream6.2.3 输入输出的格式控制 6.3 string流6.4 文件流 六、文件与流 6.1 IO流库概念 C建立了一个十分庞大的流类库来实现数据的输入输 出操作。其中的每个流类实现不同的功能&#xff0c;这些类通过…

C++学习笔记(三十五):c++ 函数指针及lambda表达式

本节介绍c函数指针。在一些源码中经常能看到c函数指针&#xff0c;但之前一直觉着这一块比较复杂&#xff0c;就一直没去仔细研究&#xff0c;终于有时间去仔细研究这一块内容了。 c风格的函数指针 函数指针是指将一个函数赋值给一个变量的方法&#xff0c;可以将函数作为一个参…

linux命令--管道详解

关于linux命令的管道符的理解 管道&#xff1a;是进程之间的一种通信方式&#xff0c;主要适用的典型类型&#xff1a;父子进程 创建管道的函数c底层接口&#xff1a;int pipe(int f[2]) 返回的int类型&#xff0c;f[1]代表写&#xff0c;f[0]代表读 父进程使用fork函数创建副…

2024美赛数学建模思路 - 复盘:光照强度计算的优化模型

文章目录 0 赛题思路1 问题要求2 假设约定3 符号约定4 建立模型5 模型求解6 实现代码 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 问题要求 现在已知一个教室长为15米&#xff0c;宽为12米&…

2024年1月11日 主题:非枪人生

2024年1月11日15:58:29 2024年1月11日15:35:13 2024年1月11日15:57:51 对物理进行大致预 2024年1月11日20:27:14 结论&#xff1a;不适合进行数据结构的训练和对电路的模拟感受 2024年1月11日20:28:32 今天也平静的结束了 不需要键盘的支持也就这么结束了我也不知道…