旋转位置编码原理及代码

旋转位置编码原理及代码

旋转位置编码

  • RoPE(Rotary Positional Encoding)
  • 当位置发生偏移时,只需要旋转角度
  • 外推性,指大模型输入长度超过预训练文本长度时,输出表现变化情况。
    • 使用绝对位置编码具有外推性上的限制,旋转位置编码则没有
  • 旋转位置编码还具有一个相对位置编码的优点
    • 两个token之间如果具有的相对位置,无论两个token存在句子的哪个位置都会有相同的表示,从下图中可以看出,角度即为相对位置偏移量,token在句中的位置不同可以在坐标系上体现
  • 在二维数据中可以看出是在qk上乘了一个旋转矩阵
  • 旋转矩阵的特质
    • R(a)的转置 = R(-a)(正反两个方向的转动)
    • R(a)R(b) = R(a+b)(转动了a+b的角度)

旋转位置编码的核心是找到对应的旋转矩阵
在这里插入图片描述
LLaMA中旋转矩阵相关代码

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):# 计算词向量元素两两分组之后,每组元素对应的旋转角度freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))# 生成 token 序列索引 t = [0, 1,..., seq_len-1]t = torch.arange(seq_len, device=freqs.device)# freqs.shape = [seq_len, dim // 2] freqs = torch.outer(t, freqs).float()# torch.polar 的文档# https://pytorch.org/docs/stable/generated/torch.polar.html# 计算结果是个复数向量# 假设 freqs = [x, y]# 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cisdef apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:# xq.shape = [batch_size, seq_len, dim]# xq_.shape = [batch_size, seq_len, dim // 2, 2]xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)# 转为复数域xq_ = torch.view_as_complex(xq_)xk_ = torch.view_as_complex(xk_)# 应用旋转操作,然后将结果转回实数域# xq_out.shape = [batch_size, seq_len, dim]xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)return xq_out.type_as(xq), xk_out.type_as(xk)class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.wq = Linear(...)self.wk = Linear(...)self.wv = Linear(...)self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)def forward(self, x: torch.Tensor):bsz, seqlen, _ = x.shapexq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(batch_size, seq_len, dim)xk = xk.view(batch_size, seq_len, dim)xv = xv.view(batch_size, seq_len, dim)# attention 操作之前,应用旋转位置编码xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)# scores.shape = (bs, seqlen, seqlen)scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)scores = F.softmax(scores.float(), dim=-1)output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)

注:代码部分为复制粘贴,后续会对代码进行整理总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/469628.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ps:焦点堆栈

焦点堆栈 Focus Stacking是一种摄影和图像处理技术,通过合并多张在不同焦距拍摄的照片来创建一张具有更大景深的图像,特别适用于微距摄影、风景摄影和任何需要在整个场景中保持尖锐对焦的情况。 ◆ ◆ ◆ 拍摄注意事项 1、使用三脚架 为了确保图像之间…

《小强升职记:时间管理故事书》阅读笔记

目录 前言 一、你的时间都去哪儿了 1.1 你真的很忙吗 1.2 如何记录和分析时间日志 1.3 如何找到自己的价值观 二、无压工作法 2.1 传说中的“四象限法则 2.2 衣柜整理法 三、行动时遇到问题怎么办? 3.1 臣服与拖延 3.2 如何做到要事第一? 3.…

【一周年】我的创作纪念日

今天,是我成为创作者的第366天,不知不觉,来CSDN已经一年啦~ 在这个特殊的日子,也给大家讲讲我的创作故事。 一、机缘 起初,刚认识CSDN时,我的高中生涯刚结束,顺利从一名懵懂的高中生变身为一名懵…

Linux第53步_移植ST公司的linux内核第5步_系统镜像打包并烧录到EMMC

本节主要学习系统镜像打包,然后将打包文件烧录到EMMC测试。 1、创建bootfs文件夹 1)、打开第1个终端 输入“ls回车” 输入“cd linux/回车”,切换到“linux”目录 输入“ls回车”,列出“linux”目录下的文件和文件夹 输入“cd atk-mp1/…

Go语言中的加密艺术:深入解析crypto/subtle库

Go语言中的加密艺术:深入解析crypto/subtle库 引言crypto/subtle库概览ConstantTimeCompare函数深入解析ConstantTimeSelect函数应用详解ConstantTimeLessOrEq函数实践指南安全编程实践性能优化与最佳实践与其他加密库的比较总结 引言 在当今快速发展的互联网时代&…

localStorage、sessionStorage、cookie区别

localStorage: localStorage 的生命周期是永久的,关闭页面或浏览器之后 localStorage 中的数据也不会消失。localStorage 除非主动删除数据,否则数据永远不会消失 sessionStorage: sessionStorage 的生命周期是仅在当前会话下有效。sessionStorage 引入…

C语言第二十四弹---指针(八)

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】 指针 1、数组和指针笔试题解析 1.1、字符数组 1.1.1、代码1: 1.1.2、代码2: 1.1.3、代码3: 1.1.4、代码4: 1…

C++的进阶泛型编程学习(1):函数模板的基本概念和机制

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、模板1.1 模板的概念1.1.1 形象的解释:模板就是通用的模具,目的是提高通用性1.1.1 模板的特点:1.1.2 综述模板的作用 1.2…

OpenGL-ES 学习(1)---- AlphaBlend

AlphaBlend OpenGL-ES 混合本质上是将 2 个片元的颜色进行调和(一般是求和操作),产生一个新的颜色 OpenGL ES 混合发生在片元通过各项测试之后,准备进入帧缓冲区的片元和原有的片元按照特定比例加权计算出最终片元的颜色值,不再是新&#xf…

Codeforces Round 920 (Div. 3)

D. Very Different Array(贪心双指针/前缀和) 思路:绝对值就是线段-->让线段最长(肯定是越在最短端找最右端的 越最右端找最左端的)-->判断怎么连哪段最长(采用双指针的策略去判断) (左红…

Swift Combine 通过用户输入更新声明式 UI 从入门到精通十五

Combine 系列 Swift Combine 从入门到精通一Swift Combine 发布者订阅者操作者 从入门到精通二Swift Combine 管道 从入门到精通三Swift Combine 发布者publisher的生命周期 从入门到精通四Swift Combine 操作符operations和Subjects发布者的生命周期 从入门到精通五Swift Com…

uniapp前端手机获取安全区域css值 防止按键不能被点击

引入 再编写小程序和移动端的时候可能会出现这种情况,页面中的按键刚好才手机中按不到的位置 如下 这是苹果手机的home按键 如果刚好我们的按钮再这个位置,用户是点击不到的 我们就需要一个办法,能够自动的让我们的按键移动到安全可点击的区域 解决 我们可以使用…