LLM 加速技巧:Muti Query Attention

MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。

在介绍MQA 之前,我们先回顾一下传统的多头注意力

Multi-Head Attention(MHA)

多头注意力是transformer 模型的默认注意力机制,如下图所示:

在文本生成方面,基于transformer 的自回归语言模型存在一个问题。在训练过程中可以获得真实的目标序列,并且可以有效地实现并行化。

但是在推理过程中,每个位置的查询都要处理在该位置或之前生成的所有键值对。也就是说自注意力层在特定位置的输出影响下一个令牌的生成,所以无法并行化,这使得推理变得非常的慢。

下图是基于transformer 解码器的自回归语言模型中自注意层的解码过程:

 defMHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):q=tf.einsum("bd, hdk−>bhk", x, P_q)new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis=2)], axis=2)new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis=2)], axis=2)logits=tf.einsum("bhk, bhmk−>bhm", q, new_K)weights=tf.softmax(logits)O=tf.einsum("bhm, bhmv−>bhv", weights, new_V)Y=tf.einsum("bhv, hdv−>bd", O, P_o)returnY, new_K, new_V

其中:

X:当前的输入张量,m为当前步,m+1为阶跃,形状为[b, d]

P_q, P_k:查询和键投影张量,形状为[h, d, k]

P_v:值投影张量,形状为[h, d, v]

P_o:学习到的线性投影,形状为[h, d, v]

Prev_K:上一步的关键张量,形状为[b, h, m, k]

Prev_V:前一步的Value张量,形状为[b, h, m, v]

new_K:加上当前步的键张量,形状为[b, h, m+1, k]

new_V:加了当前步长的Value张量,形状为[b, h, m+1, v]

维度表示如下:

M:先前执行的步骤数

B:批量大小

D:输入和输出的尺寸

H:注意力头数

k:Q,K张量的另一个维度

v: v张量的另一个维度

Multi-Query Attention(MQA)

MQA是多头注意的一种变体。

MQA的方法是保持Q的初始头数,但K和V只有一个头,这意味着所有Q个头共享相同的K和V,因此称为Multi-Query,如下图所示:

从论文的解释中可以看到,MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

MQA解码过程的代码本质上与MHA的代码相同,只是从中删除了表示头部尺寸的字母“h”。K, V, P_k, P_v的和方程:

 defMQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):q=tf.einsum("bd, hdk−>bhk", x, P_q)new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis=2)], axis=2)new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis=2)], axis=2)logits=tf.einsum("bhk, bmk−>bhm", q, new_K)weights=tf.softmax(logits)O=tf.einsum("bhm, bmv−>bhv", weights, new_V)Y=tf.einsum("bhv, hdv−>bd", O, P_o)returnY, new_K, new_V

