记录一下用于学习多头注意力机制的计算过程的实验脚本
from TransUNet.networks.vit_seg_modeling import Attention
import TransUNet.networks.vit_seg_configs as Config
import torch, math
import torch.nn as nn
from torch.nn import Linear, Dropout
from torch.nn.functional import softmaxconfig = Config.get_r50_b16_config()
batch_size = 1
seq_len = 16
attention = Attention(config, True)
hidden_states = torch.randn(batch_size, seq_len, config.hidden_size)print("输入形状:", hidden_states.shape)
# [batch_size, seq_len, hidden_size]
mixed_query = attention.query(hidden_states)
mixed_key = attention.key(hidden_states)
mixed_value = attention.value(hidden_states)
print("K 线性变换后:", mixed_key.shape)
print("Q 线性变换后:", mixed_query.shape)
print("V 线性变换后:", mixed_value.shape)# [batch_size, num_attention_heads, seq_len, attention_head_size]
query_layer = attention.transpose_for_scores(mixed_query)
key_layer = attention.transpose_for_scores(mixed_key)
value_layer = attention.transpose_for_scores(mixed_value)
print("Q 拆分多头后:", query_layer.shape)
print("K 拆分多头后:", key_layer.shape)
print("V 拆分多头后:", value_layer.shape)
# 计算注意力分数# query_layer: [batch_size, num_attention_heads, seq_len, attention_head_size]# key_layer.transpose(-1, -2): [batch_size, num_attention_heads, attention_head_size, seq_len]# 矩阵乘法在最后两个维度,即query_layer的行(token)与key_layer(token)的列做点乘
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
print("注意力分数形状:", attention_scores.shape) # [batch_size, num_attention_heads, seq_len, seq_len]
# 依照attention_head_size,缩放注意力分数
attention_scores = attention_scores / math.sqrt(attention.attention_head_size)
# softmax
attention_probs = attention.softmax(attention_scores)
# 对Query0 对 Key0~Key_n 的注意力分数求和
# 在经过了缩放与softmax后,应为1
print("注意力概率和:", attention_probs[0,0,0].sum())
# 将注意力分数与value_layer进行矩阵乘法计算上下文
# attention_probs: [batch_size, num_attention_heads, seq_len, seq_len]
# value_layer: [batch_size, num_attention_heads, seq_len, attention_head_size]
context_layer = torch.matmul(attention_probs, value_layer)
print("上下文形状:", context_layer.shape) # [batch_size, num_attention_heads, seq_len, attention_head_size]
# [batch_size, seq_len, num_attention_heads, attention_head_size]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()# contiguous()确保张量在内存中连续存储,为后续 view 操作做准备
print("合并前形状:", context_layer.shape) #
# 开始合并多头
# [batch_size, seq_len, all_head_size(hidden_size)]
new_context_layer_shape = context_layer.size()[:-2] + (attention.all_head_size,)
# [batch_size, seq_len, num_heads * head_size] → [batch_size, seq_len, hidden_size]
context_layer = context_layer.view(*new_context_layer_shape)
print("合并后形状:", context_layer.shape)
# 输出线性变换(Projection),形状不变
attention_output = attention.out(context_layer)
print("输出形状:", attention_output.shape)
# 应用Dropout,以一定概率(如 0.1)随机将部分神经元输出置零,防止过拟合
attention_output = attention.proj_dropout(attention_output)
print("dropout后输出形状:", attention_output.shape) #