LLaMA开源大模型源码分析!

 Datawhale干货 

作者:宋志学,Datawhale成员

花了一晚上照着transformers仓库的LLaMA源码,把张量并行和梯度保存的代码删掉,只留下模型基础结构,梳理了一遍LLaMA的模型结构。

今年四月份的时候,我第一次接触深度学习,也是今年第一次接触Datawhale,在Datawhale和小伙伴一起学习、讨论了大半年,不知不觉已经可以做到看源码的程度了。

Datawhale才是一个没有围墙的大学,在这里无论你有什么想法💡,只要你愿意前进,总会有小伙伴和你一起。

博客地址:

https://flowus.cn/kmno4/share/527055be-464f-4f0f-98c5-8b8f72a1fc2e

LLaMA-Model

在transformers仓库中可以看到llama的源码,首先是LlamaModel类,继承自PreTrainedModel,这个类是所有模型的基类,包含了一些通用的方法,比如保存模型、加载模型、初始化权重等。

继承关系为:LlamaModel-> LlamaPreTrainedModel-> PreTrainedModel

LlamaConfig

LlamaConfig 中主要是定义一些参数,比如vocab_size、hidden_size、num_hidden_layers、num_attention_heads等。所有的参数有默认值,可以直接创建cofing就能用。

config = LlamaConfig()

LlamaModel

6783830017ed6f9a29869202cd8218ff.jpeg

LlamaModel 初始化

  • 设置了模型的两个属性:padding_idx(用于指定填充标记的索引),vocab_size(词汇表的大小)

  • 初始化了模型的嵌入层、解码器层、归一化层

  • 嵌入层(nn.Embedding):模型使用嵌入层将输入的标记映射成密集的向量表示。

  • 解码器层(nn.ModuleList()):模型包含多个解码器层,这些层都是由 LlamaDecoderLayer 定义

  • 归一化层 LlamaRMSNorm:归一化层使用的是 Root Mean Square Layer Normalization(RMS Layer Norm)

  • 设置了是否使用 gradient_checkpoint 主要是用来节省显存

  • 调用 post_init() 完成一些初始化和准备检查的代码

def __init__(self, config: LlamaConfig):super().__init__(config)self.padding_idx = config.pad_token_idself.vocab_size = config.vocab_size# embedding 层self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)# 中间的一堆 decoderlayers 层self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._use_sdpa = config._attn_implementation == "sdpa"self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.gradient_checkpointing = False# Initialize weights and apply final processingself.post_init()

可以看一下 post_init() 的代码,主要是初始化权重和gradient_checkpointing相关的一些事情。该方法在PreTrainedModel基类中,transformers中所有模型基本都继承这个类。

def post_init(self):"""A method executed at the end of each Transformer model initialization, to execute code that needs the model'smodules properly initialized (such as weight initialization)."""self.init_weights()self._backward_compatibility_gradient_checkpointing()

LlamaModel forward

forward 部分的代码有点长,但其实大部分都是张量并行或者是节省显存相关的代码,对于理解模型结构来说可以直接忽略。

首先进来就是把 inputs_ids 进行向量化,然后拿到 hidden_states 。然后是存起来所有的hidden_states 进入 decoder_layer 再拿一个 hidden_states,作为下一轮 decoder_layerhidden_states 输入,最后给 hidden_states norm一下。如下代码所示:

inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embedsfor decoder_layer in self.layers:# 存起来所有的 hidden_statesif output_hidden_states:all_hidden_states += (hidden_states,)# 这里是 decoder_layer 的 forwardlayer_outputs = decoder_layer(hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,)# 再拿一个 hidden_states,作为下一轮 decoder_layer 的 hidden_states 输入hidden_states = layer_outputs[0]hidden_states = self.norm(hidden_states)

最后就是以 BaseModelOutputWithPast 的形式输出。ok,接下来继续看decoder_layer中的其他代码。

LlamaDecoderLayer

Embedding层不用多说,用的就是torch中的nn.Embedding。那就直接来看DecoderLayer。

8e444bd665786abd4f4a471a8a72b244.png

DecoderLayers 初始化

先来看初始化。

  • hidden_size : 也就是在上面说的输入输出。

  • self_attn : 别看它写这么多啊,其实就是选一下用什么 attention 。看见大写字母不要怕,直接点进去看看怎么个事!

    LLAMA_ATTENTION_CLASSES = {"eager": LlamaAttention,"flash_attention_2": LlamaFlashAttention2,"sdpa": LlamaSdpaAttention,
    }
  • mlp : 一个全连接层 LlamaMLP 这个待会后面再说,输入输出都是 hidden_size 大小。

  • input_layernorm : LlamaRMSNorm 层,输入时候的norm

  • post_attention_layernorm : 丢入 mlp 之前的操作。

