MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

  • 注意力机制详解及代码
    • 前言:
    • MHA
    • MQA
    • GQA

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

在这里插入图片描述

在这里插入图片描述

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 对注意力输出进行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
## 分组注意力查询
import torch
from torch import nn
class MutiGroupAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiGroupAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩阵self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 输出线性层self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 计算注意力分数attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 对注意力分数进行归一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/702802.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速找出文件夹里的全部带有中文纯中文的文件

首先,需要用到的这个工具: 度娘网盘 提取码:qwu2 蓝奏云 提取码:2r1z 步骤 1、打开工具,切换到批量复制文件 2、鼠标移到右侧,点击搜索添加 3、设定查找范围、指定为文件、勾选 包含全部子文件夹&#x…

vue 点击平滑到指定位置并绑定页面滑动效果

1.html元素 写出对应的数据块&#xff08;注意添加ref) 用于获取元素位置 <template><div class"index-page" ><div class"top-head" ref"index"><img src"logo.png" style"height: 40px;margin-right: 2…

一图流解释Java中线程状态的转换

目录 一.Java中的几大线程状态 二.线程之间的相互转换 ▐ NEW --> RUNNABLE ▐ RUNNABLE <--> WAITING ▐ RUNNABLE <--> Timed Waiting ▐ RUNNABLE<--> BLOCKED ▐ RUNNABLE<-->TERMINATED 一.Java中的几大线程状态 简单来说线程可以处于…

Unity里的Time

Time and frame rate management Time类&#xff1a; Time script reference page. 一些常见的属性有&#xff1a; Time.time 返回从游戏开始经历的时间.Time.deltaTime 返回从上帧结束到现在经历的时间&#xff0c;和帧率成反比Time.timeScale 控制时间流逝的因子Time.fixe…

STM32_HAL_TIM_通用计时器_实现计时

项目思路 1使用定时器计数每秒一次 2使用一个变量记录定时器响应多少次 3使用UART将记录的次数发出 1STM32Cude设置 1配置时钟源 2打开UART 3打开TIM2 3.1界面介绍 3.2选项介绍 Slave Mode&#xff08;从模式&#xff09;&#xff1a;当设备被设置为从模式时&#xff0c…

推荐网站(10)storyset故事集 |免费定制、制作动画和下载插图

今天介绍一个免费定制、制作动画和下载插图的网站storyset故事集。Storyset故事集是一个提供丰富动画插图资源的在线平台&#xff0c;它不仅免费&#xff0c;而且用户友好&#xff0c;可以极大地丰富你的演示文稿&#xff08;PPT&#xff09;和其他视觉内容。 通过Storyset故事…

替补url

检查图片数量是否为6张,如果图片数量不足6张&#xff0c;则用本地图片补充当前位缺失的URL <div class"right-pic-box" v-for"(i, index) in imagesWithReplacement" :key"index"><img :src"i" width"316" height&…

Axure “情形”的使用

这篇笔记的主要内容是如果在Axure中使用“情形”&#xff0c;对应在我们的研发中就是“判断条件”的使用 Axure情形的使用Axure添加caseAxure的if &#xff0c;sele if 条件判断 条件判断不管是在研发代码中还是实际生活中&#xff0c;无处不在&#xff0c;只是表现形式不同罢…

牛客热题:二叉树的最大深度

&#x1f4df;作者主页&#xff1a;慢热的陕西人 &#x1f334;专栏链接&#xff1a;力扣刷题日记 &#x1f4e3;欢迎各位大佬&#x1f44d;点赞&#x1f525;关注&#x1f693;收藏&#xff0c;&#x1f349;留言 文章目录 牛客热题&#xff1a;二叉树的最大深度题目链接方法一…

C++ 程序员常用的VScode的插件

vscode中好用的插件 Better CommentsBookmarksC/C ThemeChinese (Simplified) (简体中文) Language Pack for Visual Studio CodeclangdClang-FormatCodeLLDBCMakeCMake ToolsCode RunnerCode Spell CheckerCodeSnapColor Highlightvscode-mindmapDraw.io IntegrationError Len…

图文详解JUC:Wait与Sleep的区别与细节

目录 一.Wait() 二.Sleep() 三.总结Wait()与Sleep()的区别 一.Wait() 在Java中&#xff0c;wait() 方法是 Object类中的一个方法&#xff0c;用于线程间的协作。当一个线程调用wait() 方法时&#xff0c;它会释放对象的锁并进入等待状态&#xff0c;直到其他线程调用相同对…

常用的一些字符转换工具--web(在线进制转换)

十进制 转2进制&#xff0c; 16进制 十进制浮点数转16进制&#xff08;4个Byte) http://www.speedfly.cn/tools/hexconvert/