多头注意力机制运算过程详解

news/2025/3/19 12:19:01/文章来源:https://www.cnblogs.com/yanyeeee/p/18780781

记录一下用于学习多头注意力机制的计算过程的实验脚本

from TransUNet.networks.vit_seg_modeling import Attention
import TransUNet.networks.vit_seg_configs as Config
import torch, math
import torch.nn as nn
from torch.nn import Linear, Dropout
from torch.nn.functional import softmaxconfig = Config.get_r50_b16_config()
batch_size = 1
seq_len = 16
attention = Attention(config, True)
hidden_states = torch.randn(batch_size, seq_len, config.hidden_size)print("输入形状:", hidden_states.shape)
# [batch_size, seq_len, hidden_size]
mixed_query = attention.query(hidden_states)
mixed_key = attention.key(hidden_states)
mixed_value = attention.value(hidden_states)
print("K 线性变换后:", mixed_key.shape)
print("Q 线性变换后:", mixed_query.shape)
print("V 线性变换后:", mixed_value.shape)# [batch_size, num_attention_heads, seq_len, attention_head_size]
query_layer = attention.transpose_for_scores(mixed_query)
key_layer = attention.transpose_for_scores(mixed_key)
value_layer = attention.transpose_for_scores(mixed_value)
print("Q 拆分多头后:", query_layer.shape)
print("K 拆分多头后:", key_layer.shape)
print("V 拆分多头后:", value_layer.shape)
# 计算注意力分数# query_layer: [batch_size, num_attention_heads, seq_len, attention_head_size]# key_layer.transpose(-1, -2): [batch_size, num_attention_heads, attention_head_size, seq_len]# 矩阵乘法在最后两个维度,即query_layer的行(token)与key_layer(token)的列做点乘
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
print("注意力分数形状:", attention_scores.shape)  # [batch_size, num_attention_heads, seq_len, seq_len]
# 依照attention_head_size,缩放注意力分数
attention_scores = attention_scores / math.sqrt(attention.attention_head_size)
# softmax
attention_probs = attention.softmax(attention_scores)
# 对Query0 对 Key0~Key_n 的注意力分数求和
# 在经过了缩放与softmax后,应为1
print("注意力概率和:", attention_probs[0,0,0].sum())
# 将注意力分数与value_layer进行矩阵乘法计算上下文
# attention_probs: [batch_size, num_attention_heads, seq_len, seq_len]
# value_layer: [batch_size, num_attention_heads, seq_len, attention_head_size]
context_layer = torch.matmul(attention_probs, value_layer)
print("上下文形状:", context_layer.shape)  # [batch_size, num_attention_heads, seq_len, attention_head_size]
# [batch_size, seq_len, num_attention_heads, attention_head_size]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()# contiguous()确保张量在内存中连续存储,为后续 view 操作做准备
print("合并前形状:", context_layer.shape)  #
# 开始合并多头
# [batch_size, seq_len, all_head_size(hidden_size)]
new_context_layer_shape = context_layer.size()[:-2] + (attention.all_head_size,)
# [batch_size, seq_len, num_heads * head_size] → [batch_size, seq_len, hidden_size]
context_layer = context_layer.view(*new_context_layer_shape)
print("合并后形状:", context_layer.shape)
# 输出线性变换(Projection),形状不变
attention_output = attention.out(context_layer)
print("输出形状:", attention_output.shape)
# 应用Dropout,以一定概率(如 0.1)随机将部分神经元输出置零,防止过拟合
attention_output = attention.proj_dropout(attention_output)
print("dropout后输出形状:", attention_output.shape)  #

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/901464.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

yum install -y devtoolset-8-gcc*

如果执行结果为上面这个结果的话,需要执行以下操作 yum install centos-release-scl*修改CentOS-SCLo-scl.repo文件 baseurl=https://mirrors.aliyun.com/centos/7/sclo/x86_64/rh/ 和 gpgcheck=0修改CentOS-SCLo-scl-rh.repo文件和上面一样查看 [root@iZbp153shsqfoddljmkit4…

几个技巧,教你去除文章的 AI 味!