class LlamaDecoderLayer(nn.Module):def __init__(self, config: LlamaConfig, layer_idx: int):super().__init__()self.hidden_size = config.hidden_sizeself.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)self.mlp = LlamaMLP(config)self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\

DecoderLayers forward

首先复制一份 hidden_statesresidual。然后 hidden_states 进入 input_layernorm 进行norm。然后进入 self_attn 进行 attention 操作,拿到 hidden_statesself_attn_weightspresent_key_value。然后 hidden_statesresidual 相加,得到 hidden_states

然后 hidden_states 进入 post_attention_layernorm 进行norm。最后 hidden_states 进入 mlp 进行全连接操作,拿到 hidden_states。然后 hidden_statesresidual 相加,得到 hidden_states。最后输出 hidden_states

residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,**kwargs,
)
hidden_states = residual + hidden_states# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)return outputs

Llama Attention

31aaddd17134c8360db0117ba21d30b6.png

看代码首先映入眼帘的就是  Attention Is All You Need  好好好,很有精神!那我们接着往下看。

先来看 init 部分叭。

  • layer_idx : 这个就是第几个 DecoderLayers 层。不用关心。

  • attention_dropout : 用于dropout的概率。

  • hidden_size : 输入输出大小。

  • num_attention_heads : 多头注意力的头数。

  • head_dim : 多头注意力的维度 self.hidden_size // self.num_heads,和transformers中的一样。

  • num_key_value_heads : 用于key和value的头数。

其他的参数都在 LlamaConfig 中有默认值,可以直接使用,也可以直接去LlamaConfig的源码中看具体的解释,这里就不再多说。

再往下就是 q_projk_projv_projo_proj 四个矩阵(全连接层),耳熟能详了。

class LlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxif layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ""to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")self.attention_dropout = config.attention_dropoutself.hidden_size = config.hidden_sizeself.num_heads = config.num_attention_headsself.head_dim = self.hidden_size // self.num_headsself.num_key_value_heads = config.num_key_value_headsself.num_key_value_groups = self.num_heads // self.num_key_value_headsself.max_position_embeddings = config.max_position_embeddingsself.rope_theta = config.rope_thetaself.is_causal = Trueif (self.head_dim * self.num_heads) != self.hidden_size:raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {self.num_heads}).")self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)self._init_rope()

LlamaAttention forward

重头戏来了,attention forward 部分。

注意:其中有关于张量并行或者显存节省的部分我就直接省略了,直接看主要代码。这个笔记主要是分析llama的模型结构,并不讨论如何节省显存。

首先拿到 hidden_statesbatch_sizeseq_len 。然后把 hidden_states 丢入 q_projk_projv_proj 三个矩阵(全连接层),拿到 query_stateskey_statesvalue_states 。然后把 query_stateskey_statesvalue_states reshape 为下一步计算做准备。

将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果

key_statesvalue_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化

然后 attn_weights 加上 attention_mask,再 softmaxdropout。然后 attn_weightsvalue_states 相乘,把 attn_output reshape 为下一步计算做准备,最后把 attn_output 丢入 o_proj ,然后return就行了。

好了,至此。我觉得llama最激动人心的地方已经结束了。

# 获取 batch_size 和 seq_len
bsz, q_len, _ = hidden_states.size()# 把 hidden_states 丢入 q_proj、k_proj、v_proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)# 把 q_proj、k_proj、v_proj 的输出 reshape 为下一步计算做准备
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# 将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)# 首先,它将key_states和value_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)# 然后 attn_weights 加上 attention_mask
attn_weights = attn_weights + attention_mask# softmax + dropout
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# 然后 attn_weights 和 value_states 相乘
attn_output = torch.matmul(attn_weights, value_states)# 然后把 attn_output reshape 为下一步计算做准备
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)# 最后把 attn_output 丢入 o_proj
attn_output = self.o_proj(attn_output)# 返回 attn_output、attn_weights、present_key_value
return attn_output, attn_weights, past_key_value

LlamaMLP

c1cd4cf6f3c0e2a88c2f2b536e2fd10f.png

看完 attention 再看 MLP ,突然就觉得好简单了,哈哈哈。这部分代码比较少,就直接放到一起了。

x进来之后先进去up_proj和gate_proj,gate_proj进行激活,然后这俩再乘起来,丢进 down_proj。那直接放个图叭,这个过程有点简单了。

class LlamaMLP(nn.Module):def __init__(self, config):super().__init__()# 这俩不必多说self.config = configself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_size# 三个全连接层self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))return down_proj

LlamaRMSNorm

RMSNorm函数可以用以下数学公式表示:

其中:

  • 是层的输入。

  • 代表层的权重。

  • 是权重的数量。

  • 是一个小常数,用于数值稳定性(以避免除以零的情况)。

