Llama模型结构解析(源码阅读)

目录

  • 1. LlamaModel整体结构流程图
  • 2. LlamaRMSNorm
  • 3. LlamaMLP
  • 4. LlamaRotaryEmbedding

  • 参考资料:
    https://zhuanlan.zhihu.com/p/636784644
    https://spaces.ac.cn/archives/8265 ——《Transformer升级之路:2、博采众长的旋转式位置编码》

前言:本次阅读代码位置,在transformers库底下的modeling_llama.py,具体位置在:transformers/models/llama/modeling_llama.py,如下图所示:在这里插入图片描述

1. LlamaModel整体结构流程图

在这里插入图片描述

2. LlamaRMSNorm

  • 代码如下
class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypevariance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return (self.weight * hidden_states).to(input_dtype)
  • RMSNorm的公式如下所示:
    x i 1 n ∑ i = 1 n x i 2 + e p s ∗ w e i g h t i \frac{x_i}{\sqrt{\frac{1}{n}\sum\limits_{i=1}^{n}{x_i}^2 + eps}} * weight_i n1i=1nxi2+eps xiweighti

    • 其中,公式与代码的对应关系如下:
      在这里插入图片描述

3. LlamaMLP

  • 代码如下:
class LlamaMLP(nn.Module):def __init__(self,hidden_size: int,intermediate_size: int,hidden_act: str,):super().__init__()self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)self.act_fn = ACT2FN[hidden_act]def forward(self, x):return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  • 流程图:
    在这里插入图片描述

  • 其中输入为x,输出为y

  • 代码中intermediate_size一般比hidden_size大,我们通过在jupyter notebook中打印Llama-13B的模型,可以看到如下所示:
    在这里插入图片描述

  • 总结:MLP模块就是几个nn.Linear的组合

4. LlamaRotaryEmbedding

  • 代码如下

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
  • 具体的使用,还调用了另外两个函数,如下所示:
def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed
  • 注意这里的实现跟原始推导有点区别,这里实现的方式如下图所示:
    在这里插入图片描述

  • 原始推导如下图所示:
    在这里插入图片描述
    具体可以查看作者的博客:👉戳我👈

  • 总结:RoPE就是在attention计算时,K跟Q做内积之前,先给各自注入位置信息。

结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/95974.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringCloud(十)——ElasticSearch简单了解(二)DSL查询语句及RestClient查询文档

文章目录 1. DSL查询文档1.1 DSL查询分类1.2 全文检索查询1.3 精确查询1.4 地理查询1.5 查询算分1.6 布尔查询1.7 结果排序1.8 分页查询1.9 高亮显示 2. RestClient查询文档2.1 查询全部2.2 其他查询语句2.3 排序和分页2.4 高亮显示 1. DSL查询文档 1.1 DSL查询分类 查询所有…

使用Windbg动态调试排查软件启动不了的问题

目录 1、问题说明 2、初步分析 3、使用Windbg启动程序进行动态调试 4、进一步分析 5、何时使用Windbg静态分析?何时使用Windbg进行动态调试? 6、最后 VC常用功能开发汇总(专栏文章列表,欢迎订阅,持续更新...&…

14:00面试,14:08就出来了,问的问题有点变态

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到8月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%,…

Kubernetes可视化管理工具Kuboard部署使用及k8s常用命令梳理记录

温故知新 📚第一章 前言📗背景📗目的📗总体方向 📚第二章 安装 Kubernetes 多集群管理工具 - Kuboard v3📗部署方式📗通过Kuboard v3 - Kubernetes安装(在master节点执行)&#x1f4…

SpringBoot复习:(60)文件上传的自动配置类MultipartAutoConfiguration

可以看到,定义了一个类型为StandartServletMultipartResolver的bean 用来进行文件上传,定义了一个类型为MultipartConfigElement的bean用来进行上传相关的配置,其中使用了MultipartProperties中的属性,这个类的定义如下&#xff1…

vue+element-ui el-table组件二次封装实现虚拟滚动,解决数据量大渲染DOM过多而卡顿问题

一、此功能已集成到TTable组件中 二、最终效果 三、需求 某些页面不做分页时,当数据过多,会导致页面卡顿,甚至卡死 四、虚拟滚动 一、固定一个可视区域的大小并且其大小是不变的,那么要做到性能最大化就需要尽量少地渲染 DOM 元素…

梯度下降算法入门

提到梯度下降我们知道梯度下降算法是很多机器学习算法、深度学习算法的基础。 首先我们需要明确一些概念什么是梯度: 梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处…

软件测试/测试开发丨Python 内置库 正则表达式

点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27058 python 内置库 正则表达式 目录 正则表达式使用re模块实现正则表达式操作 正则表达式 正则表达式就是记录文本规则的代码可以查找操作符合某些复…

HBuilderX修改manifest.json设置,解决跨域问题(CORS、Cross-Origin)

搭建一个前台uniapp,后台springboot的开发环境时,遇到了跨域问题。 console提示错误信息: Access to XMLHttpRequest at http://10.0.180.203/api/cms/getAdList?apId1 from origin http://localhost:8080 has been blocked by CORS policy…

ROS-5.自定义topic消息格式

自定义topic消息格式 1. 定义消息1.1. 定义msg文件1.2. 在package.xml中添加功能包依赖1.3. 在CMakeList.txt添加编译选项1.4. 编译 2.定义发布者和订阅者2.1 定义发布者2.2. 定义订阅者2.3. 修改CMakeList.txt2.4 编译 3. 使用消息3.1 启动ros主程序3.2. 启动发布者3.3 启动订…

如何制作并运行 jar 程序

以下是用 Intellij 制作 jar 程序,并运行的方法。 【1】新建工程,保持默认选项,Next 【2】保持默认选项,Next 【3】给工程命名,设置保存位置,Finish 【4】新建工程结束,进入开发界面 【5】展开…

ArcGIS将两个相同范围但不同比例或位置的矢量数据移动到相同位置

有两个市图层,一个是正确经纬度的市行政范围图层,另一个是其他软件导出获取的不正确经纬度信息或缺失信息。 如果单纯的依靠移动图层,使不正确的移动到正确位置需要很久。尝试定义投影等也不能解决。 使用ArcMap 的空间校正工具条&#xff…