上面都是tf的代码,如果阅读有问题,我从 llm-foundry项目中找到了pytorch的代码实现,这里只做个摘抄,有兴趣的请看原项目

 classMultiheadAttention(nn.Module):def__init__(self,d_model: int,n_heads: int,device: str):"""Multi Head init func.Args:d_model (int): hidden state size, e.g. 768n_heads (int): 设定的注意力头数, e.g. 8device (str): _description_"""super().__init__()self.d_model=d_modelself.n_heads=n_headsself.Wqkv=nn.Linear(                       # Multi-Head Attention 的创建方法self.d_model, 3*self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_modeldevice=device)                                            # (d_model, 3 * d_model)self.attn_fn=scaled_multihead_dot_product_attentionself.out_proj=nn.Linear(self.d_model, self.d_model, device=device)defforward(self,x):"""forward func.Args:x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)Returns:_type_: _description_"""qkv=self.Wqkv(x)                            # (1, 768, 3 * 768)query, key, value=qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)3, dim=2)     context, attn_weights, past_key_value=self.attn_fn(query,key,value,self.n_heads)                                             # (1, 512, 768)returnself.out_proj(context), attn_weights, past_key_valueclassMultiQueryAttention(nn.Module):"""Multi-Query self attention.Using torch or triton attention implemetation enables user to also useadditive bias."""def__init__(self,d_model: int,n_heads: int,device: Optional[str] =None,):super().__init__()self.d_model=d_modelself.n_heads=n_headsself.head_dim=d_model//n_headsself.Wqkv=nn.Linear(                           # Multi-Query Attention 的创建方法d_model,d_model+2*self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_modeldevice=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量)self.attn_fn=scaled_multihead_dot_product_attentionself.out_proj=nn.Linear(self.d_model, self.d_model, device=device)self.out_proj._is_residual=True  # type: ignoredefforward(self,x,):qkv=self.Wqkv(x)                                           # (1, 512, 960)query, key, value=qkv.split(                               # query -> (1, 512, 768)[self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)dim=2                                                    # value -> (1, 512, 96))context, attn_weights, past_key_value=self.attn_fn(query,key,value,self.n_heads,multiquery=True,)returnself.out_proj(context), attn_weights, past_key_value

从代码中可以看到所有 头之间共享一份 key 和 value 的参数,但是如何将这 1 份参数同时让 8 个头都使用呢?

代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享,主要是这个函数:scaled_multihead_dot_product_attention

 defscaled_multihead_dot_product_attention(query,key,value,n_heads,past_key_value=None,softmax_scale=None,attn_bias=None,key_padding_mask=None,is_causal=False,dropout_p=0.0,training=False,needs_weights=False,multiquery=False,):q=rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)kv_n_heads=1ifmultiqueryelsen_headsk=rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery # (1, 512, 96) -> (1, 1, 96, 512)  if multiqueryv=rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery # (1, 512, 96) -> (1, 1, 512, 96)  if multiqueryattn_weight=q.matmul(k) *softmax_scale                       # (1, 8, 512, 512)attn_weight=torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)out=attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)out=rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)returnout, attn_weight, past_key_value

MQA指标测试

MQA能在多大程度上提高速度?让我们看看原文中提供的结果图表:

从上表可以看出,MQA在编码器上的速度提升不是很显著,但在解码器上的速度提升是相当显著的。

论文中也有关于质量的实验,结果表明MQA的性能与基线相比只是稍微低一些。降低应该是肯定的因为毕竟共享了参数,但是只要再可接受范围内并且能够大量提升速度这个降低就是可以接受的,对吧。

为什么MQA可以实现推理加速?

在MQA中,键张量和值张量的大小分别为b * k和b * v,而在MHA中,键张量和值张量的大小分别为b * h * k和b * h * v,其中h表示头的个数。

MQA通过以下方法实现推理加速:

1、KV缓存大小减少了h(头数量),这意味着需要存储在GPU内存中的张量也减少了。节省的空间可以用来增加批大小,从而提高效率。

2、减少了从内存中读取的数据量,从而减少了计算单元的等待时间,提高了计算利用率。

3、MQA有一个相对较小的KV数量,可以放入缓存(SRAM)中。MHA则需要较大的KV数量,不能完全存储在缓存中,需要从GPU内存(DRAM)读取,这很耗时。

总结

MQA是在2019年提出的,当时的应用还没有那么广泛。这是因为以前的模型不需要关心这些方面,例如,LSTM只需要维护一个状态,而不需要保留任何缓存。

当transformer最初被提出时,它主要用于Seq2Seq任务,特别是在Encoder-Decoder模型中。由于模型的规模不是很大,也并且没有太多的实际需求,所以MQA并没有引起太多的关注。

直到近年来(尤其是2023年开始)基于transformer的大型语言模型(如GPT)得到广泛应用后,推理的瓶颈才被人们重视。所以MQA才被发现非常有用,这主要是由于对大规模gpt式生成模型的实际需求。

最后我们再回顾以下这个论文:

https://avoid.overfit.cn/post/877de0f5a56d478d8133d75a05064e7e

作者:Florian June

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/522609.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

