- 将数据压平
- 通过 nn.Linear(cfg.in_feature, cfg.expert_num) 得到专家权重
- 通过 torch.topk 得到 每个top上的权重,以及 以专家索引为value的 专家 (batch_size*seq_len,top_k) ,
它表示每个token在不同top_k 上的权重,以及每个token在不同topk_上对应的专家
- 通过 F.one_hot “以专家索引为value” 将value变换为索引,值为0,1
- 通过 变换将专家索引提到第一位
- 通过 torch.nonzero 得到 tok_k 索引列表和 token 索引列表。
它表示不同的top_k与 token之间的一一对应关系
- 通过token索引列表从1中拿到需要处理的数据, 并通过专家去执行 (token, output_featue)
- 通过 tok_k 索引列表和 token 索引列表 去 top权重上去获得权重 (此时数据为1维度,需要unsqueeze)
9 通过7和8的数据逐乘计算权重 - 用token 索引列表 和最终计算的权重 添加到目标输出结果中。
最终代码:
import torch
from torch import nn, Tensor
from torch.nn import functional as Ffrom xtransformer.moe.config import MoeConfig
from xtransformer.moe.expert import BasicExpertclass SparseMoe(nn.Module):def __init__(self, cfg: MoeConfig):super(SparseMoe, self).__init__()self.cfg = cfgself.expert_list = nn.ModuleList([BasicExpert(cfg.in_feature,cfg.out_feature,cfg.hidden_dim,) for _ in range(cfg.expert_num)])self.share_expert = nn.ModuleList([BasicExpert(cfg.in_feature,cfg.out_feature,cfg.hidden_dim,) for _ in range(cfg.share_num)])self.gate = nn.Linear(cfg.in_feature, cfg.expert_num)def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:# x shape (batch_size,seq_len,hidden_dim)# 由于每个专家的选中无法从x张量中抽离出新的张量,因为形状无法对齐, 先将x进行压平# flat_x shape(batch_size * seq_len,-1)flat_x = x.reshape(-1, x.size(-1))# 计算专家权重expert_weights = self.gate(flat_x)expert_weights = F.softmax(expert_weights, dim=-1)# topk_value, topk_idx shape (batch_size*seq_len,top_k)# topk_value 的值是专家权重, topk_idx 值是专家索引,索引我value都是经过排序的topk_value, topk_idx = expert_weights.topk(self.cfg.top_k, -1)# 重新初始化权重topk_value = topk_value / topk_value.sum(dim=-1, keepdim=True)# 由于专家索引在 topk_idx的value上,我们需要将其转换为索引上, 所以最好的办法是topk_idx的value进行one_hot 编码# 此时得到的 topk_idx shape (batch_size*seq_len,top_k,expert_num), 值为 0 和 1 , 1 代表选中的专家,,expert_num 表示专家索引topk_idx = F.one_hot(topk_idx, num_classes=self.cfg.expert_num)# topk_idx shape(expert_num,top_k,batch_size*seq_len)topk_idx = topk_idx.permute(2, 1, 0)# 初始化最终结果张量final_ret = torch.zeros([flat_x.size(0), cfg.out_feature], device=x.device)# 接下来的目的,是找对应类别的专家,去执行他们需要处理的token,并计算权重for expert_idx in range(self.cfg.expert_num):cur_expert = self.expert_list[expert_idx]# cur shape (top_k,batch_size*seq_len) 其值为0 和 1 ,0表示被选中,, 1 表示未被选中cur = topk_idx[expert_idx]# selected_x 为查找非0的所有索引,# 其返回值的元素个数等于cur_x 的维度,# 返回的第一个值代表第一个维度上的索引, 第二个值代表第二维度上的索引 ....# 每一个返回值都是一个元组, 元组的长度都相同selected_topk_idx, selected_token_idx = torch.nonzero(cur, as_tuple=True)# cur_x shape(selected_token_idx,hidden_dim)cur_x = flat_x[selected_token_idx, :]# cur_weight shape (selected_tokens) 值为权重cur_weight = topk_value[selected_token_idx, selected_topk_idx]cur_weight = cur_weight.unsqueeze(dim=-1)# 专家执行 expert_ret shape(selected_token_idx,out_feature)expert_ret = cur_expert(cur_x)# 专家执行结果与权重进行计算# expert_ret shape(selected_token_idx,out_feature)expert_ret = expert_ret * cur_weightfinal_ret.index_add_(0, selected_token_idx, expert_ret)# reshape 到标准结果final_ret = final_ret.reshape(x.size(0), x.size(1), cfg.out_feature)# 计算共享专家的结果for share_expert in self.share_expert:final_ret = final_ret + share_expert(x)return final_ret, expert_weightsif __name__ == '__main__':cfg = MoeConfig()sm = SparseMoe(cfg)ret, _ = sm(torch.rand([2, 3, 1024]))print(ret.shape)