超过GPT3.5?Mixtral 8*7B 模型结构分析

 Datawhale干货 

作者:宋志学,Datawhale成员

前言

2023年12月11日,Mistral AI团队发布了一款高质量的稀疏专家混合模型Mixtral 8x7B。

Mistral AI继续致力于向开发者社区提供最优秀的开放模型。在人工智能领域向前发展,需要采取超越重用众所周知的架构和训练范式的新技术路径。最重要的是,它需要让社区从原创模型中受益,以促进新的发明和用途。

Mixtral 8x7B是一款高质量的稀疏专家混合模型(SMoE),具有开放权重。采用Apache 2.0开源软件许可证。Mixtral在大多数基准测试中表现超过了Llama 2 70B,推断速度快6倍。它是目前拥有宽松许可证最强大的开放权重模型,并在成本/性能权衡方面是最佳模型。特别是在大多数标准基准测试中,它的表现匹配或超过了GPT3.5

9c087e22cce3da17cfb863d88ddfb8a3.png

Mixtral具有以下特点:

  • 优雅地处理32k标记的上下文。

  • 支持英语、法语、意大利语、德语和西班牙语。

  • 在代码生成方面表现出色。

  • 可以微调为一个遵循指令的模型,在MT-Bench上达到8.3的分数。


transformers 仓库中可以看到 mixtral 的源码,首先是 MixtralModel 类,继承自 PreTrainedModel ,这个类是所有模型的基类,包含了一些通用的方法,比如保存模型、加载模型、初始化权重等。具体目录是:src\transformers\models\mixtral\modeling_mixtral.py

继承关系为:MixtralModel -> MixtralPreTrainedModel -> PreTrainedModel

d84780d01dc5387228450fce9f5d5bd7.png

MixtralConfig

MixtralConfig 类继承自 PretrainedConfig ,这个类是所有配置类的基类,包含了一些通用的方法,比如保存配置、加载配置、初始化配置等。具体路径在 transformers 仓库的 src\transformers\models\mixtral\configuration_mixtral.py目录下。

可以使用如下代码直接创建模型的config对象:

config = MixtralConfig()

MixtralModel

1c88f66d464345c48ebcbce795f114d2.png

MixtralModel 初始化

如果你看过我上一篇 LLaMA开源大模型源码分析!的话,就会发现这里的初始化和llama模型的初始化非常相似,都是先初始化embed_tokens,然后初始化layers,最后初始化norm

  • 设置了模型的两个属性:padding_idx(用于指定填充标记的索引),vocab_size(词汇表的大小)

  • 初始化了模型的嵌入层、解码器层、归一化层

  • 嵌入层(nn.Embedding):模型使用嵌入层将输入的标记映射成密集的向量表示。

  • 解码器层(nn.ModuleList()):模型包含多个解码器层,这些层都是由 MixtralDecoderLayer 定义

  • 归一化层 MixtralRMSNorm:归一化层使用的是 Root Mean Square Layer Normalization(RMS Layer Norm),和llama使用的是一样的。

  • 设置了是否使用 gradient_checkpoint 主要是用来节省显存

  • 调用 post_init() 完成一些初始化和准备检查的代码

class MixtralModel(MixtralPreTrainedModel):"""Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]Args:config: MixtralConfig"""def __init__(self, config: MixtralConfig):super().__init__(config)self.padding_idx = config.pad_token_idself.vocab_size = config.vocab_sizeself.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)self.layers = nn.ModuleList([MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])self._attn_implementation = config._attn_implementationself.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)self.gradient_checkpointing = False# Initialize weights and apply final processingself.post_init()

可以看一下 post_init() 的代码,主要是初始化权重和gradient_checkpointing相关的一些事情。该方法在PreTrainedModel基类中,transformers中所有模型基本都继承这个类。

def post_init(self):"""A method executed at the end of each Transformer model initialization, to execute code that needs the model'smodules properly initialized (such as weight initialization)."""self.init_weights()self._backward_compatibility_gradient_checkpointing()

MixtralModel Forward

forward 部分的代码有点长,但其实大部分都是张量并行或者是节省显存相关的代码,对于理解模型结构来说可以直接忽略。

首先进来就是把 inputs_ids 进行向量化,然后拿到 hidden_states 。然后是存起来所有的hidden_states 进入 decoder_layer 再拿一个 hidden_states,作为下一轮 decoder_layerhidden_states 输入,最后给 hidden_states norm一下。如下代码所示:

# 向量化
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embedsfor decoder_layer in self.layers:#存起来所有的 hidden_statesif output_hidden_states:all_hidden_states += (hidden_states,)# 这里是decoder_layer 的forwardlayer_outputs = decoder_layer(hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_values,output_attentions=output_attentions,output_router_logits=output_router_logits,use_cache=use_cache,)# # 再拿一个 hidden_states,作为下一轮 decoder_layer 的 hidden_states 输入hidden_states = layer_outputs[0]# norm 一下
hidden_states = self.norm(hidden_states)