这种归一化有助于通过确保权重的规模不会变得过大或过小来稳定学习过程,这在具有许多层的深度学习模型中特别有用。

class LlamaRMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):"""LlamaRMSNorm is equivalent to T5LayerNorm"""super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):input_dtype = hidden_states.dtypehidden_states = hidden_states.to(torch.float32)variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.to(input_dtype)

参考:https://space.bilibili.com/45156039

8c2428ec014740c4d5aa5ee991a3cb14.png

干货学习,三连

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/297062.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java内存区域与内存溢出异常

Java与C++之间有一堵由内存分配和垃圾收集技术所围成的高墙,墙外面的人想进去,墙里面的人却想出来。 2.1 概述 对于从事C、C++程序开发的开发人员来说,在内存管理领域,他们即是拥有最高权力的“皇帝”,又是从事最基础工作的劳动人民——即拥有每一个对象的“所有权”,又…

MySQL部署之yum安装

MySQL https://www.mysql.com //mysql官网 yum安装步骤 yum安装 清理环境 [rootmysql ~]# yum erase mariadb mariadb-server mariadb-libs mariadb-devel -yuserdel -r mysql[rootmysql ~]# rm -rf /etc/my* && rm -rf /var/lib/mysql && rm -rf /use/bin/m…

计算机视觉基础(10)——深度学习与图像分类

前言 传统视觉算法采用手工设计特征与浅层模型,而手工设计特征依赖于专业知识,且泛化能力差。深度学习的出现改变了这一状况,为视觉问题提供了端到端的解决方案。在之前的课程中,我们已经学习了图像分类的传统知识。在本节课中&am…

便捷记账本,批量筛选需要的账目明细

你是否曾因为无法管理自己的财务而感到烦恼?如果你正在寻找一款简单易用的记账软件,那么晨曦记账本将是你的不二之选,一起来看看这款软件是怎么使用的吧。 所需工具: 一个【晨曦记账本】软件 操作步骤: 步骤1&…

机器学习之实验过程01

import pandas as pd import numpy as np import matplotlib.pyplot as plt data_path = /home/py/Work/labs/data/SD.csv # 请确保您的数据文件路径是正确的 df = pd.read_csv(data_path) df.head() # 创建散点图 # 创建散点图 plt.figure(figsize=(10, 6)) plt.scat…

拓扑排序算法总结

知识概览 求图的拓扑序是图的宽搜的一个很经典的应用,拓扑序列是针对有向图来说的。 拓扑序列的定义是: 如果说一个点的序列满足对于图中的每条有向边(x, y),x都出现在y的前面,那就称这个序列是这个图的拓扑序列。 备注&#xff…

【Linux】Linux常见指令解析上

目录 1. 前言2. ls指令3. pwd指令4. cd指令3.1 cd常见快捷指令 4. touch指令5. mkdir指令6. rmdir指令 && rm指令 (重要)6.1 rmdir指令6.2 rm指令 7. man指令 1. 前言 这篇文章我们将详细介绍一下Linux下常见的基本指令。 2. ls指令 语法: ls [选…

FPFA.一种二倍频电路代码描述以及测量详情

一、前言 1、因为需要倍频电路所以找了个二倍频的电路,通过fpga实际测量发现经过倍频后的电路峰值降低。不过这个也正常,因为该电路只要过触发点就会开始发生波形变化,而电路的触发值不是峰值。​​​​​​​ 2、继续对电路做倍频后信号做二…

linux操作系统——进程(二) 进程状态

进程状态 你真正的理解了进程的状态嘛?特别是操作系统教材中学过的进程状态,你真的理解了吗? 教材上关于进程状态的说明 下面我们以下图为例: 这是教材上对操作系统的说明,但是它并没有详细的说明,这些状态具体是什么&#xf…

51单片机拆字程序实验

一、实验内容 1.基本要求 熟悉51仿真系统;设计并单步调试,实现将R5中数值(初值为本人学号后两位)拆分成两位独立的数据分别存于R6,R7中; 2.扩展要求 将R6,R7中的被拆出来的一位HEX数据转换为可显示的ASCII编码&…

安全实践:保障 Kubernetes 生产环境的安全性

▲ 点击上方"DevOps和k8s全栈技术"关注公众号 Kubernetes(简称 K8s)是一个强大的容器编排平台,广泛应用于生产环境中。然而,与其功能强大相对应的是对安全性的高要求。在生产环境中,我们必须采取一系列措施来…

【微服务】springboot整合kafka-stream使用详解

目录 一、前言 二、kafka stream概述 2.1 什么是kafka stream 2.2 为什么需要kafka stream 2.2.1 对接成本低 2.2.2 节省资源 2.2.3 使用简单 2.3 kafka stream特点 2.4 kafka stream中的一些概念 2.5 Kafka Stream应用场景 三、环境准备 3.1 搭建zk 3.1.1 自定义d…