稀疏专家模型实现,关键流程分析

news/2025/3/27 16:40:12/文章来源:https://www.cnblogs.com/xiezhengcai/p/18789800
  1. 将数据压平
  2. 通过 nn.Linear(cfg.in_feature, cfg.expert_num) 得到专家权重
  3. 通过 torch.topk 得到 每个top上的权重,以及 以专家索引为value的 专家 (batch_size*seq_len,top_k) , 它表示每个token在不同top_k 上的权重,以及每个token在不同topk_上对应的专家
  4. 通过 F.one_hot “以专家索引为value” 将value变换为索引,值为0,1
  5. 通过 变换将专家索引提到第一位
  6. 通过 torch.nonzero 得到 tok_k 索引列表和 token 索引列表。 它表示不同的top_k与 token之间的一一对应关系
  7. 通过token索引列表从1中拿到需要处理的数据, 并通过专家去执行 (token, output_featue)
  8. 通过 tok_k 索引列表和 token 索引列表 去 top权重上去获得权重 (此时数据为1维度,需要unsqueeze)
    9 通过7和8的数据逐乘计算权重
  9. 用token 索引列表 和最终计算的权重 添加到目标输出结果中。

最终代码:

import torch
from torch import nn, Tensor
from torch.nn import functional as Ffrom xtransformer.moe.config import MoeConfig
from xtransformer.moe.expert import BasicExpertclass SparseMoe(nn.Module):def __init__(self, cfg: MoeConfig):super(SparseMoe, self).__init__()self.cfg = cfgself.expert_list = nn.ModuleList([BasicExpert(cfg.in_feature,cfg.out_feature,cfg.hidden_dim,) for _ in range(cfg.expert_num)])self.share_expert = nn.ModuleList([BasicExpert(cfg.in_feature,cfg.out_feature,cfg.hidden_dim,) for _ in range(cfg.share_num)])self.gate = nn.Linear(cfg.in_feature, cfg.expert_num)def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:# x shape (batch_size,seq_len,hidden_dim)# 由于每个专家的选中无法从x张量中抽离出新的张量,因为形状无法对齐, 先将x进行压平# flat_x shape(batch_size * seq_len,-1)flat_x = x.reshape(-1, x.size(-1))# 计算专家权重expert_weights = self.gate(flat_x)expert_weights = F.softmax(expert_weights, dim=-1)# topk_value, topk_idx shape (batch_size*seq_len,top_k)# topk_value 的值是专家权重, topk_idx 值是专家索引,索引我value都是经过排序的topk_value, topk_idx = expert_weights.topk(self.cfg.top_k, -1)# 重新初始化权重topk_value = topk_value / topk_value.sum(dim=-1, keepdim=True)# 由于专家索引在 topk_idx的value上,我们需要将其转换为索引上, 所以最好的办法是topk_idx的value进行one_hot 编码# 此时得到的 topk_idx shape (batch_size*seq_len,top_k,expert_num), 值为 0 和 1 , 1 代表选中的专家,,expert_num 表示专家索引topk_idx = F.one_hot(topk_idx, num_classes=self.cfg.expert_num)# topk_idx shape(expert_num,top_k,batch_size*seq_len)topk_idx = topk_idx.permute(2, 1, 0)# 初始化最终结果张量final_ret = torch.zeros([flat_x.size(0), cfg.out_feature], device=x.device)# 接下来的目的,是找对应类别的专家,去执行他们需要处理的token,并计算权重for expert_idx in range(self.cfg.expert_num):cur_expert = self.expert_list[expert_idx]# cur shape (top_k,batch_size*seq_len)  其值为0 和 1 ,0表示被选中,, 1 表示未被选中cur = topk_idx[expert_idx]# selected_x 为查找非0的所有索引,# 其返回值的元素个数等于cur_x 的维度,# 返回的第一个值代表第一个维度上的索引, 第二个值代表第二维度上的索引 ....# 每一个返回值都是一个元组, 元组的长度都相同selected_topk_idx, selected_token_idx = torch.nonzero(cur, as_tuple=True)# cur_x  shape(selected_token_idx,hidden_dim)cur_x = flat_x[selected_token_idx, :]# cur_weight shape (selected_tokens) 值为权重cur_weight = topk_value[selected_token_idx, selected_topk_idx]cur_weight = cur_weight.unsqueeze(dim=-1)# 专家执行 expert_ret shape(selected_token_idx,out_feature)expert_ret = cur_expert(cur_x)# 专家执行结果与权重进行计算# expert_ret shape(selected_token_idx,out_feature)expert_ret = expert_ret * cur_weightfinal_ret.index_add_(0, selected_token_idx, expert_ret)# reshape 到标准结果final_ret = final_ret.reshape(x.size(0), x.size(1), cfg.out_feature)# 计算共享专家的结果for share_expert in self.share_expert:final_ret = final_ret + share_expert(x)return final_ret, expert_weightsif __name__ == '__main__':cfg = MoeConfig()sm = SparseMoe(cfg)ret, _ = sm(torch.rand([2, 3, 1024]))print(ret.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/904873.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在SSD1306上显示动态表情符号位图

