浅析注意力(Attention)机制

Attention顾名思义,说明这项机制是模仿人脑的注意力机制建立的,我们不妨从这个角度展开理解

2.1 人脑的注意力机制

人脑的注意力机制,就是将有限的注意力资源分配到当前关注的任务,或关注的目标之上,暂时忽略其他不重要的因素,这是人类利用有限的注意力资源从大量信息中快速筛选出高价值信息的手段,是人类在长期进化中形成的一种生存机制,极大地提高了信息处理的效率与准确性。

举个栗子,就以上班为例,今天本该又是摸鱼的一天,但你的“恩人”突然交给你一项任务——查找关于“注意力机制”的资料并总结,并于下班之前向她汇报。于是你不得不放下手上的娱乐节目,转而应付恩人派下的工作。你选定了“注意力机制”作为关键词开始搜索,在搜索引擎的推送下阴差阳错的看到了这篇博文(这是不可能的),又因为这篇博文关键信息太少而选择忽略了它,努力一番后又查到了一些资料,汇总的大量初步结果并提交恩人,按时下班,happy ending!

死喽

上面的例子中其实出现了多次“识别关键要素”或“筛选重要信息”的动作,这便是注意力机制的体现。而深度学习中的注意力机制从本质上讲和人类的选择性注意力机制类似,核心目标也是从众多信息中选择出对当前任务目标更关键的信息。

2.2 为什么需要Attention

在之前的博文《理解LSTM》中提到过,LSTM通过引入逻辑门,从结构层面上有效解决了序列长距离依赖问题(梯度消失)。然而,面对超长序列时(例如一段500多词的文本),LSTM也可能失效。而 Attention 机制可以更好地解决序列长距离依赖问题,并且具有并行计算能力

我们还是以文本问题举例, 看一看RNN或LSTM处理超长文本序列时会发生什么?

死喽

可以看到, 为了理解当前文本,我们有时需要获得很久之前的历史状态下的某些信息。而RNNs从结构层面上无形中添加了一种假设,那就是当前的文本只和临近区域的文本具有较强的关联性,而和距离较远的上下文关联不大或没有关联。很明显,这样的假设是不恰当的,这就限制了RNNs处理文本的长度和理解文本的精度,而Attention的出现则几乎打破了模型对于文本长度的限制。

采用RNN架构的网络均具有这种局限, 包括LSTM, GRU等等

为了进一步理解,让我们从循环神经网络的老大难问题——机器翻译问题入手。
在翻译任务中,源语言和目标语言的单词数和语序往往不是一一对应的,这种输入和输出都是不定长序列的任务,称为 Seq2Seq,以英语和德语为例,如下图所示。

翻译

为了解决这个问题,我们创造了Encoder-Decoder结构的循环神经网络。

  • 它先通过一个Encoder循环神经网络读入所有的待翻译句子中的单词,得到一个包含原文所有信息的中间隐藏层,接着把中间隐藏层状态输入Decoder网络,一个词一个词的输出翻译句子。
  • 这样子,无论输入中的关键词语有着怎样的先后次序,由于都被打包到中间层一起输入后方网络,我们的Encoder-Decoder网络都可以很好地处理这些词的输出位置和形式了。

问题在于,由于中间状态\(C\)来自输入网络最后的隐藏层,一般来说它是一个大小固定的向量。既然是大小固定的向量,那么它能储存的信息就是有限的,当句子长度不断变长,由于后方的decoder网络的所有信息都来自中间状态,中间状态需要表达的信息就越来越多。在语句信息量过大时,中间状态就作为一个信息的瓶颈阻碍翻译了。这时我们很容易联想到,如果网络能够在处理长文本时懂得筛选关键信息, 而不是将全部文本都作为都作为中间状态储存,是不是就可以突破文本长度的限制了?这便是注意力机制的由来。

Encoder-Decoder(编码-解码)是深度学习中非常常见的一个模型框架,比如无监督算法的auto-encoding就是用编码-解码的结构设计并训练的;比如这两年比较热的image caption的应用,就是CNN-RNN的编码-解码框架;再比如神经网络机器翻译NMT模型,往往就是LSTM-LSTM的编码-解码框架。因此,准确的说,Encoder-Decoder并不是一个具体的模型,而是一类框架。Encoder和Decoder部分可以是任意的文字,语音,图像,视频数据,模型可以采用CNN,RNN,BiRNN、LSTM、GRU等等。所以基于Encoder-Decoder架构,我们可以设计出各种各样的应用算法。