给大家分享一些快速去除文章 AI 味的小技巧,有些是网上被分享过的,也有些是我个人的经验。学会之后,无论是写工作文案、毕业设计、自媒体文章,还是平时生活中写写好评,都是非常轻松的。最近有不少朋友在利用 AI 写毕业设计论文,几秒钟一篇文章就刷出来的,爽的飞起。 结果…

Sci Chart中的XyDataSeries与UniformXyDataSeries

在 SciChart 中,XyDataSeries 和 UniformXyDataSeries 是两种用于处理数据序列的核心类,主要差异体现在数据存储方式、性能优化及适用场景上。 以下是具体对比: 1. 数据存储与结构差异 **XyDataSeries<TX, TY>** 需要同时存储 X 和 Y 值的完整坐标对。例如,对于每个数…

强化学习基础_基于价值的强化学习

Action-Value Functions 动作价值函数 折扣回报(Discounted Return) 折扣回报 Ut 是从时间步 t 开始的累积奖励,公式为: Rt 是在时间步 t 获得的奖励。γ 是折扣因子(0<γ<1),用于减少未来奖励的权重。这是因为未来的奖励通常不如当前奖励重要,例如在金融领域,未…

USB杂谈

一、USB控制器 OHCI 1.0、1.1控制器 UHCI:1.0、1.1控制器 EHCI 2.0控制器 XHCI 3.0控制器 EHCI 2.0控制器 HID:人机交互接口,鼠标、手柄 、键盘、扫描枪USB协议中对集线器的层数是有限制的,USB1.1规定最多为5层,USB2.0规定最多为7层。 理论上,一个USB主控制器最多可接127个…

2025年3月中国数据库排行榜:PolarDB夺魁傲群雄,GoldenDB晋位入三强

2025年3月排行榜解读出炉,榜单前四现波动,PolarDB时隔半年重返榜首、GoldenDB进入前三,此外更有一些新星产品表现亮眼!欢迎阅读、一起盘点~阳春三月,万物复苏。2025年3月中国数据库流行度排行榜的发布,不仅展现了中国数据库企业在技术创新、生态建设和应用深化方面的显著…

# 20241902 2024-2025-2 《网络攻防实践》第四周作业

1.实验内容 通过本次实验,在搭建的实验环境中完成TCP/IP协议栈重点协议的攻击实验,包括ARP缓存欺骗攻击、ICMP重定向攻击、SYN Flood攻击、TCP RST攻击、TCP会话劫持攻击,并熟悉wireshark、netwox和ettercap等软件的操作。 2.实验过程 实验1 ARP缓存欺骗攻击 本实验中Kali为…

【Azure Fabric Service】分享使用Visual Studio 2022发布中国区Service Fabric服务应用的办法

问题描述 使用Visual Studio 2022如何发布Service Fabric到中国区云服务呢? 因为使用VS2022中的插件无法创建Service Fabric Cluster服务。那么,是否又比较好的替代方案呢?问题解答 是的,有替代方案。 除了昨天介绍使用的Powershell命令外( 【Azure Fabric Service】演示使…

如何让GameObject销毁时无论是否Active过,都调用OnDestroy

1)如何让GameObject销毁时无论是否Active过,都调用OnDestroy2)升级到URP画面会提升吗3)如何用Dynamic Mesh做出在墙上打洞的效果4)UE可以把烘焙好的光照贴图导出吗这是第424篇UWA技术知识分享的推送,精选了UWA社区的热门话题,涵盖了UWA问答、社区帖子等技术知识点,助力…

测序芯片-不同键合工艺对比-flowcell-代加工-外协加工-委外加工-激光代加工-河南郑州-芯晨微纳(河南)

基因测序(包括DNA测序和RNA测序)是研究生命信息的重要方法之一。DNA测序(DNA sequencing,或译DNA定序)是指分析特定DNA片段的碱基序列, 也就是腺嘌呤(A)、胸腺嘧啶(T)、胞嘧啶(C)与鸟嘌呤(G)的排列方式。同理,RNA测序是指分析特定RNA片段的碱基序列,也就是腺嘌呤(A)、鸟嘌呤…