Transformer的代码实现 day03(Positional Encoding)

Positional Encoding的理论部分

  • 注意力机制是不含有位置信息,这也就表明:“我爱你”,“你爱我”这两者没有区别,而在现实世界中,这两者有区别。
  • 所以位置编码是在进行注意力计算之前,给输入加上一个位置信息,如下图:
    在这里插入图片描述
  • 位置编码的公式如下:
    • 注意,pos表示该单词在句子中的位置,i表示该单词的输入向量的第i维度
      在这里插入图片描述
  • 由此我们可以得出不同位置之间的位置编码关系:
    在这里插入图片描述

Positional Encoding代码

  • 由于位置编码的公式固定,所以对于相同位置的位置编码也固定,即“我爱你”中的我,和“你爱我”中的你的位置编码相同
  • 所以我们可以一次将所有要输入信息的位置编码都生成出来,之后需要哪个就传哪个
class PositionalEncoding(nn.Module):def __init__(self, dim, dropout, max_len=5000):super(PositionalEncoding, self).__init__()# 确保每个单词的输入维度为偶数,这样sin和cos能配对if dim % 2 != 0:raise ValueError("Cannot use sin/cos positional encoding with ""odd dim (got dim={:d})".format(dim))"""构建位置编码pepe公式为:PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})"""pe = torch.zeros(max_len, dim)  # max_len 是解码器生成句子的最长的长度,假设是 10,dim为单词的输入维度# 将位置序号从一维变为只有一列的二维,方便与div_term进行运算,# 如将[0, 1, 2, 3, 4]变为:#[  #  [0],  #  [1],  #  [2],  #  [3],  #  [4]  #]position = torch.arange(0, max_len).unsqueeze(1)# 这里使用a^b = e^(blna)公式,来简化运算# torch.arange(0, dim, 2, dtype=torch.float)表示从0到dim-1,步长为2的一维张量# 通过以下公式,我们可以得出全部2i的(pos/10000^2i/dim)方便接下来的pe计算div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *-(math.log(10000.0) / dim)))# 得出的div_term为从0开始,到dim-1,长度为dim/2,步长为2的一维张量# 将position与div_term做张量乘法,得到的张量形状为(max_len,dim/2)# 将结果取sin赋给pe中偶数维度,取cos赋给pe中奇数维度pe[:, 0::2] = torch.sin(position.float() * div_term)pe[:, 1::2] = torch.cos(position.float() * div_term)# 将pe的形状从(max_len,dim)变成(max_len,1,dim),在第二个维度上增加一个大小为1的新维度# 如从原始 pe 张量形状: (5, 4)  #[  # [a1, b1, c1, d1],  # [a2, b2, c2, d2],  # [a3, b3, c3, d3],  # [a4, b4, c4, d4],  # [a5, b5, c5, d5]  #]# 转换为:执行 unsqueeze(1) 后的 pe 张量形状: (5, 1, 4)  #[  # [[a1, b1, c1, d1]],  # [[a2, b2, c2, d2]],  # [[a3, b3, c3, d3]],  # [[a4, b4, c4, d4]],  # [[a5, b5, c5, d5]]  #]pe = pe.unsqueeze(1)# 将pe张量注册为模块的buffer。在PyTorch中,buffer是模型的一部分,但不包含可学习的参数(即不需要梯度)。# 这样做是因为位置编码在训练过程中是固定的,不需要更新。self.register_buffer('pe', pe)self.drop_out = nn.Dropout(p=dropout)self.dim = dimdef forward(self, emb, step=None):# 做乘法是因为在 Transformer 模型中,位置编码被加到输入张量中,而输入张量通常是词嵌入的向量,其值通常在较小的范围内。# 但是,在将位置编码添加到输入张量之前,我们希望将其值扩大到一个较大的范围,以便位置编码对输入的影响更加显著。# 注意:emb为输入张量,形状为(seq_len, dim),seq_len 表示输入的句子的长度,dim为单词的输入维度emb = emb * math.sqrt(self.dim)# 根据step来选择加入pe的哪一部分if step is None:# 如果pe的形状为(max_len, dim),那么pe[:a]表示:取pe的第0行到第a-1行的全部元素,得到的新二维张量的形状为(a, dim)# 而pe[:, a]表示:取pe的第a-1列的全部元素,得到的新一维张量的形状为(max_len)# 而pe[:, :a]表示:取pe的第0列到第a-1列的全部元素,得到的新二维张量的形状为(max_len,a)emb = emb + self.pe[:emb.size(0)]else:emb = emb + self.pe[step]emb = self.drop_out(emb)return emb