2.3 Attention的核心思想

在正式介绍注意力机制之前,我们先要明确以下几个概念:

  • 查询(Query):用于记录模型当前关注的任务信息,向量形式
  • 键(Key):用于记录输入序列中每个信息单元的标识符或标签, 用于与Query进行比较,以决定哪些信息是相关的, 在机器翻译任务中,Key可能是源语言的每个单词或短语的特征向量
  • 值(Value):Value通常包含输入序列的实际信息,当Query和Key匹配时,相应的Value值被用于计算输出
  • 分数(Score): Score又称为注意力分数,用于表示Query和Key的匹配程度,Score越高,模型对当前信息单元的关注度越高

我们仍以机器翻译为例,通过引入注意力机制,让生成词不是只能关注全局的语义编码向量c,而是增加了一个“注意力范围”,表示接下来输出词时候要重点关注输入序列中的哪些部分,然后根据关注的区域来产生下一个输出,如下图所示。

翻译
此时生成目标句子单词的过程就成了下面的形式: $$ \begin{aligned}&\mathbf{y}_{1}=\mathbf{f}\mathbf{1}(\mathbf{C}_{1})\\&\mathbf{y}_{2}=\mathbf{f}\mathbf{1}(\mathbf{C}_{2},\mathbf{y}_{1})\\&\mathbf{y}_{3}=\mathbf{f}\mathbf{1}(\mathbf{C}_{3},\mathbf{y}_{1},\mathbf{y}_{2})\end{aligned} $$ 这样一来,由于每个生成词关注的语义编码向量都各不相同,且信息容量都被限定在了一个范围内,无需一次性关注局部特征,也就解决了序列长度过长带来的问题。

在理解了注意力机制的作用之后,我们就可以对其具体步骤加以描述了(正片开始)。

翻译

如上图所示,Attention 通常可以进行如下描述,表示为将 Query(Q) 和 key-value pairs(把 Values 拆分成了键值对的形式) 映射到输出上,其中 query、每个 key、每个 value 都是向量,输出是 \(V\) 中所有 values 的加权,其中权重是由 Query 和每个 key 计算出来的,计算方法分为三步:

  1. 第一步:计算并比较 Q 和 K 的相似度,用 f 来表示:\(f(Q,K_i)\quad i=1,2,\cdots,m\), 一般第一步计算方法包括四种
  • 点乘(transformer使用):\(f(Q,K_i)=Q^TK_i\)
  • 加权:\(f(Q,K_i)=Q^TWK_i\)
  • 拼接权重:\(f(Q,K_i)=W[Q^T;K_i]\)
  • 感知器:\(f(Q,K_i)=V^T\tanh(WQ+UK_i)\)
  1. 将得到的相似度进行 softmax 操作,进行归一化,得到注意力分数:\(\alpha_i=softmax(\frac{f(Q,K_i)}{\sqrt{d}_k})\)
  2. 针对计算出来的权重 \(\alpha_{i}\),对 \(V\) 中的所有 values 进行加权求和计算,得到 Attention 向量:\(Attention=\sum_{i=1}^m\alpha_iV_i\)

2.4 Attention代码实现

最后附一个Attention机制的代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SimpleAttention(nn.Module):def __init__(self, input_dim):super(SimpleAttention, self).__init__()self.input_dim = input_dimself.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)def forward(self, x):Q = self.query(x)K = self.key(x)V = self.value(x)# Compute attention scores (dot product of queries and keys)attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.input_dim ** 0.5# Apply softmax to get attention weightsattention_weights = F.softmax(attention_scores, dim=-1)# Weighted sum of valuesoutput = torch.matmul(attention_weights, V)return output, attention_weights# Example usage
input_dim = 64
seq_length = 10
batch_size = 5# Dummy input tensor (batch_size, seq_length, input_dim)
x = torch.rand(batch_size, seq_length, input_dim)# Initialize the attention module
attention = SimpleAttention(input_dim)# Forward pass
output, attention_weights = attention(x)print("Output shape:", output.shape)  # Expected: (batch_size, seq_length, input_dim)
print("Attention weights shape:", attention_weights.shape)  # Expected: (batch_size, seq_length, seq_length)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/835479.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HBase架构与基础命令

