深度学习理论基础(六)多头注意力机制

目录

  • 一、自定义多头注意力机制
    • 1. Scaled Dot-Product Attention
    • 2. 多头注意力机制框图
      • (1)计算公式
      • (2)具体计算过程
      • (3)具体代码
  • 二、pytorch中的子注意力机制模块

  
  深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息,提高模型的性能和泛化能力。
  下图 展示了人类在看到一幅图像时如何高效分配有限注意力资源的,其中红色区域表明视觉系统更加关注的目标,从图中可以看出:人们会把注意力更多的投入到人的脸部。文本的标题以及文章的首句等位置。而注意力机制就是通过机器来找到这些重要的部分。
在这里插入图片描述

一、自定义多头注意力机制

1. Scaled Dot-Product Attention

  在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。
Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.
在这里插入图片描述

2. 多头注意力机制框图

  多头注意力机制是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。
在这里插入图片描述

(1)计算公式

在这里插入图片描述

(2)具体计算过程

①计算注意力得分:根据Query和Key计算两者的相似性或相关性。常见方法:求两者的向量点积(内积)。
②对注意力得分进行softmax归一化处理。
③输出:根据权重系数对value进行加权求和。
在这里插入图片描述

(3)具体代码

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):#embedding_dim:输入向量的维度,num_heads:注意力机制头数def __init__(self, embedding_dim, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_heads       #总头数self.embedding_dim = embedding_dim   #输入向量的维度self.d_k= self.embedding_dim// self.num_heads  #每个头 分配的输入向量的维度数self.softmax=nn.Softmax(dim=-1)self.W_query = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)self.W_key = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)self.W_value = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)self.fc_out = nn.Linear(embedding_dim, embedding_dim)#输入张量 x 中的特征维度分成 self.num_heads 个头,并且每个头的维度为 self.d_k。def split_head(self, x, batch_size):x = x.reshape(batch_size, -1, self.num_heads, self.d_k)return x.permute(0,2,1,3)   #x  (N_size, self.num_heads, -1, self.d_k)def forward(self, x):batch_size=x.size(0)  #获取输入张量 x 的批量(batch size)大小q= self.W_query(x)  k= self.W_key(x)  v= self.W_value(x)#使用 split_head 函数对 query、key、value 进行头部切分,将其分割为多个注意力头。q= self.split_head(q, batch_size)k= self.split_head(k, batch_size)v= self.split_head(v, batch_size)##attention_scorce = q*k的转置/根号d_kattention_scorce=torch.matmul(q, k.transpose(-2,-1))/torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))attention_weight= self.softmax(attention_scorce)## output = attention_weight * Voutput = torch.matmul(attention_weight, v)  # [h, N, T_q, num_units/h]output  = out.permute(0,2,1,3).contiguous() # [N, T_q, num_units]output  = out.reshape(batch_size,-1, self.embedding_dim)output  = self.fc_out(output)return output

  

二、pytorch中的子注意力机制模块

  nn.MultiheadAttention是PyTorch中用于实现多头注意力机制的模块。它允许你在输入序列之间计算多个注意力头,并且每个头都学习到了不同的注意力权重。
  创建了一些随机的输入数据,包括查询(query)、键(key)、值(value)。接着,我们使用multihead_attention模块来计算多头注意力,得到输出和注意力权重。
  请注意,你可以调整num_heads参数来控制多头注意力的头数,这将会影响到模型的复杂度和表达能力。

import torch
import torch.nn as nn# 假设我们有一些输入数据
# 输入数据形状:(序列长度, 批量大小, 输入特征维度)
input_seq_length = 10
batch_size = 3
input_features = 32# 假设我们的输入序列是随机生成的
input_data = torch.randn(input_seq_length, batch_size, input_features)# 定义多头注意力模块
# 参数说明:
#   - embed_dim: 输入特征维度
#   - num_heads: 多头注意力的头数
#   - dropout: 可选,dropout概率,默认为0.0
#   - bias: 可选,是否在注意力计算中使用偏置,默认为True
#   - add_bias_kv: 可选,是否添加bias到key和value,默认为False
#   - add_zero_attn: 可选,是否在注意力分数中添加0,默认为False
multihead_attention = nn.MultiheadAttention(input_features, num_heads=4)# 假设我们有一个query,形状为 (查询序列长度, 批量大小, 输入特征维度)
query = torch.randn(input_seq_length, batch_size, input_features)# 假设我们有一个key和value,形状相同为 (键值序列长度, 批量大小, 输入特征维度)
key = torch.randn(input_seq_length, batch_size, input_features)
value = torch.randn(input_seq_length, batch_size, input_features)# 计算多头注意力
# 返回值说明:
#   - output: 注意力计算的输出张量,形状为 (序列长度, 批量大小, 输入特征维度)
#   - attention_weights: 注意力权重,形状为 (批量大小, 输出序列长度, 输入序列长度)
output, attention_weights = multihead_attention(query, key, value)# 输出结果
print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/595890.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据挖掘实战-基于LSTM算法的HCV检测者分类模型研究

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

