transformer中的attention机制详解

transformer中用到的注意力机制包括self-attention(intra-attention)和传统的attention(cross-attention),本篇文章将在第一节简述这两者的差别,第二节详述self-attention机制,第三节介绍其实现

self-attention和attention的区别

传统attention机制

发生在decoder和encoder之间,decoder可以更多的参考encoder中相关的信息,以便指导其输出。attention机制可以分为以下三步

  • 计算algnment score

其中hj时encoder输出的隐状态, si是decoder 输出的隐状态, eij描述的时输入位置j和输出位置i的匹配度

  • 匹配度归一化,这里使用softmax进行计算

  • 计算context vector

从表达式可看出,decoder 计算所需的context vector 实际上就是输入隐状态的加权和。

具体到RNN中,每个时间步应用attention机制的计算步骤如下

  • decoder RNN 接收 token 的嵌入和初始解码器隐藏状态。
  • RNN 处理其输入,产生输出和新的隐藏状态向量 si。输出被丢弃。
  • attention计算:我们使用encoder输出的所有的隐藏状态和 decoder 输出的si 向量来计算此时间步骤的context vector ci。
  • 我们将 si 和 ci 连接成一个向量。
  • 我们将此向量传递给前馈神经网络(与模型联合训练)。
  • 前馈神经网络的输出表示此时间步骤的输出词。
  • 对下一个时间步骤重复此操作

self-attention

发生在decoder或者encoder内部,将输出或者输入序列内部不同位置关联起来,以计算序列表征

self-attention机制

self-attention的实现步骤和attention类似,在attention中计算align score时用到了输入和输出的hidden state,但是对于self-attention只需要用到一种,即在encoder中的self-attention只用到encoder层输出的hidden state, decoder中的self-attention只用到decoder层的hidden state

我们将self-attention拆解为两部分,1. self-attention计算 2. multi-head attention

self attention计算:scaled dot-product attention

  • 获取encoder输入的embeding,并计算每个embedding 的query,key,value,下文简写为q,k,v。

其中WQ, WK, WV为去要学习的权重矩阵

  • 接下来我们要计算不同位置之间的关联度。例如我们要计算位置0处的embedding和其他位置embedding的关联度,参考传统attention机制align score的计算方法,我们要用位置0处计算得到的hidden states即query 和其他位置处计算的key进行计算。在self-attention中计算过程如下

相较于传统attention计算align score,self-attention中多了一步scale,即用key维度的开方对qk结果进行缩放。论文中提出这样做的理由是

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by 1/sqrt(dk)

即缩放的目的是为了保证softmax有更加稳定的梯度

  • 得到其他位置对于位置0的关联度/权重之后,我们就可以计算位置0处包含有上下文信息的context vector:

  • 利用矩阵运算可以同步求出其他位置的context vector
    获取输入向量的Q,K,V矩阵

    运用矩阵计算得到每个位置的结果

multi-head attention

相较于单头注意力,使用多头注意力的目的在于

  1. 从不同表征空间挖掘不同位置之间的关联。

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

  1. 单头注意力在计算不同位置间关联时用到了加权平均,这在一定程度上影响了特征计算的准确性,因此要用多头注意力来抵消这种影响

In these models, the number of operations required to relate signals from two arbitrary input or output positions grows in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes it more difficult to learn dependencies between distant positions [12]. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.

单头注意力包含WQ,WK,WV, 产生一个output,多头注意力则包含n个WQ,WK,WV,这些参数的权重不共享,产生n个output

这n个output被拼接到一起,并对拼接结果再次进行projection得到最终结果

self-attention实现

  1. 首先是single attention
def attention(query, key, value, mask=None, dropout=None):# 获取维度, query, key, value 的size 均为(batch_size,  n_head, seq_length, hidden_state_length)d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(d_k)if mask:scores = score.masked_fill(mask==0, -1e9)p_atten = scores.softmax(dim = -1)if dropout:p_atten = dropout(p_atten)return torch.matmul(p_atten, vakue), p_atten

在transformer decoder中会在self-attention中使用mask,在encoder中不会用到。因为本篇文章主要讲解self-attention因此没有讲解mask的使用,下一篇讲解transformer的文章中会具体分析self-attention在decoder和encoder中的区别。

  1. multi head attention实现
class MultiHeadAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super(MultiHeadAttention, self).__init__()assert d_model % h == 0self.d_k = d_model / hself.h = h self.wq =  nn.Linear(d_model, d_model)self.wk =  nn.Linear(d_model, d_model)self.wv =  nn.Linear(d_model, d_model)self.wo =  nn.Linear(d_model, d_model)self.atten = None self.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):if mask:# same mask applied to all headsmask = mask.unsqueeze(1)nbatches = query.size(0)query = self.wq(query).view(nbatches, -1, self.h, self.dk).transpose(1, 2)key = self.wk(key).view(nbatches, -1, self.h, self.dk).transpose(1, 2)value = self.wv(value).view(nbatches, -1, self.h, self.dk).transpose(1, 2)x, self.atten = attention(query, key, value, mask, self.dropout)# concat n heads outputs x = (x.transpose(1,2).contiguous().view(nbatches, -1, self.h*self.d_k))del querydel keydel valuereturn self.wo(x)

从multihead attention的实现中看出,实际上是将维度为(nbatches, seq_length, d_model)的矩阵,利用矩阵变换,得到了一个 (nbatches,h, seq_length,dk)的矩阵,且h*dk = d_model。

在attention中计算的时候所有head并行计算,得到一个(nbatches,h, seq_length,dk)的输出,对这个输出结果在进行矩阵变换得到 (nbatches, seq_length, d_model)的矩阵。完成了所谓的‘矩阵拼接’

拼接后的矩阵经过wo计算得到最终结果

ref:
Attention is all you need
The Illustrated Transformer
Visualizing A Neural Machine Translation Model (Mechanics of Seq2seq Models With Attention)
The Annotated Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/736442.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Profibus转Modbus网关在智能化水处理系统优化改造的应用

Profibus协议和Modbus协议作为两种常见的工业通信协议,各自具有一定的优势和适用范围。而通过Profibus转Modbus网关(XD-MDPB100)的加入将两者结合使用,可以实现不同设备之间的无缝连接和数据传输,为罐内压载水处理系统的监控和控制提供了更为便利的解决方案。通过Profibus…

代码随想录算法训练营第四十四天 | 322.零钱兑换 279.完全平方数 139.单词拆分

322.零钱兑换 题目链接 文章讲解 视频讲解class Solution { public:int coinChange(vector<int>& coins, int amount) {// dp[j]: 表示能凑成面额j所需的最少硬币个数vector<int> dp(amount + 1, 0);// 递推公式: dp[j] = min(dp[j - coins[i]] + 1, dp[j])// …

中奖与抽奖次序无关

中奖与抽奖次序无关前言 典例剖析 【人教 2019A 版教材\(P_{262}\) 页习题10.3 第 6 题改编】在一个袋子中放 \(6\) 个白球,\(4\) 个红球,摇匀后随机摸球 \(3\) 次,采用放回和不放回两种方式摸球 . 设事件 \(A_{i}\)=“第 \(i\) 次摸到红球”,\(i=1,2,3\) . (1). 在两种摸球…

为什么 [] == ![] 为 true?

🧑‍💻 写在开头 点赞 + 收藏 === 学会🤣🤣🤣前言面试官问我,[] == ![] 的结果是啥,我:蒙一个true; 面试官:你是对的;我:内心非常高兴; 面试官:解释一下为什么; 我:一定要冷静,要不就说不会吧;这个时候,面试官笑了,同学,感觉你很慌的一批啊!不必慌张…

odoo学习-2

1. 新加自定义模块odoo同级目录下新建my_addons文件夹加入自己的模块(注意:views中也要创建一个xml文件) 2. model代码-写在models下面的py文件中from odoo import api, fields, modelsclass EpidemicRecord(models.Model):_name = epidemic.record # 数据库表明name = fie…

C++定义函数指针,回调C#

C++定义函数指针,回调C#C++定义函数指针。 typedef int (__stdcall * delegate_func)(int a, int b); 暴露接口:int __stdcall CPPcallCSharp(delegate_func func); 方法实现:int __stdcall CPPcallCSharp(delegate_func func) { return func(1,2); } 头文件calculator.h#if…

《DNK210使用指南 -CanMV版 V1.0》第七章 基于CanMV的MicroPython语法开发环境搭建

第七章 基于CanMV的MicroPython语法开发环境搭建章节摘自【正点原子】DNK210使用指南 - CanMV版 V1.03)购买链接:https://detail.tmall.com/item.htm?&id=782801398750 4)全套实验源码+手册+视频下载地址:http://www.openedv.com/docs/boards/k210/ATK-DNK210.html 5)…

金蝶云星空字段之间连续触发值更新

场景说明字段A配置了字段B的计算公式,字段B配置了自动C的计算公式,修改A的时候,触发了B的重算,但是C触发不到。 具体需求:配置值更新事件:料本,料本系数, PCBA加工费,整机装配费,税率%【字段A】公式:供应链含税报价 = ( 料本 * 料本系数 + PCBA加工费 + 整…

PaddleNLP UIE 实体关系抽取

目录环境依赖配置SSH克隆代码训练定制代码结构数据标注准备语料库数据标注导出数据数据转换doccanoLabel Studio模型微调问题处理找不到 paddlenlp.trainer找不到GPUprotobuf==3.20.2CUDA/cuDNN/paddle PaddleNLP UIE 实体关系抽取 PaddlePaddle用户可领取免费Tesla V100在线算…

Python对历年高考分数线数据用聚类、决策树可视化分析一批、二批高校专业、位次、计划人数数据|附代码数据

全文链接:https://tecdat.cn/?p=36626 原文出处:拓端数据部落公众号 随着高等教育的普及与竞争的日益激烈,高考作为通往高等教育的重要门槛,其分数线的波动、高校及专业的选择成为了社会广泛关注的焦点。考生和家长在面临众多高校和专业的选择时,往往需要综合考虑多种因素…

阿里228x82y还原之递归数组解密

声明 本文章中所有内容仅供学习交流,抓包内容、敏感网址、数据接口均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关,若有侵权,请联系我立即删除! 目标网站 某里228 分析逆向流程 228递归函数str解密 原理就是用数组push最后填充下,然…

ffmpeg在Windows上的安装

首先进入官网Download FFmpeg 选择windows版本下载想要的版本 Gyan.dev的版本可能会更符合Windows标准,而BtbN的版本可能会更加开放和跨平台往下拉选择想要的版本进行下载 我下载的是第一个下载好之后解压文件复制bin目录的路径 接着按照下面的顺序进行环境配置,结束后一路确…

leaflet如何把低层级瓦片在高层级显示

https://leafletjs.cn/reference.html#gridlayer使用了maxNativeZoom属性 示例 let map = L.map("map", {attributionControl: false,maxZoom: 18, }).setView([62, -82], 6);let layer_keepLevel_16 = L.tileLayer("url", {minZoom: 1,maxZoom: 18,maxNat…

Java JVM——11. 执行引擎

1.概述执行引擎属于JVM的下层,里面包括:解释器、即时编译器、垃圾回收器。执行引擎是Java虚拟机核心的组成部分之一。“虚拟机”是一个相对于“物理机”的概念,这两种机器都有代码执行能力,其区别是物理机的执行引擎是直接建立在处理器、缓存、指令集和操作系统层面上的,而…

vue3+vite打包优化

1、清除console和debugger 安装 terser插件npm install terser -Dbuild里添加terserOptions配置// 打包环境移除console.log,debugger terserOptions: { compress: { drop_console: true, drop_debugger: true } }, 二、gzip静态资源压缩 第一步:客户端打包开启首先下载 vit…

Java JVM 执行引擎深入解析

1.执行引擎概述执行引擎属于JVM的下层,里面包括:解释器、即时编译器、垃圾回收器。执行引擎是Java虚拟机核心的组成部分之一。“虚拟机”是一个相对于“物理机”的概念,这两种机器都有代码执行能力,其区别是物理机的执行引擎是直接建立在处理器、缓存、指令集和操作系统层面…

Vuex

Vuex 什么是Vuex? 概念:专门在Vue中实现集中式状态(数据)管理的一个Vue插件,对应用中多个组件的共享状态进行集中式管理(读/写),也是组件间通信的方式,且适用于任意组件间通信 之前想要传递数据,可以使用全局事件总线/消息订阅去实现,但是如果有很多组件都想要去读和写…

27-String类

String字符串是常量,创建之后不可改变 字符串字面值存储在字符串池中,可以共享 String s = "hello"; 产生一个对象,字符串池中存储 String s = new String("hello");//产生两个对象,堆、池各存储一个String name = "hello";//"hello…

01字典树和可持久化01字典树

01字典树 01字典树是一种只有0和1两种边的字典树。可以解决查询第 \(k\) 小,查询 \(x\) 是第几小等问题。 查询第 \(k\) 小 可以把输入的数转成等长二进制,然后插入01字典树。比如将 \([0,0,1,3,3]\) 插入字典树:这里红色数字表示以该段为前缀的数的个数,黑色表示对应的数。…

c# , net 创建树形结构,创建树形节点

/// <summary> /// 生成树形结构 /// </summary> public void GetTreeNode() {//SqlHelper.GetSqlDataReader是封装的查询数据库语句,可根据自己需求封装//假设获取所有一级节点List<Products> products = SqlHelper.GetSqlDataReader(sql);for (int i = 0; …