transformer之KV Cache

一、为什么要研究KV Cache

非常有效的加速推理速度,效果如下所示:

import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
NAME_OR_PATH = r'***************'
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(NAME_OR_PATH)
model = AutoModelForCausalLM.from_pretrained(NAME_OR_PATH).to(device)
model.config.pad_token_id = tokenizer.eos_token_id
for use_cache in (True, False):times = []for _ in range(10):  # measuring 10 generationsstart = time.time()model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=10)times.append(time.time() - start)print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")#===================================max_new_tokens=1=======================================================
with KV caching: 0.072 +- 0.008 seconds
without KV caching: 0.081 +- 0.02 seconds
#===================================max_new_tokens=10=======================================================
with KV caching: 0.428 +- 0.021 seconds
without KV caching: 0.751 +- 0.04 seconds
#===================================max_new_tokens=100=======================================================
with KV caching: 4.606 +- 0.072 seconds
without KV caching: 19.257 +- 1.663 seconds
#===================================max_new_tokens=1000=======================================================
with KV caching: 42.941 +- 0.711 seconds
without KV caching: 349.954 +- 1.523 seconds

使用KV Cache的推理速度是明显优于没有使用KV Cache的,而且生成的token越长速度提升就越明显,当最大生成token数为1000时,近10倍的加速,一次推理近6分钟。

二、为什么KV Cache能加速

2.1 原理是什么

最本质的原理是避免重复计算,将需要重复计算的结果进行缓存,需要缓存的值为历史token对应的KV值,所以叫KV Cache。

2.2 为什么只需要KV

预测新的token只与输入的最后一个token相关,输入的最后一个token因为只需要计算注意力值,而注意力的值需要将输入token的V值进行加权即得到结果,进行加权就需要将当前的Q与与所有的K进行计算得到权重,所以只需要缓存历史token的KV值。

2.2 为什么会存在重复计算

首先,生成式模型每生成一个新token都需要调用整个模型进行一次推理,历史token计算得到的中间激活值在Decoder架构的模型中每次推理时都是一样的,所以可以进行缓存。
这是因为Decoder架构中,当前token只用之前的token计算得到注意力值,通过Causal Mask实现,换句话说,在推理的时候前面已经生成的字符不需要与后面的字符产生attention,从而使得前面已经计算的K和V可以缓存起来。

2.3 预测新token的计算步骤

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
由于Causal Mask矩阵的存在,预测下一个token只与输入的最后一个token的QKV和历史token的KV有关;
如果没有Causal Mask,比如说是encoder架构,每次推理时每个token需要考虑所有输入的token,所以得到的注意力值都会变化,就不存在重复计算的情况。

三、结论

1、KV Cache是通过空间换时间,避免重复计算进而提升推理速度
2、预测新token只与输入的最后一个token的注意力值相关,而注意力值与最后一个token的Q和所有输入token的KV相关,每一层的注意力不变,进而每一层的KV都是不变的
3、只适用于Decoder架构,因为只与之前的token进行计算,得到的注意力值不会变化,第一层、第二层、第三层到第 l l l层; 如果只有一层,那就不仅仅适用于Decoder架构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/208666.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于docker实现JMeter分布式压测

为什么需要分布式? 在工作中经常需要对一些关键接口做高QPS的压测,JMeter是由Java 语言开发,没创建一个线程(虚拟用户),JVM默认会为每个线程分配1M的堆栈内存空间。受限于单台试压机的配置很难实现太高的并…

[网鼎杯 2018]Fakebook

[网鼎杯 2018]Fakebook 打开环境出现一个登录注册的页面 在登录和注册中发现 了地址栏出现变化&#xff0c;扫一波看看 看看robots.txt和flag.php 访问robots.txt看看 再访问user.php.bak <?php class UserInfo { public $name ""; public …

扩散模型实战(十一):剖析Stable Diffusion Pipeline各个组件