解锁您的SSD1306上充满活力的视觉效果!学习毫不费力地显示动态表情符号位图,并以风格增强您的项目。 在本教程中,我们将通过使用PCBX在线模拟环境在SSD1306 OLED显示器上显示位图图像的过程。本教程将介绍设置PCBX模拟,格式化位图数据,配置显示大小和管理图像延迟。步骤1:…

redis基础数据结构——ZipList

ZipList 基于特殊写法实现的双端链表,由一系列特殊编码的连续内存块组成,可以像deque一样在双端压入/弹出,并且时间复杂度在O(1) 整体ZL结构如下zlbytes(uint32):当前zl总的byte数。 zltail(uint32):尾结点的offset,指向的是最后一个entry的起始地址。 zllen(uint16):记…

day:28 postman——环境变量(依赖,关联接口)

一.接口的环境变量 (1)定义变量 可以将需要填写的值设为变量 变量设置:{{}}(2)添加环境变量 方法一:方法二:(3)查看环境变量(4)选择环境,执行二.依赖接口 先登录接口成功,生成cookie值,才能让后面接口依赖 cookie值是保持会话 查看cookie值方法 方法一:方法二:…

L1.1 技术和产品准备度

L1.1 技术和产品准备度 技术和产品准备度 技术与产品的演进 ​ 上面这张图展示了如何在技术尚未完全成熟时,启动产品开发,以及技术如何随着新需求或洞察逐步演进,并支持产品的更新换代。产品1.0:由先前研发的的技术3支撑,加上“产品开发可以在预期的技术开发成果的基础上提…

从故障响应到客户信赖:华为ITR流程的五大核心步骤与实战案例

华为究竟是如何在与西方巨头的激烈竞争中崭露头角、脱颖而出的呢?答案是:凭借卓越的服务。今天我们来探讨一下华为是如何通过卓越的服务赢得全球市场的。 一、华为的三件大事 华为前高管费敏曾经总结过,华为的业务可以分成三件大事:1. 开发产品:这就是 IPD 流程,负责从有…

提升生产效率的关键: ethercat转TCPIP智能通信

大家好。最近在数据互联互通方面,我们迎来了一个重要的突破。作为生产管理系统的核心组成部分,数据互联互通一直是一个亟待解决的挑战。我们知道,EtherCAT和TCP/IP是两种不同的通信协议,它们之间的互通性一直存在问题。不过,现在有一款新产品值得关注,这款产品能够实现Et…

Trae初体验

Trae(国际版)的Ai搭载Claude-3.7-Sonnet(完全免费且速度很快)和DeepSeek-R1以及V3(不存在服务器繁忙)以及GPT-4o Trae国服的Ai搭载DouBao和DeepSeek。用Claude-3.7-Sonnet 写一个简易的贪吃蛇小游戏:这个贪吃蛇游戏包含以下功能:使用方向键控制蛇的移动 吃到食物会增加长度和…

C# 从零开始使用Layui.Wpf库开发WPF客户端

一、简介最近需要开发一个桌面版的工具软件,之前用得更多的是Winform,作为一个全干工程师,我们也要兼顾下WPF,趁此机会再研究下开源控件库。MaQaQ:Winform真好用(有个HZHControls控件库,值得一看)。 二、准备工作找了下开源控件库,诸如MaterialDesignInXAML、HandyCon…

聚点和闭包中点的等价条件

聚点有以下等价描述: 闭包中点有以下等价描述:这些等价描述在与导集和闭包的证明中能起到很大的作用。下面是一个例子。

Itext5生成高质量、易识别、适合小尺寸标签打印的二维码

高质量、易识别、小尺寸二维码生成 1.增大二维码的原始尺寸(例如 1000 x 1000 或更大),再缩放为 PDF 所需的大小。这样可以保留更多像素细节,提高识别率。 2.降低容错级别到 L 或 M,如果你的内容不是特别长或复杂的话,这样能减少密集度。 3.优化缩放方式: • 使用 Buffe…

【Docker】安装部署jenkins

docker安装部署jenkinsdocker安装jenkins  1、下载jenkins  2、创建挂载目录  3、启动jenkins容器  4、验证jenkins是否启动成功  5、获取管理员密码  6、下载安装插件 docker安装jenkins【1】下载jenkins拉取jenkins镜像 docker pull jenkins/jenkins:2.426.2-lts…