穷人想赚钱该怎么选打工VS创业?2024年如何把握新机遇?

在贫穷的困境中,打工与创业似乎成为了两条截然不同的道路,摆在每一个渴望改变命运的人面前。然而,这并非简单的选择题,而是一场关于勇气、智慧与机遇的较量。打工,对于许多人来说,是稳定且相对安全的收入来…

遗传算法理解与代码实战(二)- demo(python+deap)

前文介绍了遗传算法,并且手动python代码进行了实践,但是在遇到复杂的问题时(遗传算法理解与代码实战(三)会介绍),手写代码很麻烦,所以需要借助专门的遗传算法库来实现,这…

社区医院智慧管理:Java+SpringBoot新实践

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

酷炫!向数字世界 AGI 迈进!让智能体直接控制键盘、鼠标,与一切软件交互

信息革命催生了数字世界,这个世界为大模型提供了海量数据,同时也为通用人工智能(AGI)的实现提供了可能。在迈向数字世界的 AGI 的过程中,北京智源人工智能研究院、新加坡南洋理工大学和北京大学联合提出了一种名为 Gen…

数据结构——lesson7二叉树 堆的介绍与实现

前言💞💞 啦啦啦~这里是土土数据结构学习笔记🥳🥳 💥个人主页:大耳朵土土垚的博客 💥 所属专栏:数据结构学习笔记 💥对于数据结构顺序表链表有疑问的都可以在上面数据结…

ai直播数字人:AI大模型应用开发的神奇世界

当AI技术的发展走向一个新的高峰,AI直播数字人逐渐成为人们关注的焦点。这种全新的数字人形态,通过大模型应用开发,带来了一个神奇世界。 在这个神奇世界里,AI直播数字人可以展现出与真实人类相媲美的外貌和声音。通过先进的图像…

[递归、搜索、回溯]----递归

前言 作者:小蜗牛向前冲 专栏:小蜗牛算法之路 专栏介绍:"蜗牛之道,攀登大厂高峰,让我们携手学习算法。在这个专栏中,将涵盖动态规划、贪心算法、回溯等高阶技巧,不定期为你奉上基础数据结构…

ROS 2基础概念#6:服务(Service)| ROS 2学习笔记

服务(Service)是 ROS 2 计算图中节点通信的另一种方法。 服务基于调用和响应模型,而不是主题的发布者-订阅者模型。 虽然主题允许节点订阅数据流并获取持续更新,但服务仅在客户端专门调用时才提供数据。 ROS 2服务的基本概念 ROS…

UE4升级UE5 蓝图节点变更汇总(4.26/27-5.2/5.3)

一、删除部分 Ploygon Editing删除 Polygon Editing这个在4.26、4.27中的插件,在5.1后彻底失效。 相关的蓝图,如编辑器蓝图 Generate mapping UVs等,均失效。 如需相关功能,请改成Dynamic Mesh下的方法。 GetSupportedClass删…

微服务超大Excel文件导出方案优化

1、在导出Excel时经常会碰到文件过大,导出特别慢 2、微服务限制了请求超时时间,文件过大情况必然超时 优化思路: 1、文件过大时通过文件拆分、打包压缩zip,然后上传到oss,并设置有效期(30天过期) 2、把…

便捷在线导入:完整Axure元件库集合,让你的设计更高效!

Axure元件库包含基本的工具组件,可以使原型绘制节省大量的重复工作,保持整个设计页面的一致性和标准化,同时显得专业。Axure元件库就像我们日常生活中的门把手、自行车踏板和桌子上的螺丝钉,需要组装才能使用。作为一名成熟的产品…

搜索引擎都没流量啦,官网建设还有啥意义?

百度等搜索引擎都没啥流量了,再建设官网还有啥用?如果你把官网定位于获客,那真的没啥太大用处,但是官网不仅仅是用来获客的。 一、搜索引擎的流量被稀释了 搜索引擎流量减少的原因有多个, 1. 社交媒体的崛起&#xf…