【Transformer从零开始代码实现 pytoch版】(五)总架构类的实现

Transformer总架构

在这里插入图片描述
在实现完输入部分、编码器、解码器和输出部分之后,就可以封装各个部件为一个完整的实体类了。

【Transformer从零开始代码实现 pytoch版】(一)输入部件:embedding+positionalEncoding

【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

【Transformer从零开始代码实现 pytoch版】(四)输出部件:Linear+softmax

编码器-解码器总结构代码实现

class EncoderDecoder(nn.Module):""" 编码器解码器架构实现、定义了初始化、forward、encode和decode部件"""def __init__(self, encoder, decoder, source_embed, target_embed, generator):""" 传入五大部件参数:param encoder: 编码器:param decoder: 解码器:param source_embed: 源数据embedding函数:param target_embed: 目标数据embedding函数:param generator: 输出部分类被生成器对象"""super(EncoderDecoder, self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = source_embedself.tgt_embed = target_embedself.generator = generator					# 生成器后面会专门用到def forward(self, source, target, source_mask, target_mask):""" 构建数据流入流出:param source: 源数据:param target: 目标数据:param source_mask: 源数据掩码张量:param target_mask: 目标数据掩码张量:return:"""# 注意这里先用的encode和decode函数,又才在其函数里面,再用了encoder和decoderreturn self.decode(self.encode(source, source_mask), source_mask, target, target_mask)def encode(self, source, source_mask):""" 编码函数,编码部件:param source: 源数据张量:param source_mask: 源数据的掩码张量:return: 经过解码器的输出"""return self.encoder(self.src_embed(source), source_mask)def decode(self, memory, source_mask, target, target_mask):""" 解码函数,解码部件:param memory:编码器的输出QV:param source_mask:源数据的掩码张量:param target:目标数据:param target_mask:目标数据的掩码张量:return:"""return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

示例

# 输入参数
vocab_size = 1000
size = d_model = 512# 编码器部分
dropout = 0.2
d_ff = 64				# 隐藏层参数
head = 8				# 注意力头数
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
encoder_layer = EncoderLayer(size, c(attn), c(ff), dropout)
encoder_N = 8
encoder = Encoder(encoder_layer, encoder_N)# 解码器部分
dropout = 0.2
d_ff = 64
head = 8
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
decoder_layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)
decoder_N = 8
decoder = Decoder(decoder_layer, decoder_N)# 用了nn的embedding作为输入示意
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = Generator(d_model, vocab_size)# 输入张量和掩码张量
source = target = torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]])
source_mask = target_mask = torch.zeros(2, 4, 4)# 实例化编码器-解码器,再带入参数实现
ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_res = ed(source, target, source_mask, target_mask)
print(f"ed_res: {ed_res}\n shape:{ed_res.shape}")ed_res: tensor([[[-0.1861,  0.0849, -0.3015,  ...,  1.1753, -1.4933,  0.2484],[-0.3626,  1.3383,  0.1739,  ...,  1.1304,  2.0266, -0.5929],[ 0.0785,  1.4932,  0.3184,  ..., -0.2021, -0.2330,  0.1539],[-0.9703,  1.1944,  0.1763,  ...,  0.1586, -0.6066, -0.6147]],[[-0.9216, -0.0309, -0.6490,  ...,  1.0177,  0.5574,  0.4873],[-1.4097,  0.6678, -0.6708,  ...,  1.1176,  0.1959, -1.2494],[-0.3204,  1.2794, -0.4022,  ...,  0.6319, -0.4709,  1.0520],[-1.3238,  1.1470, -0.9943,  ...,  0.4026,  1.0911,  0.1327]]],grad_fn=<AddBackward0>)shape:torch.Size([2, 4, 512])

编码器-解码器模型构建函数

def make_model(source_vocab, target_vocab, N=6, d_model=512, d_ff=2048, head=8, dropout=0.1):""" 用于构建模型:param source_vocab: 源数据词汇总数:param target_vocab: 目标词汇总数:param N: 解码器/解码器堆叠层数:param d_model: 词嵌入维度:param d_ff: 前馈全连接层隐藏层维度:param dropout: 置0比率:return: 返回构建编码器-解码器模型"""# 拷贝函数,来保证拷贝的函数彼此之间相互独立,不受干扰c = copy.deepcopy# 实例化多头注意力attn = MultiHeadedAttention(head, d_model)# 实例化全连接层ff = PositionwiseFeedForward(d_model, d_ff, dropout)# 实例化位置编码类,得到对象positionposition = PositionalEncoding(d_model, dropout)model = EncoderDecoder(Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),nn.Sequential(Embedding(d_model, source_vocab), c(position)),nn.Sequential(Embedding(d_model, source_vocab), c(position)),Generator(d_model, target_vocab))# 模型结构构建完成后,初始化模型中的参数for p in model.parameters():# 这里判定当参数维度大于1的时候,则会将其初始化成一个服从均匀分布的矩阵if p.dim() > 1:nn.init.xavier_normal(p)        # 生成服从正态分布的数,默认为U(-1, 1),更改第二个参数可以改值return model

示例

