note
(1)近似注意力:
- Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。
- Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。
(2)在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存并消耗了大量的计算资源。如何优化自注意力机制的时空复杂度、增强计算效率是大语言模型面临的重要问题。
- 方法一:从近似注意力出发,旨在减少注意力计算和内存需求,提出了稀疏近似、低秩近似等方法。
- 方法二:从计算加速设备本身的特性出发,研究如何更好地利用硬件特性对Transformer 中的注意力层进行高效计算。
(3)FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵。
文章目录
- note
- 一、近似注意力
- 1. 基于位置的稀疏注意力机制
- 2. 基于内容的稀疏注意力机制
- (1)Routing Transformer:使用聚类
- (2)Reformer:使用LSH哈希
- 二、计算加速
- 1. GPU硬件基础知识
- 2. flashattention
- 3. 多查询注意力MQA
- (1)MHA和MQA的区别
- (2)MHA和MQA的具体代码
- (3)使用矩阵乘法matmul广播实现参数共享
- (4)tgi框架中的MQA
- Reference
一、近似注意力
对一些训练好的Transformer 结构中的注意力矩阵进行分析时发现,其中很多是稀疏的,因此可以通过限制Query-Key 对的数量来降低计算复杂度。这类方法称为稀疏注意力(SparseAttention)机制。可以将稀疏化方法进一步分成基于位置的和基于内容信息的两类。
1. 基于位置的稀疏注意力机制
基于位置的稀疏注意力机制的基本类型如下图,主要包含如下五种类型:全局注意力(Global Attention)、带状注意力(Band Attention)、膨胀注意力(Dilated Attention)、随机注意力(Random Attention)、局部块注意力(Block Local Attention)。
这些注意力机制的区别主要在于它们如何选择序列中的元素来计算注意力权重,这直接影响计算复杂度、处理长距离依赖的能力以及对不同类型任务的适用性。每种注意力机制的关键区别和特点:
-
全局注意力(Global Attention):
- 关键特点:在计算每个位置的注意力时,考虑序列中的所有其他位置。
- 优点:能够捕获全局依赖性,理论上可以处理任意距离的关系。
- 缺点:计算复杂度高,随序列长度的平方增长,不适合处理长序列。
-
带状注意力(Band Attention):
- 关键特点:仅在每个位置的一个固定宽度的带内计算注意力权重,通常集中在序列的对角线附近。
- 优点:减少了计算量,适合捕获局部依赖性。
- 缺点:可能忽略重要的长距离依赖。
-
膨胀注意力(Dilated Attention):
- 关键特点:通过引入膨胀因子来间隔地选择序列中的元素进行注意力计算,从而覆盖更广的范围。和CNN中的Dilated Conv类似,通过增加空隙以获取更大的感受野
- 优点:在降低计算复杂度的同时,能够捕获更远的依赖性。
- 缺点:可能不如全局注意力在捕捉所有长距离依赖上有效。
-
随机注意力(Random Attention):
- 关键特点:随机选择序列中的位置来计算注意力权重。即通过随机采样,提升非局部的交互。
- 优点:显著降低计算需求,引入随机性可能帮助模型探索更多的依赖关系。
- 缺点:随机性可能导致忽略一些关键的依赖关系。
-
局部块注意力(Block Local Attention):
- 关键特点:将序列分割成多个块,在这些局部块内计算注意力权重。使用多个不重叠的块Block来限制信息交互。
- 优点:大幅降低计算复杂度,适合处理长序列。
- 缺点:如果不允许跨块计算,则可能忽略块间的依赖关系。
总结来说,这些注意力机制通过不同的策略平衡计算复杂度和模型的捕获依赖能力。选择哪种注意力机制取决于特定任务的需求,例如处理长序列数据时可能更倾向于使用带状、膨胀、随机或局部块注意力机制,而在不那么受限于计算资源的情况下,全局注意力可能是最好的选择,因为它能够捕获全局依赖性。
下面给出带状注意力的栗子:
# query-shape: [bs, seq_len, emb_dim]
def band_attention(query, key, value, band_width):"""Args:query, key, value: standard attention inputsband_width: The width of the band around the diagonal to compute attention.Returns:Tensor: The output of the attention mechanism."""batch_size, seq_len, d_k = query.size()scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)# Create a mask to zero out attention scores outside the bandidxs = torch.arange(seq_len).unsqueeze(0).to(query.device)mask = (idxs - idxs.transpose(0, 1)).abs().ge(band_width).to(scores.dtype)scores.masked_fill_(mask, float('-inf'))attention = F.softmax(scores, dim=-1)output = torch.matmul(attention, value)return output# 测试的case
def band_attention_test():import torch# 假设输入数据的维度batch_size = 2seq_length = 10embed_size = 128heads = 8# 生成随机数据作为输入values = torch.rand(batch_size, seq_length, embed_size)keys = torch.rand(batch_size, seq_length, embed_size)queries = torch.rand(batch_size, seq_length, embed_size)# 定义带宽band_width = 3# 使用相同的随机数据输入band_attention_output = band_attention(queries, keys, values, band_width)# Band Attention Output Shape: torch.Size([2, 10, 128])print("Band Attention Output Shape:", band_attention_output.shape)
可以看到上面的mask
矩阵确实是带状的:
现有的稀疏注意力机制,通常是基于上述五种基于位置的稀疏注意力机制的复合模式,下图给出了一些典型的稀疏注意力模型:
- star-transformer:使用带状注意力和全局注意力的组合,只包括一个全局注意力节点和宽度为3的带状注意力,其中任意两个非相邻节点通过一个共享的全局注意力连接,而相邻节点则直接相连。
- longformer:将上层中的一些带状注意力头部替换为具有扩张窗口的注意力,在增加感受野同时不增加计算量
- ETC(Extended Transformer Construction):利用带状注意力和外部全局节点注意力(External Global-node Attention)的组合。ETC 稀疏注意力还包括一种掩码机制来处理结构化输入,并采用对比预测编码(Contrastive Predictive Coding,CPC)进行预训练。
- BigBird:使用带状和全局注意力,还使用额外的随机注意力来近似全连接注意力,此外还揭示了稀疏编码器和稀疏解码器的使用可以模拟任何图灵机
2. 基于内容的稀疏注意力机制
基于内容的稀疏注意力机制根据输入数据创建稀疏注意力,其中一种很简单的方法是选择和给定查询 (Query) 有很高相似度的键 (Key)。
(1)Routing Transformer:使用聚类
(1)Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。中心向量采用滑动平均的方法进行更新:
μ ~ ← μ ~ + ( 1 − λ ) ( ∑ i : μ ( q i ) = μ q i + ∑ j : μ ( k j ) = μ k j ) c μ ← λ c μ + ( 1 − λ ) ∣ μ ∣ μ ← μ ~ c μ \begin{gathered} \widetilde{\boldsymbol{\mu}} \leftarrow \tilde{\boldsymbol{\mu}}+(1-\lambda)\left(\sum_{i: \mu\left(\boldsymbol{q}_i\right)=\mu} \boldsymbol{q}_i+\sum_{j: \mu\left(\boldsymbol{k}_j\right)=\mu} \boldsymbol{k}_j\right) \\ c_\mu \leftarrow \lambda c_\mu+(1-\lambda)|\mu| \\ \mu \leftarrow \frac{\widetilde{\boldsymbol{\mu}}}{c_\mu} \end{gathered} μ ←μ~+(1−λ) i:μ(qi)=μ∑qi+j:μ(kj)=μ∑kj cμ←λcμ+(1−λ)∣μ∣μ←cμμ
(2)Reformer:使用LSH哈希
(2)Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。假设 b b b 是桶的个数,给定一个大小为 [ D k , b / 2 ] [D k , b / 2] [Dk,b/2] 的随机矩阵 R R R , LSH 函数的定义为:
h ( x ) = arg max ( [ x R ; − x R ] ) h(\boldsymbol{x})=\arg \max ([\boldsymbol{x} R ;-\boldsymbol{x} R]) h(x)=argmax([xR;−xR])
如果 h q i = h k j h \boldsymbol{q}_i=h \boldsymbol{k}_j \quad hqi=hkj 时, q i \boldsymbol{q}_i qi 才可以与相应的Key-Value对进行交互。
二、计算加速
1. GPU硬件基础知识
NVIDIA GPU中的内存(显存)按照它们物理上是在GPU芯片内部还是板卡RAM存储芯片上,决定了它们的速度、大小以及访问限制。GPU显存分为:
- 全局内存(Global memory)
- 本地内存(Local memory)
- 共享内存(Shared memory,SRAM)
- 寄存器内存(Register memory)
- 常量内存(Constant memory)
- 纹理内存(Texture memory)
全局内存和本地内存使用的高带宽显存(High Bandwidth Memory,HBM)位于板卡RAM存储芯片上,该部分内存容量很大。全局内存是所有线程都可以访问,而本地内存则只能当前线程访问。NVIDIA H100中全局内存有80GB空间,其访问速度虽然可以达到3.35TB/s,但是如果全部线程同时访问全局内存时,其平均带宽仍然很低。
共享内存和寄存器位于GPU芯片上,因此容量很小,并且共享内存只有在同一个GPU线程块(Thread Block)内的线程才可以共享访问,而寄存器仅限于同一个线程内部才能访问。NVIDIA H100中每个GPU线程块在流式多处理器(Stream Multi-processor,SM)可以使用的共享存储容量仅有228KB,但是其速度非常快,远高于全局内存的访问速度。
根据自注意力机制的原理,在GPU中进行计算时,传统的方法还需要引入两个中间矩阵 S 和 P 并存储到全局内存中。具体计算过程如下:
S = Q × K , P = Softmax ( S ) , O = P × V \boldsymbol{S}=\boldsymbol{Q} \times \boldsymbol{K}, \boldsymbol{P}=\operatorname{Softmax}(\boldsymbol{S}), \boldsymbol{O}=\boldsymbol{P} \times \boldsymbol{V} S=Q×K,P=Softmax(S),O=P×V
按照上述计算过程,需要:
- 首先从全局内存中读取矩阵 Q Q Q 和 K K K ,并将计算好的矩阵 S S S再写入全局内存
- 之后再从全局内存中获取矩阵 S S S ,计算Softmax得到矩阵 P P P 再写入全局内存
- 之后读取矩阵 P P P 和矩阵 V V V ,计算得到矩阵 O O O 。
这样的过程会极大占用显存的带宽。在自注意力机制中,计算速度比内存速度快得多 ,因此计算效率越来越多地受到全局内存访问的瓶颈。
2. flashattention
FlashAttention就是通过利用GPU硬件中的特殊设计,针对全局内存和共享存储的I/O速度的不同,尽可能地避免HBM中读取或写入注意力矩阵。
FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵。
FlashAttention 就提出了不使用中间注意力矩阵,通过存储归一化因子来减少全局内存消耗的方法。
FlashAttention 算法并没有将S、P 整体写入全局内存,而是通过分块写入,存储前向传递的Softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从全局内存中读取中间注意力矩阵的标准方法更快。
虽然大幅减少了全局内存的访问量,重新计算也导致FLOP(FLOPS指标,Floating Point Operations per Second 指每秒浮点运算次数) 增加,但其运行的速度更快且使用的内存更少。
3. 多查询注意力MQA
多查询注意力(Multi Query Attention)是多头注意力的一种变体。其特点是,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数,因此键和值的矩阵仅有一份,这大幅减少了显存占用,使其更高效。
由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。文献研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约5% 的原始训练数据量就可以达到不错的效果。
包括Falcon[64]、SantaCoder[65]、StarCoder[66] 在内的很多模型都采用了多查询注意力机制。
(1)MHA和MQA的区别
MHA 和 MQA 之间的区别主要在于建立 Wqkv Layer
上(如下代码)。在MQA中,除了query向量还保存8个头,key和value向量都只剩下1个【公共头】,即前面说的所有head之间共享一份key和value参数。
# Multi Head Attention
self.Wqkv = nn.Linear( # 【关键】Multi-Head Attention 的创建方法self.d_model, 3 * self.d_model, # 有 query, key, value 3 个矩阵, 所以是 3 * d_modeldevice=device
)
query, key, value = qkv.chunk( # 【关键】每个 tensor 都是 (1, 512, 768)3, dim=2
)# Multi Query Attention
self.Wqkv = nn.Linear( # 【关键】Multi-Query Attention 的创建方法d_model,d_model + 2 * self.head_dim, # 只创建 query 的 head 向量,所以只有 1 个 d_modeldevice=device, # 而 key 和 value 不再具备单独的头向量
)
query, key, value = qkv.split( # query -> (1, 512, 768)[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)dim=2 # value -> (1, 512, 96)
)
(2)MHA和MQA的具体代码
其中MultiheadAttention
和MultiQueryAttention
类完整的代码如下。
class MultiheadAttention(nn.Module):def __init__(self,d_model: int,n_heads: int,device: str):"""Multi Head init func.Args:d_model (int): hidden state size, e.g. 768n_heads (int): 设定的注意力头数, e.g. 8device (str): _description_"""super().__init__()self.d_model = d_modelself.n_heads = n_headsself.Wqkv = nn.Linear( # Multi-Head Attention 的创建方法self.d_model,3 * self.d_model, # 有 query, key, value 3 个矩阵, 所以是 3 * d_modeldevice=device) # (d_model, 3 * d_model)self.attn_fn = scaled_multihead_dot_product_attentionself.out_proj = nn.Linear(self.d_model,self.d_model,device=device)def forward(self,x):"""forward func.Args:x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)Returns:_type_: _description_"""qkv = self.Wqkv(x) # (1, 768, 3 * 768)query, key, value = qkv.chunk( # 每个 tensor 都是 (1, 512, 768)3,dim=2) context, attn_weights, past_key_value = self.attn_fn(query,key,value,self.n_heads) # (1, 512, 768)return self.out_proj(context), attn_weights, past_key_valueclass MultiQueryAttention(nn.Module):"""Multi-Query self attention.Using torch or triton attention implemetation enables user to also useadditive bias."""def __init__(self,d_model: int,n_heads: int,device: Optional[str] = None,):super().__init__()self.d_model = d_modelself.n_heads = n_headsself.head_dim = d_model // n_headsself.Wqkv = nn.Linear( # Multi-Query Attention 的创建方法d_model,d_model + 2 * self.head_dim, # 只创建 query 的 head 向量,所以只有 1 个 d_modeldevice=device, # 而 key 和 value 则只共享各自的一个 head_dim 的向量)self.attn_fn = scaled_multihead_dot_product_attentionself.out_proj = nn.Linear(self.d_model,self.d_model,device=device)self.out_proj._is_residual = True # type: ignoredef forward(self,x,):qkv = self.Wqkv(x) # (1, 512, 960)query, key, value = qkv.split( # query -> (1, 512, 768)[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)dim=2 # value -> (1, 512, 96))context, attn_weights, past_key_value = self.attn_fn(query,key,value,self.n_heads,multiquery=True,)return self.out_proj(context), attn_weights, past_key_value
(1)初始化函数 __init__
__init__(self, d_model: int, n_heads: int, device: Optional[str] = None)
: 这是类的初始化函数,用于创建类的实例时初始化其属性。它接受三个参数:模型的维度d_model
、注意力头的数量n_heads
,以及设备device
(可选),用于指定模块运行的硬件(CPU或GPU)。self.d_model = d_model
和self.n_heads = n_heads
: 这两行代码将传入的模型维度和头的数量保存为类的属性。self.head_dim = d_model // n_heads
: 计算每个头的维度,即将模型维度均分到每个头上。self.Wgkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
: 创建一个线性层Wgkv
,用于生成查询(Q)、键(K)和值(V)。这个线性层的输出维度是d_model + 2 * self.head_dim
,意味着查询的维度保持为d_model
,而键和值的维度为self.head_dim
。这种设计减少了模型参数,因为它没有为键和值分别创建额外的线性变换。self.attn_fn = scaled_multihead_dot_product_attention
和self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
: 定义了一个注意力函数attn_fn
和一个输出投影层out_proj
。attn_fn
负责计算多头点积注意力,而out_proj
用于将注意力机制的输出转换回原始输入的维度。
(2)前向传播函数 forward
def forward(self, X)
: 定义了前向传播函数,它接收一个输入张量X
。gkv = self.Wgkv(X)
: 首先,输入通过Wgkv
线性层,产生了合并的查询、键、值矩阵。query, key, value = gkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
: 然后,将gkv
拆分为查询、键和值三部分。注意拆分的维度与Wgkv
层的输出设计相匹配。context, attn_weights, past_key_value = self.attn_fn(query, key, value, self.n_heads, multiquery=True)
: 使用定义的注意力函数计算注意力,multiquery=True
参数指示使用多查询注意力机制。return self.out_proj(context), attn_weights, past_key_value
: 最后,将注意力的输出通过out_proj
投影层,然后将结果、注意力权重和过去的键值对返回。
(3)使用矩阵乘法matmul广播实现参数共享
其中注意上面的scaled_multihead_dot_product_attention
函数就是实现刚才说的一份key和value参数让多个头使用,使用矩阵乘法matmul
进行广播,实现参数共享。
def scaled_multihead_dot_product_attention(query,key,value,n_heads,multiquery=False,):q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) # (1, 512, 768) -> (1, 8, 512, 96)kv_n_heads = 1 if multiquery else n_headsk = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery # (1, 512, 96) -> (1, 1, 96, 512) if multiqueryv = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery # (1, 512, 96) -> (1, 1, 512, 96) if multiqueryattn_weight = q.matmul(k) * softmax_scale # (1, 8, 512, 512)attn_weight = torch.softmax(attn_weight, dim=-1) # (1, 8, 512, 512)out = attn_weight.matmul(v) # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)out = rearrange(out, 'b h s d -> b s (h d)') # (1, 512, 768)return out, attn_weight, past_key_value
(4)tgi框架中的MQA
具体还可以参考tgi框架中的MQA代码:
class MultiQueryAttention(nn.Module):"""Multi-Query self attention.Using torch or triton attention implementation enables user to also useadditive bias."""def __init__(self, config, prefix, weights):super().__init__()attn_impl = config.attn_config["attn_impl"]self.attn_impl = config.attn_config["attn_impl"]self.clip_qkv = config.attn_config["clip_qkv"]self.qk_ln = config.attn_config["qk_ln"]self.d_model = config.d_modeld_model = config.d_modelself.n_heads = config.n_headsself.softmax_scale = config.attn_config["softmax_scale"]if self.softmax_scale is None:self.softmax_scale = 1 / math.sqrt(self.head_dim)self.attn_dropout_p = config.attn_config["attn_pdrop"]# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)self.Wqkv = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias)fuse_splits = (d_model, d_model + self.head_dim)if self.qk_ln:raise NotImplementedError("qk_ln not supported")if self.attn_impl == "flash":self.attn_fn = flash_attn_fnelif self.attn_impl == "triton":self.attn_fn = triton_flash_attn_fnif verbose:warnings.warn("While `attn_impl: triton` can be faster than `attn_impl: flash` "+ "it uses more memory. When training larger models this can trigger "+ "alloc retries which hurts performance. If encountered, we recommend "+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.")elif self.attn_impl == "torch":self.attn_fn = scaled_multihead_dot_product_attentionif torch.cuda.is_available() and verbose:warnings.warn("Using `attn_impl: torch`. If your model does not use `alibi` or "+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "+ "we recommend using `attn_impl: triton`.")else:raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")self.out_proj = TensorParallelRowLinear.load(config,prefix=f"{prefix}.out_proj",weights=weights,bias=not config.no_bias,)# self.out_proj._is_residual = Truedef forward(self,x,past_key_value=None,attn_bias=None,attention_mask=None,is_causal=True,needs_weights=False,):qkv = self.Wqkv(x)if self.clip_qkv:qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)(query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)key_padding_mask = attention_maskif self.qk_ln:dtype = query.dtypequery = self.q_ln(query).to(dtype)key = self.k_ln(key).to(dtype)(context, attn_weights, past_key_value) = self.attn_fn(query,key,value,self.n_heads,past_key_value=past_key_value,softmax_scale=self.softmax_scale,attn_bias=attn_bias,key_padding_mask=key_padding_mask,is_causal=is_causal,dropout_p=self.attn_dropout_p,training=self.training,needs_weights=needs_weights,multiquery=True,)return (self.out_proj(context), attn_weights, past_key_value)
Reference
[1] https://github.com/huggingface/text-generation-inference
[2] LLM 加速技巧:Muti Query Attention
[3] 训练模型算力的单位:FLOPs、FLOPS、Macs 与 估算模型(FC, CNN, LSTM, Transformers&&LLM)的FLOPs
[4] FlashAttention 的速度优化原理是怎样的?
[5] FlashAttention图解(如何加速Attention)
[6] flashattention论文:https://arxiv.org/pdf/2205.14135.pdf