参考文献

  1. 04 Transformer 中的位置编码的 Pytorch 实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/589651.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode 234. 回文链表

给你一个单链表的头节点 head ,请你判断该链表是否为 回文链表 。如果是,返回 true ;否则,返回 false 。 示例 1: 输入:head [1,2,2,1] 输出:true 示例 2: 输入:he…

Maven依赖管理项目构建工具

一、Maven简介 1、为什么学习Maven 1.1、Maven是一个依赖管理工具 ①jar 包的规模 随着我们使用越来越多的框架,或者框架封装程度越来越高,项目中使用的jar包也越来越多。项目中,一个模块里面用到上百个jar包是非常正常的。 比如下面的例…

jvm总结学习

四种加载器 1.启动类加载器 2.拓展类加载器 3.应用程序加载器 4.自定义加载器 沙箱机制 就是为了保证安全,增加的一些权限。 native方法区(静态变量,常量,类信息(构造方法,接口定义)&…

基于深度学习的钢材表面缺陷检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)

摘要:本文深入研究了基于YOLOv8/v7/v6/v5的钢材表面缺陷检测系统,核心采用YOLOv8并整合了YOLOv7、YOLOv6、YOLOv5算法,进行性能指标对比;详述了国内外研究现状、数据集处理、算法原理、模型构建与训练代码,及基于Strea…

算法基础--递推

😀前言 递推算法在计算机科学中扮演着重要的角色。通过递推,我们可以根据已知的初始条件,通过一定的规则推导出后续的结果,从而解决各种实际问题。本文将介绍递推算法的基础知识,并通过一些入门例题来帮助读者更好地理…

代码随想录第29天|491.递增子序列 46.全排列 47.全排列 II

目录: 491.递增子序列 46.全排列 47.全排列 II 491.递增子序列 491. 非递减子序列 - 力扣(LeetCode) 代码随想录 (programmercarl.com) 回溯算法精讲,树层去重与树枝去重 | LeetCode:491.递增子序列_哔哩哔哩_bili…

PS从入门到精通视频各类教程整理全集,包含素材、作业等(7)复发

PS从入门到精通视频各类教程整理全集,包含素材、作业等 最新PS以及插件合集,可在我以往文章中找到 由于阿里云盘有分享次受限制和文件大小限制,今天先分享到这里,后续持续更新 PS敬伟01——90集等文件 https://www.alipan.com/s…

盲盒一番赏小程序搭建:打造神秘与惊喜的赏玩新体验

随着移动互联网的快速发展,小程序因其便捷、轻量级的特点,逐渐成为了连接商家与消费者的新桥梁。盲盒一番赏小程序的搭建,旨在为用户带来一种全新的赏玩体验,满足他们对神秘与惊喜的追求。 盲盒一番赏小程序将传统的盲盒概念与一…

【智能算法】猎豹优化器(CO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2022年,MA Akbari等人受到自然界中猎豹捕猎行为启发,提出了猎豹优化器(The Cheetah Optimizer,CO)。 2.算法原理 2.1算法思想 CO法对猎…

P1102 A-B 数对 (非二分,不开龙永远的痛,用map解决)

可是我真的会伤心 题目链接 思路:1.本来想的是暴力,两层循环模拟每个数。 2.后来想先把每个数字的个数求出来放在数组nums【】中,并把不重复的数字存到数组b,再两层循环b数组应该时间复杂度会好些,如果b数组中的两个数…

欧拉路径欧拉回路

欧拉回路,指遍历图时通过图中每条边且仅通过一次,最终回到起点的一条闭合回路,适用于有向图与无向图,如果不强制要求回到起点,则被称为欧拉路径。 欧拉图:具备欧拉回路的图 无向图:图的所有顶…

全球范围内2nm晶圆厂建设加速

随着人工智能浪潮席卷而来,先进制程芯片的重要性日益凸显。当前,3nm工艺节点是行业内最先进的节点。与此同时,台积电、三星、英特尔、Rapidus等厂商正积极布局建设2nm晶圆厂。台积电与三星此前计划于2025年量产2nm芯片,而Rapidus则…