source_vocab = target_vocab = 11
N = 6
res = make_model(source_vocab, target_vocab, N)
print(res)EncoderDecoder((encoder): Encoder((layers): ModuleList((0-5): 6 x EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(feed_forward): PositionwiseFeedForward((w1): Linear(in_features=512, out_features=2048, bias=True)(w2): Linear(in_features=2048, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False))(sublayer): ModuleList((0-1): 2 x SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1, inplace=False)))))(norm): LayerNorm())(decoder): Decoder((layers): ModuleList((0-5): 6 x DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(src_attn): MultiHeadedAttention((linears): ModuleList((0-3): 4 x Linear(in_features=512, out_features=512, bias=True))(dropout): Dropout(p=0.1, inplace=False))(feed_forward): PositionwiseFeedForward((w1): Linear(in_features=512, out_features=2048, bias=True)(w2): Linear(in_features=2048, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False))(sublayer): ModuleList((0-2): 3 x SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1, inplace=False)))))(norm): LayerNorm())(src_embed): Sequential((0): Embedding((lut): Embedding(512, 11))(1): PositionalEncoding((dropout): Dropout(p=0.1, inplace=False)))(tgt_embed): Sequential((0): Embedding((lut): Embedding(512, 11))(1): PositionalEncoding((dropout): Dropout(p=0.1, inplace=False)))(generator): Generator((project): Linear(in_features=512, out_features=11, bias=True))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/171085.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速入门笔记软件『Obsidian』

前言 Obsidian 是基于 Markdown 语法的笔记软件&#xff0c;界面简洁&#xff0c;使用简单&#xff0c;功能实用&#xff0c;支持跨平台数据同步&#xff0c;实现基于双向链接的知识图谱&#xff0c;同时提供各种各样的扩展主题和插件 本文将会详细讲解笔记软件 Obsidian 的安…

微信小程序使用阿里巴巴矢量图标

一&#xff0c;介绍 微信小程序使用图标有两种方式&#xff0c;一种是在线获取&#xff0c;一种是下载到本地使用&#xff0c; 第一种在线获取的有个缺点就是图标是灰色的&#xff0c;不能显示彩色图标&#xff0c;而且第一种是每次请求资源的&#xff0c;虽然很快&#xff0…

STM32-EXTI中断

EXTI简介 EXTI&#xff08;Extern Interrupt&#xff09;外部中断 EXTI可以监测指定GPIO口的电平信号&#xff0c;当其指定的GPIO口产生电平变化时&#xff0c;EXTI将立即向NVIC发出中断申请&#xff0c;经过NVIC裁决后即可中断CPU主程序&#xff0c;使CPU执行EXTI对应的中断程…

具名挂载和匿名挂载

匿名卷挂载 &#xff1a; -v 的时候只指定容器内的路径 如下面这个&#xff1a;/etc/nginx 1.docker run -d -P --name nginx -v /etc/nginx nginx 2.查看所有卷 docker volume ls 这里发现&#xff0c;这就是匿名挂载&#xff0c;只指定容器内的路径&#xff0c;没有指定…

CLIP:万物分类(视觉语言大模型)

本文来着公众号“AI大道理” ​ 论文地址&#xff1a;https://arxiv.org/abs/2103.00020 传统的分类模型需要先验的定义固定的类别&#xff0c;然后经过CNN提取特征&#xff0c;经过softmax进行分类。然而这种模式有个致命的缺点&#xff0c;那就是想加入新的一类就得重新定义…

学习网络编程No.9【应用层协议之HTTPS】

引言&#xff1a; 北京时间&#xff1a;2023/10/29/7:34&#xff0c;好久没有在周末早起了&#xff0c;该有的困意一点不少。伴随着学习内容的深入&#xff0c;知识点越来越多&#xff0c;并且对于爱好刨根问底的我来说&#xff0c;需要了解的知识就像一座大山&#xff0c;压得…

Java自学第11课:电商项目(4)重新建立项目

经过前几节的学习&#xff0c;我们已经找到之前碰到的问题的原因了。那么下面接着做项目学习。 1 新建dynamic web project 建立时把web.xml也生成下&#xff0c;省的右面再添加。 会询问是否改为java ee环境&#xff1f;no就行&#xff0c;其实改过来也是可以的。这个不重要。…

web前端开发第3次Dreamweave课堂练习/html练习代码《网页设计语言基础练习案例》

目标图片&#xff1a; 文字素材&#xff1a; 网页设计语言基础练习案例 ——几个从语义上和文字相关的标签 * h标签&#xff08;h1~h6&#xff09;&#xff1a;用来定义网页的标题&#xff0c;成对出现。 * p标签&#xff1a;用来设置网页的段落&#xff0c;成对出现。 * b…

​软考-高级-系统架构设计师教程(清华第2版)【第3章 信息系统基础知识(p120~159)-思维导图】​

软考-高级-系统架构设计师教程&#xff08;清华第2版&#xff09;【第3章 信息系统基础知识(p120~159)-思维导图】 课本里章节里所有蓝色字体的思维导图

Python数据容器(序列操作)

序列 1.什么是序列 序列是指&#xff1a;内容连续、有序。可以使用下标索引的一类数据容器 列表、元组、字符串。均可以视为序列 2.序列的常用操作 - 切片 语法&#xff1a;序列[起始下标:结束下标:步长]起始下标表示从何处开始&#xff0c;可以留空&#xff0c;留空视作从…

C语言--1,5,10人民币若干,现在需要18元,一共有多少种?

今天小编给大家分享一下穷举法的一道典型例题 一.题目描述 1,5,10人民币若干,现在需要18元,一共有多少种? 二.思路分析 总共有18块钱&#xff0c;设1元有x张&#xff0c;5元有y张&#xff0c;10元有z张&#xff0c;则有表达式&#xff1a;x5y10z18&#xff0c;穷举法最重要的…

常见面试题-Redis底层的SDS、ZipList、ListPack

Redis 的 SDS 了解吗&#xff1f; 答&#xff1a; Redis 创建了 SDS&#xff08;simple dynamic string&#xff09; 的抽象类型作为 String 的默认实现 SDS 的结构如下&#xff1a; struct sdshdr {// 字节数组&#xff0c;用于保存字符串char buf[];// buf[]中已使用字节…