Self-Attention 机制和多头注意力机制

  •    🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

Self-Attention 机制和多头注意力机制是Transformer 模型中的核心组件

1. Self-Attention 机制

Self-Attention 机制是一种能够根据序列中不同位置的重要性来分配权重的注意力机制。它允许模型在一个序列中的不同位置之间进行交互,从而实现对序列的全局依赖建模,而不是像循环神经网络(RNN)那样依赖于顺序迭代。

原理: 在 Self-Attention 中,输入序列被视为一组向量,每个向量表示序列中的一个位置或词。通过计算每对位置之间的关联度得到一个注意力矩阵,该矩阵表征了每个位置对其他位置的重要性。然后,利用注意力矩阵对每个位置的向量进行加权求和,得到该位置的输出向量。

工作流程:

  1. 输入序列表示: 将输入序列表示为一组向量,通常通过嵌入层(embedding layer)将输入的符号序列(如词或字符)转换为向量表示。
  2. 计算注意力权重: 对于每个位置的输入向量,计算它与其他位置的关联度。这通常通过计算查询(Query)、键(Key)和值(Value)之间的点积来实现。
  3. 加权求和: 根据计算得到的注意力权重,对每个位置的值向量进行加权求和,得到该位置的输出向量。
  4. 输出: 得到所有位置的输出向量,作为 Self-Attention 层的输出。

优点:

  • 能够捕捉长距离依赖关系,使得模型在处理长序列时表现优异。
  • 允许并行计算,提高了计算效率。

2. 多头注意力机制

多头注意力机制是在 Self-Attention 的基础上进行扩展的一种机制,它允许模型同时关注输入序列的不同子空间,从而提高了模型的表征能力。

原理: 在多头注意力机制中,将 Self-Attention 操作应用多次,每次使用不同的查询、键和值投影矩阵,得到多组注意力权重和输出。然后将这些输出拼接在一起,并通过一个线性变换层(dense layer)进行处理,得到最终的多头注意力输出。

工作流程:

  1. 多头投影: 对输入向量分别应用多组投影矩阵,得到多组查询、键和值。
  2. 多头注意力计算: 对每组查询、键和值分别计算注意力权重,并得到多组输出向量。
  3. 拼接与线性变换: 将多组输出向量拼接在一起,并通过一个线性变换层进行处理,得到最终的多头注意力输出。

3.Attention代码实例

import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, hid_dim, n_heads, dropout):super().__init__()self.hid_dim = hid_dimself.n_heads = n_heads# hid_dim必须整除assert hid_dim % n_heads == 0# 定义wqself.w_q = nn.Linear(hid_dim, hid_dim)# 定义wkself.w_k = nn.Linear(hid_dim, hid_dim)# 定义wvself.w_v = nn.Linear(hid_dim, hid_dim)self.fc = nn.Linear(hid_dim, hid_dim)self.do = nn.Dropout(dropout)self.scale = torch.sqrt(torch.FloatTensor([hid_dim//n_heads]))def forward(self, query, key, value, mask=None):# Q与KV在句子长度这一个维度上数值可以不一样bsz = query.shape[0]Q = self.w_q(query)K = self.w_k(key)V = self.w_v(value)# 将QKV拆成多组,方案是将向量直接拆开了# (64, 12, 300) -> (64, 12, 6, 50) -> (64, 6, 12, 50)# (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)# (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)Q = Q.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)K = K.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)V = V.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)# 第1步,Q x K / scale# (64, 6, 12, 50) x (64, 6, 50, 10) -> (64, 6, 12, 10)attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale# 需要mask掉的地方,attention设置的很小很小if mask is not None:attention = attention.masked_fill(mask == 0, -1e10)# 第2步,做softmax 再dropout得到attentionattention = self.do(torch.softmax(attention, dim=-1))# 第3步,attention结果与k相乘,得到多头注意力的结果# (64, 6, 12, 10) x (64, 6, 10, 50) -> (64, 6, 12, 50)x = torch.matmul(attention, V)# 把结果转回去# (64, 6, 12, 50) -> (64, 12, 6, 50)x = x.permute(0, 2, 1, 3).contiguous()# 把结果合并# (64, 12, 6, 50) -> (64, 12, 300)x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))x = self.fc(x)return xquery = torch.rand(64, 12, 300)
key = torch.rand(64, 10, 300)
value = torch.rand(64, 10, 300)
attention = MultiHeadAttention(hid_dim=300, n_heads=6, dropout=0.1)
output = attention(query, key, value)
print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/623866.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Wireshark TS | 再谈应用传输缓慢问题

