AttentionFreeTransformer 源码解析(一):AFTFull、AFTSimple、AFTLocal

我觉得源码写的很好懂,我就不加注释了,直接上计算流程图。

AFTFull

在这里插入图片描述

class AFTFull(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT FullNumber of heads is 1 as done in the paper'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))nn.init.xavier_uniform_(self.wbias)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length'''From the paper'''Q_sig = torch.sigmoid(Q)temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))Yt = torch.mul(Q_sig, weighted)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Yt

AFTSimple

在这里插入图片描述

class AFTSimple(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT FullNumber of Heads is 1 as done in the paper.'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)'''From the paper'''weights = torch.mul(torch.softmax(K, 1), V).sum(dim=1, keepdim=True)Q_sig = torch.sigmoid(Q)Yt = torch.mul(Q_sig, weights)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Yt

AFTLocal

在这里插入图片描述

class AFTLocal(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64, s=256):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT Fulls: the window size used for AFT-Local in the paperNumber of heads is 1 as done in the paper'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))self.max_seqlen = max_seqlenself.s = snn.init.xavier_uniform_(self.wbias)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)self.wbias = nn.Parameter(torch.Tensor([[self.wbias[i][j] if math.fabs(i-j) < self.s else 0 for j in range(self.max_seqlen)] for i in range(self.max_seqlen)]))temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length'''From the paper'''Q_sig = torch.sigmoid(Q)temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))Yt = torch.mul(Q_sig, weighted)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Yt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/60235.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

谷歌发布RT-2大模型,让机器人像人类那样思考

原创 | 文 BFT机器人 大语言模型是指基于深度学习技术的大规模预训练模型&#xff0c;它能够通过学习大量的文本数据来生成人类类似的语言表达&#xff0c;机器人可以通过对大量的语言数据进行学习&#xff0c;从中掌握人类的语言表达方式&#xff0c;进而能够更好地与人进行交…

java 企业工程管理系统软件源码 自主研发 工程行业适用 em

​ 工程项目管理软件&#xff08;工程项目管理系统&#xff09;对建设工程项目管理组织建设、项目策划决策、规划设计、施工建设到竣工交付、总结评估、运维运营&#xff0c;全过程、全方位的对项目进行综合管理 工程项目各模块及其功能点清单 一、系统管理 1、数据字典&#…

【UE4 RTS】06-Camera Edge Scroll

前言 本篇实现的效果是当玩家将鼠标移至屏幕边缘时&#xff0c;视野会相应的上下左右移动 效果 步骤 1. 打开玩家控制器“RTS_PlayerController_BP”&#xff0c;在类默认值中设置如下选项 新建一个宏&#xff0c;命名为“EdgeSroll”&#xff0c; 添加两个输入和三个输出&a…

科技资讯|苹果手机版Vision Pro头显专利曝光,内嵌苹果手机使用

根据美国商标和专利局&#xff08;USPTO&#xff09;公示的清单&#xff0c;苹果公司近日获得了一项头显相关的技术专利&#xff0c;展示了一款亲民款 Vision Pro 头显&#xff0c;可以将 iPhone 放置在头显内部充当屏幕。 根据patentlyapple 媒体报道&#xff0c;这是苹果公司…

AP2915DC-DC降压恒流驱动IC LED电源驱动芯片 汽车摩托电动车灯

AP2915 是一款可以一路灯串切换两路灯串的降压 恒流驱动器,高效率、外围简单、内置功率管&#xff0c;适用于 5-80V 输入的高精度降压 LED 恒流驱动芯片。内置功 率管输出功率可达 12W&#xff0c;电流 1.2A。 AP2915 一路灯亮切换两路灯亮&#xff0c;其中一路灯亮可 以全亮&a…

virtualBox桥接模式下openEuler镜像修改IP地址、openEule修改IP地址、openEule设置IP地址

安装好openEuler后,设置远程登入前,必不可少的一步,主机与虚拟机之间的通信要解决,下面给出详细步骤: 第一步:检查虚拟机适配器模式:桥接模式 第二步:登入虚拟机修改IP cd /etc/sysconfig/network-scripts vim ifcfg-enpgs3 没有vim的安装或者用vi代替:sudo dnf …

工业互联网发展在即 博晨(BOCHEN)攻克“卡脖子”难题

5G时代的到来&#xff0c;正在悄然掀起一场智能化技术改革的风暴。工业互联网未来一定要走向制造智能化&#xff0c;这可能是我们未来工业互联网推动工业系统新生态的核心问题。”中国电子信息行业联合会专家委员会主任董云庭就曾表示。目前&#xff0c;工业互联网已经覆盖至国…

剪切、复制、粘贴事件

剪切、复制、粘贴事件 oncopy 事件在用户拷贝元素上的内容时触发。onbeforecut 事件在用户剪切文本&#xff0c;且文本还未删除时触发触发。oncut 事件在用户剪切元素的内容时触发。onbeforepaste 事件在用户向元素中粘贴文本之前触发。onpaste 事件在用户向元素中粘贴文本时触…

使用Flask.Request的方法和属性,获取get和post请求参数(二)

1、Flask中的request 在Python发送Post、Get等请求时&#xff0c;我们使用到requests库。Flask中有一个request库&#xff0c;有其特有的一些方法和属性&#xff0c;注意跟requests不是同一个。 2、Post请求&#xff1a;request.get_data() 用于服务端获取客户端请求数据。注…

RocketMQ发送消息失败:error CODE: 14 DESC: service not available now, maybe disk full

在执行业务时&#xff0c;发现MQ控制台没有查询到消息&#xff0c;在日志中发现消息发送失败&#xff0c;报错error CODE: 14 DESC: service not available now, maybe disk full 分析报错应该是磁盘空间不足&#xff0c;导致broker不能进行正常的消息存储刷盘&#xff0c;去查…

WPF实战项目十一(API篇):待办事项功能api接口

1、新建ToDoController.cs继承基础控制器BaseApiController&#xff0c;但是一般业务代码不写在控制器内&#xff0c;业务代码写在Service&#xff0c;先新建统一返回值格式ApiResponse.cs&#xff1a; public class ApiResponse{public ApiResponse(bool status, string mess…

ssm学院党员管理系统源码和论文PPT

ssm学院党员管理系统源码和论文PPT002 开发工具&#xff1a;idea 数据库mysql5.7(mysql5.7最佳) 数据库链接工具&#xff1a;navcat,小海豚等 开发技术&#xff1a;java ssm tomcat8.5 选题意义、价值和目标&#xff1a; 随着鄂尔多斯应用技术学院招生规模的不断扩大&…