VPDN(L2TP、PPTP)

1、虚拟专用拨号网络 远程接入VPN,客户端可以是PC机 技术:L2TP、PPTP 术语:LAC:L2TP的访问集中器 --- 提供用户的接入 LNS:L2TP的网络服务器 --- 提供L2TP服务的服务器 2、技术 1)PPTP 点对点隧道…

Mysql启动失败解决过程

报错内容如下: Mar 05 18:40:49 VM-0-12-centos systemd[1]: Failed to start MySQL Server. Mar 05 18:40:49 VM-0-12-centos systemd[1]: Unit mysqld.service entered failed state. Mar 05 18:40:49 VM-0-12-centos systemd[1]: mysqld.service failed. Mar 05…

创建和启动线程

概述 Java语言的JVM允许程序运行多个线程,使用java.lang.Thread类代表线程,所有的线程对象都必须是Thread类或其子类的实例。 Thread类的特性 每个线程都是通过某个特定Thread对象的run()方法来完成操作的,因此把run()方法体称为线程执行体。…

刷题日记——由浅入深的大数加法(高精度加法)

例题 代码 #include <cstdio>int main(){long long a,b;scanf("%lld %lld",&a,&b);printf("%lld\n",ab);}例题——高精度加法 编程计算&#xff1a;12345678912345678912121211231212121212121212121222222111112121&#xff1f; 分析 加…

工程师必备:PW1558 12V/20V过流限压保护芯片,短路无忧,运行更稳定

在电力电子领域&#xff0c;寻找一款能够提供全面保护且性能卓越的电源开关至关重要。PW1558正是这样一款产品&#xff0c;它凭借出色的性能和广泛的应用领域&#xff0c;赢得了业界的广泛认可。下面&#xff0c;我们将从描述、特点和应用三个方面&#xff0c;详细解读PW1558的…

Ideal的使用技巧

一、springcloud项目如何将多个服务放到services中一起启动 1、打开ideal&#xff0c;再view -> Tool Windows -> services 2、在services界面 找到 run configuration type -> springboot即可 二、配置临时的启动参数 1、在edit configurations中 2、选择相应的服务…

C. MEX Game 1

本题如果我们去模拟这个算法的话会很麻烦&#xff0c;也会TLE&#xff0c;首先我们想 1&#xff0c;对于alice来说&#xff0c;先取小的&#xff0c;对于bob来说先删除alic想取的下一个小的 2&#xff0c;那如果这个数多于两个&#xff0c;那也就是说&#xff0c;alice肯定能…

【详细讲解0基础如何进入IT行业】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

UE4_如果快速做出毛玻璃效果_假景深

UE4_如果快速做出毛玻璃效果_假景深 2022-08-20 15:02 一个SpiralBlur-SceneTexture材质节点完成效果&#xff0c;启用半透明材质通过修改BlurAmount数值大小调整效果spiralBlur-SceneTexture custom节点&#xff0c;HLSL语言float3 CurColor 0;float2 BaseUV MaterialFloa…

树状数组相关题目

题目一 方法一 归并分治 代码&#xff1a; # include <stdio.h>int arr[100]; int help[100];int n;//归并分治 // 1.统计i、j来自 l~r 范围的情况下&#xff0c;逆序对数量 // 2.统计完成后&#xff0c;让arr[l...r]变成有序的 int f(int l, int r) {if (l r)return…

Mybatis 之 useGeneratedKeys

数据库中主键基本都设置为自增&#xff0c;当我们要插入一条数据想要获取这条数据的 Id 时&#xff0c;就可使用 Mybatis 中的 useGeneratedKeys 属性。 背景 这里以 苍穹外卖 中的 新增菜品 功能为例&#xff0c;有 菜品表(dish table)和 口味表(dish_flavor table)&#xf…