问题背景 来自于朋友分享的一个案例,某某客户反馈电脑应用软件使用时打开很慢,并提供了一个慢时所捕获的数据包文件以及服务端 IP。以前也说过,所谓的慢有很多种现象,也会有很多原因引起,在没有更多输入条件的情况下&…

嵌入式基础知识学习:DA/AD—数模/模数转换

AD/DA—数模/模数转换概念 数字电路只能处理二进制数字信号,而声音、温度、速度和光线等都是模拟量,利用相应的传感器(如声音用话筒)可以将它们转换成模拟信号,然后由A/D转换器将它们转换成二进制数字信号&#xff0c…

27.8k Star,AI智能体项目GPT Pilot:第一个真正的人工智能开发者(附部署视频教程)

作者:Aitrainee | AI进修生 排版太难了,请点击这里查看原文:27.8k Star,AI智能体项目GPT Pilot:第一个真正的人工智能开发者(附部署视频教程) 今天介绍一下一个人工智能智能体的项目GPT Pilot。…

IO流高级流

前言 缓冲区能够提升输入输出的效率 虽然FileReader和FileWriter中也有缓冲区 但是BufferedReader和BufferWriter有两个非常好用的方法. 缓冲流 字节缓冲流 import java.io.*;public class BufferedStreamDemo {public static void main(String[] args) throws IOExceptio…

小阳同学嵌入式学习日记-QFile读写文件

一、QFile简介 在Qt中,QFile是一个用于文件I/O操作的类。它提供了一种方便的方式来读取和写入文件内容,以及获取有关文件的信息。 QFile类提供了许多函数,用于打开、关闭、读取和写入文件。一些常用的QFile函数包括: open(): 打开…

工作的第五天了

1.今天内容 1.现在的基本都增删改查都有 2.下一步做规格商品添加规格的方式 3.商品规格比较特殊 4.我们添加一个商品。通用一个商品,然后下面添加规格信息 5.如何做 6.第一个是添加商品 7.商品对应多个属性方式,简单来说是一个一对多的方式&#x…

想自学网络安全_Web安全,一般人我还是劝你算了吧

由于我之前写了不少网络安全技术相关的文章,不少读者朋友知道我是从事网络安全相关的工作,于是经常有人私信问我: 我刚入门网络安全,该怎么学? 要学哪些东西? 有哪些方向? 怎么选?…

如何将Oracle 中的部分不兼容对象迁移到 OceanBase

本文总结分析了 Oracle 迁移至 OceanBase 时,在出现三种不兼容对象的情况时的处理策略以及迁移前的预检方式,通过提前发现并处理这些问题,可以有效规避迁移过程中的报错风险。 作者:余振兴,爱可生 DBA 团队成员&#x…

基于SSM的游戏攻略管理系统

游戏攻略管理系统的构建与实现 一、系统概述二、系统架构与技术选型三、系统功能模块四、系统特点五、总结与展望 随着网络游戏的普及和发展,游戏攻略成为玩家们提升游戏技能、了解游戏机制的重要途径。为了更好地满足玩家需求,提高游戏攻略的管理效率和…

Java——static成员

目录 一.再谈学生类 二.static修饰成员变量 三.static修饰成员方法 四.static成员变量初始化 1.就地初始化 2.静态代码块初始化 一.再谈学生类 使用前面推文(Java——类和对象)中介绍的学生类实例化三个对象s1、s2、s3,每个对象都有自…

记录 OpenHarmony 使用 request.uploadFile 时踩的坑

​ 开发环境 设备环境:OpenHarmony 4.1.x SDK 版本:API 10 开发模型:Stage 模型 IDLE: Dev Eco 4.1 官方文档 踩坑一:后台服务地址 上传文件依赖后台服务器,如果使用本地搭建的服务,是无法访问的&…

Flex弹性盒子布局案例(认识弹性布局)

一、导航菜单 此示例创建了一个水平导航菜单&#xff0c;其中链接在 Flex 容器中等距分布。 HTML结构&#xff1a; <nav class"nav-menu"><a href"#">Home</a><a href"#">About</a><a href"#">…