推荐阅读列表&#xff1a; 扩散模型实战&#xff08;一&#xff09;&#xff1a;基本原理介绍 扩散模型实战&#xff08;二&#xff09;&#xff1a;扩散模型的发展 扩散模型实战&#xff08;三&#xff09;&#xff1a;扩散模型的应用 扩散模型实战&#xff08;四&#xff…

JVM 之 class文件详解

目录 一. 前言 二. class文件结构 2.1. 文件格式 2.2. 魔数与版本号 2.3. 常量池 2.4. 访问标志 2.5. 类索引、父类索引和接口索引集合 2.6. 字段表集合 2.7. 方法表集合 2.8. 属性表集合 2.8.1. Code 属性表 2.8.2. Exceptions 属性 2.8.3. LineNumberTable 属性…

[Unity+OpenAI TTS] 集成openAI官方提供的语音合成服务,构建海王暖男数字人

1.简述 最近openAI官方发布了很多新功能&#xff0c;其中就包括了最新发布的TTS语音合成服务的api接口。说到这个语音合成接口&#xff0c;大家可能会比较陌生&#xff0c;但是说到chatgpt官方应用上的聊天机器人&#xff0c;那个台湾腔的海王暖男的声音&#xff0c;可能就有印…

Oracle与Redis Enterprise协同,作为企业缓存解决方案

来源&#xff1a;虹科云科技 虹科干货丨Oracle与Redis Enterprise协同&#xff0c;作为企业缓存解决方案 欢迎关注虹科&#xff0c;为您提供最新资讯&#xff01; 单独使用Oracle作为企业缓存数据库时&#xff0c;会出现哪些问题呢&#xff1f;使用Redis Enterprise与Oracle共…

汇编-CALL和RET指令

CALL指令调用一个过程&#xff0c; 使处理器从新的内存位置开始执行。过程使用RET(从过程返回) 指令将处理器转回到该过程被调用的程序点上。 CALL指令的动作&#xff1a; 1.将CALL指令的下一条指令地址压栈(作为子过程返回的地址) 2.将被调过程的地址复制到指令指针寄存器E…

七、通过libfdk_aac编解码器实现aac音频和pcm的编解码

前言 测试环境&#xff1a; ffmpeg的4.3.2自行编译版本windows环境qt5.12 AAC编码是MP3格式的后继产品&#xff0c;通常在相同的比特率下可以获得比MP3更高的声音质量&#xff0c;是iPhone、iPod、iPad、iTunes的标准音频格式。 AAC相较于MP3的改进包含&#xff1a; 更多的采…

Eclipse常用设置-乱码

在用Eclipse进行Java代码开发时&#xff0c;经常会遇到一些问题&#xff0c;记录下来&#xff0c;方便查看。 一、properties文件乱码 常用的配置文件properties里中文的乱码&#xff0c;不利于识别。 处理流程&#xff1a;Window -> Preferences -> General -> Ja…

「快学Docker」监控和日志记录容器的健康和性能

「快学Docker」监控和日志记录容器的健康和性能 1. 容器健康状态监控2. 性能监控3. 日志记录几种采集架构图 4. 监控工具和平台cAdvisor&#xff08;Container Advisor&#xff09;PrometheusGrafana 5. 自动化运维 1. 容器健康状态监控 方法1&#xff1a;需要实时监测容器的运…

机器人算法—ROS TF坐标变换

1.TF基本概念 &#xff08;1&#xff09;什么是TF&#xff1f; TF是Transformations Frames的缩写。在ROS中&#xff0c;是一个工具包&#xff0c;提供了坐标转换等方面的功能。 tf工具包&#xff0c;底层实现采用的是一种树状数据结构&#xff0c;根据时间缓冲并维护多个参考…

2022年03月 Scratch(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

Scratch等级考试(1~4级)全部真题・点这里 一、单选题(共25题,每题2分,共50分) 第1题 红框中加入哪个选项积木,不能阻止气球下落? A: B: C: D: