【Transformer】笔记

主要参考
https://zhuanlan.zhihu.com/p/366592542
https://mp.weixin.qq.com/s/b-_M8GPK7FD7nbPlN703HQ

其他参考
原理 https://zhuanlan.zhihu.com/p/627448301
多头注意力机制 https://zhuanlan.zhihu.com/p/611684065
https://blog.csdn.net/shizheng_Li/article/details/131721198

面试概念

https://zhuanlan.zhihu.com/p/425336990

RNN

RNN 循环神经网络(Rerrent Neural Network,RNN),能够将之前的信息储存在隐藏层中,从而与后面的信息进行计算。
问题是:不能并行计算,而且对于长序列,容易出现记忆丢失的问题,也就是梯度消失
在这里插入图片描述
RNN 的关键点之一就是他们可以用来连接先前的信息到当前的任务上,例如使用过去的视频段来推测对当前段的理解。

LSTM

长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。
输入门,记忆门,遗忘门。

后来还提出了双向LSTM ,BILLSTM,来解决后面序列信息对前面的影响

Transformer

attention

注意力机制, 分为self-attention, multi-head attention等。
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility),每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出。
在这里插入图片描述

class ScaledDotProductAttention(nn.Module):""" Scaled Dot-Product Attention """def __init__(self, scale):super(ScaledDotProductAttention,self).__init__()self.scale = scaleself.softmax = nn.Softmax(dim=2)def forward(self, q, k, v, mask=None):u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmulu = u / self.scale # 2.Scaleif mask is not None:u = u.masked_fill(mask, -np.inf) # 3.Maskattn = self.softmax(u) # 4.Softmaxoutput = torch.bmm(attn, v) # 5.Outputreturn attn, outputif __name__ == "__main__":batch = 2n_q, n_k, n_v = 2, 4, 4d_q, d_k, d_v = 128, 128, 64q = torch.randn(batch, n_q, d_q)k = torch.randn(batch, n_k, d_k)v = torch.randn(batch, n_v, d_v)mask = torch.zeros(batch, n_q, n_k).bool()attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))attn, output = attention(q, k, v, mask=mask)print(attn)print(output)

mask = torch.zeros(batch, n_q, n_k).bool()
这行代码是在使用 PyTorch 创建一个布尔型的零张量。具体来说,它创建了一个形状为 (batch, n_q, n_k) 的张量,其中的所有元素都被初始化为 False(因为在 Python 中,False 等价于 0,True 等价于 1)。

u = u.masked_fill(mask, -np.inf)
masked_fill 是一个 PyTorch 张量的方法,它将 mask 中为 True 的元素的对应位置上的 u 中的元素替换为 -np.inf
这里的关键在于理解 mask 的作用。mask 是一个布尔型张量,其中的 TrueFalse 值表示我们希望保留还是忽略对应的 u 中的元素。在这种情况下,我们希望忽略 mask 中为 True 的元素,因此在 u 中将这些位置的值设置为负无穷大(-np.inf)。

这样做的目的可能是为了在接下来的操作中排除这些被标记的元素。例如,如果我们接下来要对 u 进行 softmax 操作,由于负无穷大在softmax 运算中会被视为 0,这样我们就可以有效地忽略掉那些在 mask 中被标记为 True 的元素。

self attention

self -attention 就是QKV 都是本身的注意力机制,比如transformer模型中的Encoder部分。self-attention 在文本序列中,能够挖掘出文本中不同字词之间的联系。不同与LSTM是有向性的记忆与遗忘字词之间的关系。

multi-head attention

注意力并行化的代表,多头注意力不仅计算一次注意力,而是并行化计算多次注意力,这样模型可以同时关注多个子空间的信息。

class MultiHeadAttention(nn.Module):""" Multi-Head Attention """def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):super().__init__()self.n_head = n_headself.d_k = d_kself.d_v = d_vself.fc_q = nn.Linear(d_k_, n_head * d_k)self.fc_k = nn.Linear(d_k_, n_head * d_k)self.fc_v = nn.Linear(d_v_, n_head * d_v)self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))self.fc_o = nn.Linear(n_head * d_v, d_o)def forward(self, q, k, v, mask=None):n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_vbatch, n_q, d_q_ = q.size()batch, n_k, d_k_ = k.size()batch, n_v, d_v_ = v.size()q = self.fc_q(q) # 1.单头变多头k = self.fc_k(k)v = self.fc_v(v)q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, d_v)if mask is not None:mask = mask.repeat(n_head, 1, 1)attn, output = self.attention(q, k, v, mask=mask) # 2.当成单头注意力求输出output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1) # 3.Concatoutput = self.fc_o(output) # 4.仿射变换得到最终输出return attn, outputif __name__ == "__main__":n_q, n_k, n_v = 2, 4, 4d_q_, d_k_, d_v_ = 128, 128, 64q = torch.randn(batch, n_q, d_q_)k = torch.randn(batch, n_k, d_k_)v = torch.randn(batch, n_v, d_v_)    mask = torch.zeros(batch, n_q, n_k).bool()mha = MultiHeadAttention(n_head=8, d_k_=128, d_v_=64, d_k=256, d_v=128, d_o=128)attn, output = mha(q, k, v, mask=mask)print(attn.size())print(output.size())

soft attention 与 hard attention

Soft attention, NLP中尝试用的注意力方式,取值为[0, 1]的权重概率分布,使用了所有编码层的隐层状态,与上两节的介绍相同,可以直接在模型训练过程中,通过后向传播优化对参数进行优化。

Hard attention, Hard attention 在原文中被称为随机硬注意力(Stochastic hard attention),这里的随机是指对编码层隐状体的采样过程,Hard attention 没有使用到所有的隐层状态,而是使用one-hot的形式对某个区域提取信息,使用这种方式无法直接进行后向传播(梯度计算),需要蒙特卡洛采样的方法来估计梯度。就好比python中的简单字典取值

相对位置编码 与 绝对位置编码

缩放因子

Transfomer中使用到的缩放点积注意力, 是点积计算的延申,增加了一个缩放因子。

在论文中我们注意到作者在做了 QK^T 时还除以一个sqrt(d_k)d_kdim的维度,作者给出的解释如:

We suspect that for large values of d_k , the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by d_k .

梯度消失问题:神经网络的权重与损失的梯度成比例地更新。问题是,在某些情况下,梯度会很小,有效地阻止了权重更新。简单来说就是这样可以优化结果

Unnormalized softmax:考虑一个正态分布。分布的 softmax 值在很大程度上取决于它的标准差。由于标准偏差很大,softmax 只存在一个峰值,其他全部几乎为0。

我们在注意力中做了一个softmax,假定说当前的数据分布方差较大,那么除了某几个位置是1,其它位置可能都接近0,而那些接近0的位置这样计算过后,在梯度反向传播时,我们只能获得一个很小的更新,不利于网络进行学习,所以我们应该降低整个分布的方差,这样可以让网络进行更好的训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/504786.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

双周回顾#006 - 这三个月

断更啦~~ 上次更新时间 2023/11/23, 断更近三个月的时间。 先狡辩下,因为忙、着实忙。因为忙,心安理得给断更找了个借口,批评下自己~~ 这三个月在做啥?跨部门援助,支援公司互联网的 ToC 项目,一言难尽。 …

破解SQL Server迷局,彻底解决“管道的另一端无任何进程错误233”

问题描述:在使用 SQL Server 2014的时候,想用 SQL Server 身份方式登录 SQL Servcer Manager,结果报错: 此错误消息:表示SQL Server未侦听共享内存或命名管道协议。 问题原因:此问题的原因有多种可能 管道…

用冒泡排序模拟C语言中的内置快排函数qsort!

目录 ​编辑 1.回调函数的介绍 2. 回调函数实现转移表 3. 冒泡排序的实现 4. qsort的介绍和使用 5. qsort的模拟实现 6. 完结散花 悟已往之不谏,知来者犹可追 创作不易,宝子们!如果这篇文章对你们有帮助的话,别忘了给个免…

配置MySQL与登录模块

使用技术 MySQL,Mybatis-plus,spring-security,jwt验证,vue 1. 配置Mysql 1.1 下载 MySQL :: Download MySQL Installer 1.2 安装 其他页面全选默认即可 1.3 配置环境变量 将C:\Program Files\MySQL\MySQL Server 8.0\bin…

嵌入式学习31-指针和函数知识回顾

1.指针: 1.提供一种间接访问数据的方法 2.空间没有名字,只有一个地址编号 2.指针: 1.地址:区分不同内存空间的编号 2.指针:指针就是地址,地址就是指针 3.指针变量:存放指针的变量称为指针变量,简称为指针 3.指针的定义: int *p NULL; …

学生云服务器_学生云主机_学生云数据库_云+校园特惠套餐

2024年腾讯云学生服务器优惠活动「云校园」,学生服务器优惠价格:轻量应用服务器2核2G学生价30元3个月、58元6个月、112元一年,轻量应用服务器4核8G配置191.1元3个月、352.8元6个月、646.8元一年,CVM云服务器2核4G配置842.4元一年&…

‘conda‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件

如果你在运行 conda 命令时收到了 ‘conda’ 不是内部或外部命令,也不是可运行的程序或批处理文件。 的错误消息,这可能意味着 Anaconda 并没有正确地添加到你的系统路径中。 1.你可以尝试手动添加 Anaconda 到系统路径中。以下是在 Windows 系统上添加…

【风格迁移】DSM-GANs:为不同的域(照片和绘画风格)创建特定的映射函数,以改善风格转换的质量和准确性

DSM-GANs:为不同的域(照片和绘画风格)创建特定的映射函数,以改善风格转换的质量和准确性 提出背景DSM-GANs 域特定映射 域特定内容空间 针对性损失函数设计模型如何进行风格转换和图像到图像翻译 提出背景 论文:ht…

超详细的 pytest 钩子函数 之初始钩子和引导钩子来啦

前几篇文章介绍了 pytest 点的基本使用,学完前面几篇的内容基本上就可以满足工作中编写用例和进行自动化测试的需求。从这篇文章开始会陆续给大家介绍 pytest 中的钩子函数,插件开发等等。 仔细去看过 pytest 文档的小伙伴,应该都有发现 pyt…

前端学习第一天-html基础

达标要求 网页的形成过程 常用的浏览器及常见的浏览器内核 web 标准三层组成 什么是HTML 熟练掌握HTML文档结构 熟练掌握HTML常用标签 1. 初识web前端 Web前端是创建Web页面或App等前端界面呈现给用户的过程。 Web前端开发是从网页制作演变而来,早期网站主…

【计算复杂性理论】证明复杂性(九):命题鸽巢原理的指数级归结下界——更简短的证明

往期文章: 【计算复杂性理论】证明复杂性(Proof Complexity)(一):简介 【计算复杂性理论】证明复杂性(二):归结(Resolution)与扩展归结&#xff…

什么是VR数字文化遗产保护|元宇宙文旅

VR数字文化遗产保护是指利用虚拟现实(VR)技术来保护和传承文化遗产。在数字化时代,许多珍贵的文化遗产面临着自然衰退、人为破坏或其他因素造成的威胁。通过应用VR技术,可以以全新的方式记录、保存和展示文化遗产,从而…