HBase架构与基础命令 一、了解HBase 官方文档:https://hbase.apache.org/1.1 HBase概述HBase 是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统,用于存储海量的结构化或者半结构化,非结构化的数据(底层是字节数组做存储的) HBase是Hadoop的生态系统之一,是建立在…

Blender 效果制作:制作起伏不平的路面

前置准备正常纹理贴图,置换贴图,法线贴图,粗糙贴图方法一首先用UV坐标,纹理贴图,法线贴图,粗糙贴图构建材质将网格细分多一点,并采用置换修改器,置换修改器使用置换贴图获得成图方法二首先用UV坐标,纹理贴图,法线贴图,粗糙贴图构建材质细分网格将【置换】改成“置换…

团队作业4-第5篇Scrum博客

团队作业4-第5篇Scrum博客 1 站立式会议 1.1 会议照片1.2 会议内容 昨天已完成的工作:已初步完成数据库记录的备份、恢复和退出功能及账目记录的增删改功能今天计划完成的工作项目模块 需要实现的功能 负责人 预计用时主界面模块 右键实现增删改功能 黄锐 2h主界面模块 报告界…

PCFN

import torch import torch.nn as nn import torch.nn.functional as Fclass PCFN(nn.Module):使用带有GELU的激活函数的1*1卷积对扩展的隐藏空间进行跨信道交互。 然后将隐藏特征分割成两块 对其中一块使用3*3卷积核GELU激活函数 编码局部上下文将处理后的结果和另一块合并def…

Linux 内核如何装载和启动一个可执行程序

张晓攀+原创作品转载请注明出处+《Linux内核分析》MOOC课程https://mooc.study.163.com/course/1000029000 实验七——Linux 内核如何装载和启动一个可执行程序 一、实验过程 1.从github上下载相关代码2.然后用test_exec.c 替换test.c,再重新编译生成根文件系统3.启动调试内核…

java Runtime.exec()执行shell/cmd命令:常见的几种陷阱与一种完善实现

java Runtime.exec()执行shell/cmd命令:常见的几种陷阱与一种完善实现@目录背景说明前言Runtime.exec()常见的几种陷阱以及避免方法陷阱1:IllegalThreadStateException陷阱2:Runtime.exec()可能hang住,甚至死锁陷阱3:不同平台上,命令的兼容性陷阱4:错把Runtime.exec()的…

昆工891数据库系统原理强化课程

--昆工昆明理工大学、计算机技术、人工智能、软件工程、网络空间安全、891计算机专业核心综合、计算机系统结构、计算机软件与理论、网络与信息安全、计算机应用技术、综合程序设计、通信工程、817信号与系统、信号与信息处理、通信与信息系统

第7篇Scrum博客

1.站立式会议 1.1 会议照片1.2 会议内容 昨天已完成的工作: 昨天已基本实现用条形图,折线图,饼图展示数据界面功能。 今天计划完成的工作项目模块 需要实现的功能 负责人 预计用时主界面模块 整合代码,查漏补缺 王伊若 5h主界面模块 主界面设计 王伊若 2h主界面模块 查询界…

Ant Design Vue组件安装

https://www.antdv.com/docs/vue/getting-started-cn

书生共学大模型实战营L1G6000 XTuner微调

任务描述:使用XTuner微调InternLM2-Chat-7B实现自己的小助手认知 该任务分为数据集处理、微调训练、合并部署三个环节。数据处理:主要是将目标json文件中的字段替换为和自己用户名相关的字段,这里我们将“尖米”替换为“科研狗1031”:微调训练:采用教程中的XTuner框架,在…

request to https://registry.npm.taobao.org/ant-design-vue failed, reason: certificate has expire

一、原因分析 其实早在 2021 年,淘宝就发文称,npm 淘宝镜像已经从 http://registry.npm.taobao.org 切换到了 http://registry.npmmirror.com。旧域名也将于 2022 年 5 月 31 日停止服务(直到 HTTPS 证书到期才真正不能用了)2024年1 月 22 日,淘宝原镜像域名(http…