Pytorch手撸Attention

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size):super(SelfAttention, self).__init__()self.embed_size = embed_size# 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)# 用Linear线性层让q、k、y能更好的拟合实际需求self.value = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.query = nn.Linear(embed_size, embed_size)def forward(self, x):# x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)batch_size, seq_len, embed_size = x.shapeQ = self.query(x)K = self.key(x)V = self.value(x)# 计算注意力分数矩阵# 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数# 注意力分数的形状为 [batch_size, seq_len, seq_len]# K.transpose(1,2)转置后[batch_size, embed_size, seq_len]# 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.embed_size, dtype=torch.float32))# 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmaxattention_weight = F.softmax(attention_scores, dim=-1)# 应用注意力权重到 V 矩阵,得到加权和# 输出的形状为 [batch_size, seq_len, embed_size]output = torch.matmul(attention_weight, V)return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, num_heads):super().__init__()self.embed_size = embed_sizeself.num_heads = num_heads# 整除来确定每个头的维度self.head_dim = embed_size // num_heads# 加入断言,防止head_dim是小数,必须保证可以整除assert self.head_dim * num_heads == embed_sizeself.q = nn.Linear(embed_size, embed_size)self.k = nn.Linear(embed_size, embed_size)self.v = nn.Linear(embed_size, embed_size)self.out = nn.Linear(embed_size, embed_size)def forward(self, query, key, value):# N就是batch_size的数量N = query.shape[0]# *_len是序列长度q_len = query.shape[1]k_len = key.shape[1]v_len = value.shape[1]# 通过线性变换让矩阵更好的拟合queries = self.q(query)keys = self.k(key)values = self.v(value)# 重新构建多头的queries,permute调整tensor的维度顺序# 结合下文demo进行理解queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 计算多头注意力分数attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))attention = F.softmax(attention_scores, dim=-1)# 整合多头注意力机制的计算结果out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)# 过一遍线性函数out = self.out(out)return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)# 创建自注意力模型实例
model = SelfAttention(embed_size)# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)# 运行自注意力模型
output_tensor = model(input_tensor)# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],[0.8250, 0.0362, 0.8953, 0.1687],[0.8254, 0.8506, 0.9826, 0.0440]],[[0.0700, 0.4503, 0.1597, 0.6681],[0.8587, 0.4884, 0.4604, 0.2724],[0.5490, 0.7795, 0.7391, 0.9113]]])输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],[-0.3748,  0.6389, -0.0861, -0.0706],[-0.3694,  0.6388, -0.0855, -0.0660]],[[-0.2365,  0.4541, -0.1811, -0.0354],[-0.2338,  0.4455, -0.1871, -0.0370],[-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/624843.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个开源的全自动视频生成软件MoneyPrinterTurbo

只需提供一个视频 主题 或 关键词 &#xff0c;就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐&#xff0c;然后合成一个高清的短视频。 一&#xff1a;功能特性 完整的 MVC架构&#xff0c;代码 结构清晰&#xff0c;易于维护&#xff0c;支持 API 和 Web界面…

python生成二维码

要在Python中生成二维码&#xff0c;可以使用第三方库qrcode。首先&#xff0c;确保已经安装了qrcode库&#xff1a; pip install qrcode然后&#xff0c;使用以下代码生成二维码&#xff1a; import qrcodedata "https://mp.csdn.net/mp_blog/creation/editor?spm100…

Adobe Premiere Pro将加入AI生成式功能,以提高视频编辑的效率;OpenAI宣布在东京设立亚洲首个办事处

&#x1f989; AI新闻 &#x1f680; Adobe Premiere Pro将加入AI生成式功能&#xff0c;以提高视频编辑的效率 摘要&#xff1a;Adobe宣布&#xff0c;将为Premiere Pro引入由生成式AI驱动的新功能&#xff0c;以提高视频编辑的效率。这些功能包括“生成扩展”&#xff0c;能…

人类连接的桥梁:探索Facebook如何连接世界

随着技术的发展和全球化的进程&#xff0c;我们的世界正在变得越来越紧密相连。在这个过程中&#xff0c;社交媒体平台扮演了一个至关重要的角色&#xff0c;为人们提供了一个跨越国界、文化和语言的交流平台。其中&#xff0c;Facebook作为全球最大的社交媒体平台&#xff0c;…

Redis从入门到精通(十八)多级缓存(三)OpenResty请求参数处理、Lua脚本查询Redis和Tomcat

文章目录 前言6.5 实现多级缓存6.5.3 请求参数处理6.5.3.1 获取参数API6.5.3.2 获取参数并返回 6.5.4 查询Tomcat6.5.4.1 发送HTTP请求的API6.5.4.2 封装HTTP工具6.5.4.3 实现商品查询6.5.4.4 使用CJSON工具类6.5.4.5 基于商品ID实现负载均衡 6.5.5 查询Redis6.5.5.1 Redis缓存…

盲盒商城小程序(有米就出)

一款前端采用uniapp&#xff0c;后端采用Django框架开发的小程序&#xff0c;包含后台管理&#xff0c;如有人需要可联系演示功能&#xff08;个人开发&#xff0c;可商用/学习&#xff09;。 部分截图如下&#xff1a;

记录一下易语言post get使用WinHttp的操作

最近在学易语言&#xff0c;在进行通讯的时候&#xff0c;出现一些问题&#xff0c;现在记录下来&#xff0c;避免以后继续忘记&#xff0c; 先声明文本型变量jsonPostData jsonPostData &#xff1d; “{hostname:” &#xff0b; hostnameTxt &#xff0b; “,hardcode:” &…

游戏前摇后摇Q闪E闪QE闪QA等操作

备注&#xff1a;未经博主允许禁止转载 个人笔记&#xff08;整理不易&#xff0c;有帮助&#xff0c;收藏点赞评论&#xff0c;爱你们&#xff01;&#xff01;&#xff01;你的支持是我写作的动力&#xff09; 笔记目录&#xff1a;学习笔记目录_pytest和unittest、airtest_w…

AR、VR、MR 和 XR——它们的含义以及它们将如何改变生活

我们的工作、娱乐和社交方式正在发生巨大变化。远程工作的人比以往任何时候都多,屏幕已成为学习和游戏的领先平台。这种演变为元宇宙铺平了道路——如今,像 Meta Quest 2 这样的流行设备将您无缝地带入一个身临其境的世界,您可以在其中购物、创作和玩游戏、与同事协作、探索…

RAKsmart:硅谷裸机云多IP服务器性能评测

在云计算领域&#xff0c;裸机云作为一种结合了传统物理服务器与云计算优势的服务模式&#xff0c;近年来备受关注。硅谷裸机云作为业界佼佼者&#xff0c;以其出色的性能和稳定性赢得了众多用户的青睐。今天&#xff0c;我们就来评测一下硅谷裸机云的多IP服务器性能。 首先&am…

vscode i18n Ally插件配置项

.vscode文件&#xff1a; {"i18n-ally.localesPaths": ["src/lang"], //显示语言&#xff0c; 这里也可以设置显示英文为en,// 如下须要手动配置"i18n-ally.keystyle": "nested", // 翻译路径格式 (翻译后变量格式 nested&#xff1a…

C语言100题练习打卡(2)

14&#xff0c;将一个正整数分解质因数。 例如&#xff1a;输入90&#xff0c;打印出902*3*3*5 #include<stdio.h> /*分析&#xff1a; * 1&#xff0c;如果这话质数恰巧等于&#xff08;小于的时候&#xff0c;继续执行循环&#xff09;n&#xff0c; 则说明分解质因数…