MixtralDecoderLayer

e1c1cedd81666cb4e9076fab33d20d99.png

MixtralDecoderLayer 初始化

好,来到了 moe 模型和 llama 模型最大区别的地方了,Mixtral 使用 MixtralSparseMoeBlock 模块代替了原有的 MLP 层, MLP 层还是在的,待会在后面我们再说。先来看初始化部分 DecoderLayer 做了什么事情。

  • hidden_size : 也就是在上面说的输入输出。

  • self_attn : 别看它写这么多啊,其实就是选一下用什么 attention 。看见大写字母不要怕,直接点进去看看怎么个事!

MIXTRAL_ATTENTION_CLASSES = {"eager": MixtralAttention,"flash_attention_2": MixtralFlashAttention2,"sdpa": MixtralSdpaAttention,
}
  • block_sparse_moe : moe稀疏矩阵,这个待会后面再说,输入输出都是 hidden_size 大小。

  • input_layernorm : MixtralRMSNorm 层,输入时候的norm

  • post_attention_layernorm : 丢入稀疏矩阵 block_sparse_moe 之前的操作。

class MixtralDecoderLayer(nn.Module):def __init__(self, config: MixtralConfig, layer_idx: int):super().__init__()self.hidden_size = config.hidden_size  # 隐藏层的大小self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)  # 自注意力机制self.block_sparse_moe = MixtralSparseMoeBlock(config)  # 稀疏混合块self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  # 输入层归一化self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  # 注意力之后的层归一化

MixtralDecoderLayer Forward

首先复制一份 hidden_statesresidual。然后 hidden_states 进入 input_layernorm 进行norm。

然后进入 self_attn 进行 attention 操作,拿到 hidden_statesself_attn_weightspresent_key_value

而后 hidden_statesresidual 相加,得到 hidden_states。此时再复制一份 residual 。然后 hidden_states 进入 post_attention_layernorm 进行norm。

来了,来了!这里 hidden_states 进入稀疏矩阵 block_sparse_moe 得到 hidden_states, router_logitshidden_statesresidual 相加,得到 hidden_states。最后输出 hidden_states

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,)hidden_states = residual + hidden_statesresidual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_statesoutputs = (hidden_states,)if output_attentions:outputs += (self_attn_weights,)if use_cache:outputs += (present_key_value,)if output_router_logits:outputs += (router_logits,)return outputs

MixtralAttention

我们先来看 Attention 部分嗷,稀疏矩阵留到最后压轴再看。

549d5c95caf55d8e94eca5400bdef207.png

MixtralAttention 初始化

好好好,首先映入眼帘的还是 Attention Is All You Need ,不忘初心,可以可以!

先来看 init 部分叭。

  • layer_idx : 这个就是第几个 DecoderLayers 层。不用关心。

  • attention_dropout : 用于dropout的概率。

  • hidden_size : 输入输出大小。

  • num_attention_heads : 多头注意力的头数。

  • head_dim : 多头注意力的维度 self.hidden_size // self.num_heads,和transformers中的一样。

  • num_key_value_heads : 用于key和value的头数。

其他的参数都在 MixtralConfig 中有默认值,可以直接使用,也可以直接去MixtralConfig的源码中看具体的解释,这里就不再多说。

再往下就是 q_projk_projv_projo_proj 四个矩阵(全连接层),耳熟能详了。

class MixtralAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformerand "Generating Long Sequences with Sparse Transformers"."""def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxif layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will ""to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")self.hidden_size = config.hidden_sizeself.num_heads = config.num_attention_headsself.head_dim = self.hidden_size // self.num_headsself.num_key_value_heads = config.num_key_value_headsself.num_key_value_groups = self.num_heads // self.num_key_value_headsself.max_position_embeddings = config.max_position_embeddingsself.rope_theta = config.rope_thetaself.is_causal = Trueself.attention_dropout = config.attention_dropoutif (self.head_dim * self.num_heads) != self.hidden_size:raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {self.num_heads}).")self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)self.rotary_emb = MixtralRotaryEmbedding(self.head_dim,max_position_embeddings=self.max_position_embeddings,base=self.rope_theta,)

MixtralAttention Forward

这里的 forward 函数就是 Attention 的核心部分了,我们来一点一点看。

注意:其中有关于张量并行或者显存节省的部分我就直接省略了,直接看主要代码。这个笔记主要是分析mixtral的模型结构,并不讨论如何节省显存。

首先获取 batch_sizeseq_len ,然后把 hidden_states 丢入 q_projk_projv_proj 三个矩阵,得到 query_stateskey_statesvalue_states 。然后把 query_stateskey_statesvalue_states reshape 为下一步计算做准备。

获取 kv_seq_len ,其实我觉得这步挺多余的,因为 kv_seq_len 就等于 self.num_key_value_heads

将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果。

key_statesvalue_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化。

然后softmaxdropout。然后 attn_weightsvalue_states 相乘,把 attn_output reshape 为下一步计算做准备,最后把 attn_output 丢入 o_proj ,然后return就行了。

# 获取 batch_size 和 seq_len
bsz, q_len, _ = hidden_states.size()# 把 hidden_states 丢入 q_proj、k_proj、v_proj
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)# 把 q_proj、k_proj、v_proj 的输出 reshape 为下一步计算做准备
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# 获取 kv_seq_len,其实我觉得这步挺多余的,因为 kv_seq_len 就等于 self.num_key_value_heads
kv_seq_len = key_states.shape[-2]# 将旋转位置嵌入应用于查询和键张量。使用了旋转位置嵌入的余弦和正弦部分,将它们与查询和键张量相乘,并将结果相加,从而实现旋转位置嵌入的效果
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)# 首先,它将key_states和value_states重复self.num_key_value_groups次。然后,使用torch.matmul()函数计算query_states和转置后的key_states之间的矩阵乘法。最后,将结果除以math.sqrt(self.head_dim)进行归一化
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)# softmax + dropout
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# 然后 attn_weights 和 value_states 相乘
attn_output = torch.matmul(attn_weights, value_states)# 然后把 attn_output reshape 为下一步计算做准备
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# 最后把 attn_output 丢入 o_proj
attn_output = self.o_proj(attn_output)# 返回 attn_output、attn_weights、past_key_value
return attn_output, attn_weights, past_key_value

MixtralSparseMoeBlock

来了,来了。MoE模型的核心,稀疏矩阵!

c34afb4d4b15ae3461b4a43828336b8b.png

MixtralSparseMoeBlock 初始化

首先来看看在初始化中,init做了什么事情。

  • hidden_dim : 输入输出维度大小。

  • ffn_dim : MLP 层的维度大小。

  • num_experts : 本地专家的数量。

  • top_k : 选择的专家数量。

  • gate : 门控层,输入是 hidden_dim ,输出是 num_experts

  • experts : 专家层,八个 MixtralBLockSparseTop2MLP 模块。(就是八个原来的MLP层)

class MixtralSparseMoeBlock(nn.Module):def __init__(self, config):super().__init__()self.hidden_dim = config.hidden_sizeself.ffn_dim = config.intermediate_sizeself.num_experts = config.num_local_expertsself.top_k = config.num_experts_per_tok# gatingself.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])

MixtralSparseMoeBlock Forward

  • 首先,输入的隐藏状态hidden_states经过重塑,以适应后续处理。

  • 使用门控层gate计算出每个隐藏状态对于各个专家的重要程度,得到router_logits

  • router_logits应用softmax函数,得到路由权重routing_weights

  • routing_weights中选出最相关的top_k个专家,并进行归一化。

  • 初始化最终的隐藏状态final_hidden_states

  • 对每个专家进行遍历,根据专家掩码expert_mask选出分配给当前专家的隐藏状态,经过专家层处理后,将结果累加到最终隐藏状态中。

  • 最后,将最终隐藏状态的形状重塑回原始形状,并返回。

看完了稀疏矩阵的数据流向,现在你还觉得MoE模型在推理的之后只有两个模型在运行嘛?哈哈哈,其实就是八个MLP层作为专家模型,实际上所有的八个MLP层都是在运行的。

# 首先获取隐藏状态的维度信息
batch_size, sequence_length, hidden_dim = hidden_states.shape
# 将隐藏状态的形状重塑为二维,便于后续处理
hidden_states = hidden_states.view(-1, hidden_dim)# router_logits用于计算每个专家对每个隐藏状态的重要程度
router_logits = self.gate(hidden_states)# 使用softmax函数计算路由权重,这些权重决定每个隐藏状态分配给每个专家的比例
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# 选择top_k个最相关的专家
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# 对路由权重进行归一化处理
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)# 将路由权重转换回输入数据类型
routing_weights = routing_weights.to(hidden_states.dtype)# 初始化最终隐藏状态
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)# 生成专家掩码,用于确定哪些隐藏状态分配给哪些专家
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)# 遍历所有的专家
for expert_idx in range(self.num_experts):# 获取当前专家的处理层expert_layer = self.experts[expert_idx]# 找出选中当前专家的隐藏状态索引idx, top_x = torch.where(expert_mask[expert_idx])# 如果没有隐藏状态被分配给当前专家,则继续下一个专家if top_x.shape[0] == 0:continue# 将索引转换为列表形式,以便高效处理top_x_list = top_x.tolist()idx_list = idx.tolist()# 获取并处理当前专家应处理的隐藏状态current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]# 将计算结果累加回最终隐藏状态中final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))# 将最终隐藏状态的形状重塑回原始的三维形状
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)# 返回最终的隐藏状态和路由逻辑结果
return final_hidden_states, router_logits

MixtralBLockSparseTop2MLP

这个就是所谓的专家模型,其实就是原来的MLP层而已。

首先初始胡三个线性层和一个激活层,然后就是前向传播部分了。hidden_states 经过第一个线性层,然后经过激活层,再与经过第三个线性层的hiden_states相乘,得到current_hidden_states

然后current_hidden_states经过第二个线性层,最后返回current_hidden_states

d3fd524060e144054084349b9b3cc693.png
class MixtralBLockSparseTop2MLP(nn.Module):def __init__(self, config: MixtralConfig):super().__init__()self.ffn_dim = config.intermediate_sizeself.hidden_dim = config.hidden_sizeself.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, hidden_states):current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)current_hidden_states = self.w2(current_hidden_states)return current_hidden_states

897714ac53dec5c184575cdf8ce66da3.png

干货学习,三连

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/418505.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AIGC - 视频生成模型的相关算法进展

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/135688206 视频生成技术确实是一个很有潜力的颠覆性技术领域,可以作为企业创新梯队的重点关注方向,最近发展很快&#xff…

SaaS多租户篇

文章目录 1. 多租户是什么2. 技术组件2.1 如何实现多租户的DB封装2.2 如何实现多租户的redis封装2.3 如何实现多租户的Web和Security封装 1. 多租户是什么 2. 技术组件 2.1 如何实现多租户的DB封装 2.2 如何实现多租户的redis封装 2.3 如何实现多租户的Web和Security封装

DBA技术栈MongoDB: 索引和查询优化

2.1 批量插入数据 单条数据插入db.collection.insertOne()多条数据插入db.collection.insertMany() db.inventory.insertMany( [{ item: "journal", qty: 25, size: { h: 14, w: 21, uom: "cm" }, status: "A" },{ item: "notebook"…

Mac book air 重新安装系统验证显示 untrusted_cert_title

环境: Mac Book Air macOS Sierra 问题描述: Mac book air 重新安装系统验证显示 untrusted_cert_title 解决方案: 1.终端输入命令行输入 date 会看到一个非常旧的日期 2.更改日期为当前时间 使用以下命令来设置日期和时间&#xff1a…

第7章面向对象设计常用的设计模式

7.1 设计模式概述 7.2 单例模式 (1)模式名称 单例模式。 (2)问题与分析 问: 对于调用者,如何才能做到确保代码中的某个类只存在一个实例,而且实例一旦创建,就可以向整个运行程序提供…

std::atomic

一、概述 std::atomic 是C11引入的一个模板类,用于提供原子操作的类型。在多线程编程中,当多个线程同时访问同一块数据时,可能会导致数据竞争和不确定的行为。std::atomic 可以用来创建原子类型的变量,保证对该变量的操作是原子的…

蓝桥杯练习题-穷举模拟

📑前言 本文主要是【穷举模拟】——蓝桥杯练习题-穷举模拟的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 🌄…

Spring DI

目录 什么是依赖注入 属性注入 构造函数注入 Setter 注入 依赖注入的优势 什么是依赖注入 依赖注入是一种设计模式,它通过外部实体(通常是容器)来注入一个对象的依赖关系,而不是在对象内部创建这些依赖关系。这种方式使得对象…

基于C++11的数据库连接池【C++/数据库/多线程/MySQL】

一、概述 概述:数据库连接池可提前把多个数据库连接建立起来,然后把它放到一个池子里边,就是放到一个容器里边进行维护。这样的话就能够避免数据库连接的频繁的创建和销毁,从而提高程序的效率。线程池其实也是同样的思路&#xf…

OCR识别网络CRNN理解与Pytorch实现

CRNN是2015年的论文“An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition”提出的图像字符识别网络,也是目前工业界使用较为广泛的一个OCR网络。论文地址:https://arxiv.org/…

OpenHarmony AI框架开发指导

一、概述 1、 功能简介 AI业务子系统是OpenHarmony提供原生的分布式AI能力的子系统。AI业务子系统提供了统一的AI引擎框架,实现算法能力快速插件化集成。 AI引擎框架主要包含插件管理、模块管理和通信管理模块,完成对AI算法能力的生命周期管理和按需部…

检索增强(RAG)的方式---重排序re-ranking

提升RAG:选择最佳嵌入Embedding&重排序Reranker模型 检索增强生成(RAG)技术创新进展:自我检索、重排序、前瞻检索、系统2注意力、多模态RAG RAG的re-ranking指的是对初步检索出来的候选段落或者文章,通过重新排序的方式来提升检索质量。…