深度学习 | Transformer模型及代码实现

        Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

        


 

一、Transformer模型

 

1、模型结构

        首先介绍 Transformer 的整体结构,下图是 Transformer 用于中英文翻译的整体结构。

        

        可以看到 Transformer 由 Encoder 和 Decoder 两个部分组成,Encoder 和 Decoder 都包含 6 个 block。 6是随机选择的数字,也可以是其他的数字。我们可以将这个结构看成是串联在一起的电池组,彼此之间通过多次的非线性变换在不同的空间提取更多的信息。

        

        编码器的结构相同,但不共享权重。每个编码器具体来说由两个子层组成。

        自注意力层能够帮助编码器在编码特定单词时考虑句子中其他的单词,输出会输入到一个独立的前馈神经网络中,每个编码器都有相同的前馈神经网络运行。

        解码器又嵌入了一个 Encoder-Decoder注意力层,帮助解码器专注于输入句子的相关部分,类似于seq2seq中的注意力机制。

        在自然语言处理中,注意首先需要使用嵌入算法将每个单词转换为向量,每个词都会嵌入到一个512维的向量中,512是一个超参数,代表训练数据集中句子的最大长度。

        注意 每个词的流动路径都是独立的。词语之间的依赖关系是通过自注意力层来表达的,前向反馈层没有相互之间的计算,所以前向反馈层可以并行计算。

         

        

2、编码器

1)、位置嵌入 Embedding

        Transformer 中单词的输入表示 x 由单词 Embedding 和位置 Embedding 相加得到。

        

        其中单词的 Embedding (嵌入向量) 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。

        而位置Embedding(位置编码向量)的则是通过定义的正余弦函数来得到的。

        

        其中,pos 表示单词在句子中的位置,d 表示 PE的维度 (与词 Embedding 一样),2i 表示偶数的维度,2i+1 表示奇数维度 (即 2i≤d, 2i+1≤d)。

        使用这种公式计算 PE 有以下的好处:

        使 PE 能够适应比训练集里面所有句子更长的句子,假设训练集里面最长的句子是有 20 个单词,突然来了一个长度为 21 的句子,则使用公式计算的方法可以计算出第 21 位的 Embedding。可以让模型容易地计算出相对位置,对于固定长度的间距 k,PE(pos+k) 可以用 PE(pos) 计算得到。因为 Sin(A+B) = Sin(A)Cos(B) + Cos(A)Sin(B), Cos(A+B) = Cos(A)Cos(B) - Sin(A)Sin(B)。将单词的词 Embedding 和位置 Embedding 相加,就可以得到单词的表示向量 x,x 就是Transformer 的输入。

        我们之前学习的RNN模型中,并没有使用过位置Embedding,那为什么Transformer中要引入位置信息呢?

        这是因为 Transformer 不采用 RNN 的结构,而是使用全局信息,不能利用单词的顺序信息,而这部分信息对于NLP来说非常重要 所以 Transformer 中使用位置 Embedding 保存单词在序列中的相对或绝对位置。

2)Transformer 中的多头注意力机制

        在 Transformer 论文中,通过添加多头注意力机制,进一步完善了自注意力层。

        也就是说一个输入向量会分别生成不同的 Q K V 的组合,从而得到不同的注意力权重 Z,再拼接到一起,这样的话扩展了模型关注不同位置的能力,为注意力层提供了表示子空间。有点类似CNN中不同的卷积核,用于捕捉输入数据不同维度特征,这样在句子比较长 容易产生歧义或一词多义的情况下也能更好的提取特征信息。

        这些 Q K V 变换矩阵都是通过训练得到的,Transformer中就是用了八个头。

        

        举个例子:当我们在对下面句子中的 it 进行编码时

        不同的颜色代表不同头,颜色深浅代表自注意力权重,

        以it为例,编码时关注重点The animal、tired。

         

        下面是一个演示,不同颜色表示不同的头Q K V,类似CNN中不同的卷积核,捕获不同中的注意力权重。

        

        论文中给出的模型架构如下:

        

        左侧为编码器块,右侧为解码器块。橙色框中的部分就是Multi-Head Attention,它是由多个Self-Attention组成的,可以看到 Encoder block 包含一个Multi-Head Attention,而Decoder block 包含两个Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个Add & Norm 层,这里Add 表示残差连接 (Residual Connection) 用于防止网络退化,Norm 表示 Layer Normalization,用于对每一层的激活值进行归一化。

        关于多头注意力和自注意力的内容,我们上一节已经介绍过,这里不再赘述。

3)残差结构

         在每个编码器子层和解码器子层中都使用了残差连接和归一化,他们可以让网络更容易学习复杂特征,从而避免梯度消失和爆炸的问题,同时训练更稳定,层归一化可以加速模型的收敛过程,有助于提高模型的泛化能力和稳定性。

         

3、解码器

        解码器需要同时链接编码器的输出。就像RNN一样,换句话说每一步解码都要使用编码器的输出来生成序列中下一个单词的表示。

        通过连接编码器和解码器模型可以有效的利用编码器对输入序列的理解从而生成更加准确的输出序列,同时也可以避免信息丢失的问题,从而提高模型的整体性能和稳定性。

        

 

4、编解码器协同工作

         编码器首先处理输入序列,然后将顶部的输出转换为 一组注意力向量 K 和 V,这些向量被每个解码器在他的编码器-解码器注意力层中使用,用于帮助解码器将注意力集中在输入序列中的恰当位置。

        在解码器阶段,每一步都会从输出序列中输出一个元素,这个元素的生成既依赖于之前的输出同时也依赖于编码器-解码器注意力层中的注意力向量,通过多次迭代计算,解码器可以逐渐生成完整的输出序列。

        整个过程中,编码器和解码器的协同工作是通过多头注意力机制和残差链接等技术实现的,这使得 Transformer 模型在各种NLP任务中取得了很好的性能。

        重复上述步骤,直到一个特殊的到达符号表示解码器已经完成了输出。可以说解码器的自注意力层和编码器的自注意力层是非常类似的,但是运行方式不同,解码器的自注意力层只允许关注输出序列中之前的位置,以避免信息泄露和信息未来化的问题,在每个解码器中,输入序列要经过多头注意力机制和前馈神经网络进行编码,然后通过编码器-解码器注意力层与编码器的输出 再进行交互,最后生成解码器最后的输出序列。整个过程中位置编码向量也被用来保留单词在序列中的位置信息。

        

 

5、线性层和softmax层

         对于解码器的输出,他是一个浮点向量,如何把他转换成一个单词呢?这就是线性层和softmax层的工作了。

        线性层:一个全连接神经网络,将解码器堆叠生成的向量映射到一个更大的向量,通常称为logits向量。

        每个单元格对应一个单词分数。

        Softmax层:将单词分数转化为概率,选择具有最高概率的单词作输出。

        

 

6、工作流程

        Transformer模型的工作流程主要包含三个步骤:

        第一步:获取输入句子的每一个单词的表示向量 X,X 由单词的 Embedding 和单词位置的 Embedding 相加得到。

         

        第二步:将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x ,传入 Encoder 中,经过 6 个 Encoder block 后可以得到句子所有单词的编码信息矩阵 C,如下图 2。

        单词向量矩阵用 X(n×d)表示, n 是句子中单词个数,d 是表示向量的维度 (论文中 d=512)。每一个 Encoder block 输出的矩阵维度与输入完全一致。

                

        第三步:将 Encoder 输出的编码信息矩阵 C传递到 Decoder 中,Decoder 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。

        需要特别说明的是Transformer中使用了多头注意力机制。

        

        上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 Begin,预测第一个单词 I;然后输入翻译开始符 Begin 和单词 I,预测单词 have,以此类推。这是 Transformer 使用时候的大致流程,接下来是里面各个部分的细节。 

 

7、优缺点总结

        Transformer 与 RNN 不同,可以比较好地并行训练。

        Transformer 本身是不能利用单词的顺序信息的,因此需要在输入中添加位置 Embedding,否则 Transformer 就是一个词袋模型了。

        Transformer 的重点是 Self-Attention 结构,其中用到的 Q, K, V矩阵通过输出进行线性变换得到。

        Transformer 中 Multi-Head Attention 中有多个 Self-Attention,可以捕获单词之间多种维度上的相关系数 attention score。

        

 



 

二、Transformer模型代码实现

 

         

1、数据准备

(1)代码包引入

import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
from torch import optim
import random
from tqdm import *
import matplotlib.pyplot as plt

(2)数据集生成

# 数据集生成
soundmark = ['ei',  'bi:',  'si:',  'di:',  'i:',  'ef',  'dʒi:',  'eit∫',  'ai', 'dʒei', 'kei', 'el', 'em', 'en', 'əu', 'pi:', 'kju:','ɑ:', 'es', 'ti:', 'ju:', 'vi:', 'd∧blju:', 'eks', 'wai', 'zi:']alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']t = 1000 #总条数
r = 0.9   #扰动项
seq_len = 6
src_tokens, tgt_tokens = [],[] #原始序列、目标序列列表for i in range(t):src, tgt = [],[]for j in range(seq_len):ind = random.randint(0,25)src.append(soundmark[ind])if random.random() < r:tgt.append(alphabet[ind])else:tgt.append(alphabet[random.randint(0,25)])src_tokens.append(src)tgt_tokens.append(tgt)
src_tokens[:2], tgt_tokens[:2]
([['kju:', 'kei', 'em', 'i:', 'vi:', 'pi:'],['bi:', 'kju:', 'eit∫', 'eks', 'ef', 'di:']],[['q', 'k', 'm', 'e', 'v', 'p'], ['b', 'q', 'h', 'x', 'f', 'd']])
from collections import Counter  # 计数类flatten = lambda l: [item for sublist in l for item in sublist]  # 展平数组
# 构建词表
class Vocab:def __init__(self, tokens):self.tokens = tokens  # 传入的tokens是二维列表self.token2index = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}  # 先存好特殊词元# 将词元按词频排序后生成列表self.token2index.update({token: index + 4for index, (token, freq) in enumerate(sorted(Counter(flatten(self.tokens)).items(), key=lambda x: x[1], reverse=True))})# 构建id到词元字典self.index2token = {index: token for token, index in self.token2index.items()}def __getitem__(self, query):# 单一索引if isinstance(query, (str, int)):if isinstance(query, str):return self.token2index.get(query, 3)elif isinstance(query, (int)):return self.index2token.get(query, '<unk>')# 数组索引elif isinstance(query, (list, tuple)):return [self.__getitem__(item) for item in query]def __len__(self):return len(self.index2token)

(3)数据集构造

from torch.utils.data import DataLoader, TensorDataset#实例化source和target词表
src_vocab, tgt_vocab = Vocab(src_tokens), Vocab(tgt_tokens)
src_vocab_size = len(src_vocab)  # 源语言词表大小
tgt_vocab_size = len(tgt_vocab)  # 目标语言词表大小#增加开始标识<bos>和结尾标识<eos>
encoder_input = torch.tensor([src_vocab[line + ['<pad>']] for line in src_tokens])
decoder_input = torch.tensor([tgt_vocab[['<bos>'] + line] for line in tgt_tokens])
decoder_output = torch.tensor([tgt_vocab[line + ['<eos>']] for line in tgt_tokens])# 训练集和测试集比例8比2,batch_size = 16
train_size = int(len(encoder_input) * 0.8)
test_size = len(encoder_input) - train_size
batch_size = 16# 自定义数据集函数
class MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):super(MyDataSet, self).__init__()self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return self.enc_inputs.shape[0]def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]train_loader = DataLoader(MyDataSet(encoder_input[:train_size], decoder_input[:train_size], decoder_output[:train_size]), batch_size=batch_size)
test_loader = DataLoader(MyDataSet(encoder_input[-test_size:], decoder_input[-test_size:], decoder_output[-test_size:]), batch_size=1)

2、模型构建

(1)位置编码

def get_sinusoid_encoding_table(n_position, d_model):def cal_angle(position, hid_idx):return position / np.power(10000, 2 * (hid_idx // 2) / d_model)def get_posi_angle_vec(position):return [cal_angle(position, hid_j) for hid_j in range(d_model)]sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # 偶数位用正弦函数sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # 奇数位用余弦函数return torch.FloatTensor(sinusoid_table)
print(get_sinusoid_encoding_table(30, 512))
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,0.0000e+00,  1.0000e+00],[ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,1.0366e-04,  1.0000e+00],[ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,2.0733e-04,  1.0000e+00],...,[ 9.5638e-01, -2.9214e-01,  7.9142e-01,  ...,  1.0000e+00,2.7989e-03,  1.0000e+00],[ 2.7091e-01, -9.6261e-01,  9.5325e-01,  ...,  1.0000e+00,2.9026e-03,  1.0000e+00],[-6.6363e-01, -7.4806e-01,  2.9471e-01,  ...,  1.0000e+00,3.0062e-03,  1.0000e+00]])

(2)掩码操作

# mask掉没有意义的占位符
def get_attn_pad_mask(seq_q, seq_k):                       # seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len]batch_size, len_q = seq_q.size()batch_size, len_k = seq_k.size()pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)          # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]return pad_attn_mask.expand(batch_size, len_q, len_k)# mask掉未来信息
def get_attn_subsequence_mask(seq):                               # seq: [batch_size, tgt_len]attn_shape = [seq.size(0), seq.size(1), seq.size(1)]subsequence_mask = np.triu(np.ones(attn_shape), k=1)          # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]subsequence_mask = torch.from_numpy(subsequence_mask).byte()  #  [batch_size, tgt_len, tgt_len]return subsequence_mask 

(3)注意力计算函数

# 缩放点积注意力计算
class ScaledDotProductAttention(nn.Module):def __init__(self):super(ScaledDotProductAttention, self).__init__()def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v]attn_mask: [batch_size, n_heads, seq_len, seq_len]'''scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.attn = nn.Softmax(dim=-1)(scores)context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]return context, attn#多头注意力计算
class MultiHeadAttention(nn.Module):def __init__(self):super(MultiHeadAttention, self).__init__()self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)def forward(self, input_Q, input_K, input_V, attn_mask):'''input_Q: [batch_size, len_q, d_model]input_K: [batch_size, len_k, d_model]input_V: [batch_size, len_v(=len_k), d_model]attn_mask: [batch_size, seq_len, seq_len]'''residual, batch_size = input_Q, input_Q.size(0)# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k]K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k]V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v]attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]# context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]output = self.fc(context) # [batch_size, len_q, d_model]return nn.LayerNorm(d_model)(output + residual), attn

(4)构建前馈网络

class PoswiseFeedForwardNet(nn.Module):def __init__(self):super(PoswiseFeedForwardNet, self).__init__()self.fc = nn.Sequential(nn.Linear(d_model, d_ff, bias=False),nn.ReLU(),nn.Linear(d_ff, d_model, bias=False))def forward(self, inputs):                             # inputs: [batch_size, seq_len, d_model]residual = inputsoutput = self.fc(inputs)return nn.LayerNorm(d_model)(output + residual)   # 残差 + LayerNorm

(5)编码器模块

# 编码器层
class EncoderLayer(nn.Module):def __init__(self):super(EncoderLayer, self).__init__()self.enc_self_attn = MultiHeadAttention()  # 多头注意力self.pos_ffn = PoswiseFeedForwardNet()  # 前馈网络def forward(self, enc_inputs, enc_self_attn_mask):'''enc_inputs: [batch_size, src_len, d_model]enc_self_attn_mask: [batch_size, src_len, src_len]'''# enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,Venc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]return enc_outputs, attn# 编码器模块
class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Embedding(src_vocab_size, d_model)self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_vocab_size, d_model), freeze=True)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):'''enc_inputs: [batch_size, src_len]'''word_emb = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]pos_emb = self.pos_emb(enc_inputs) # [batch_size, src_len, d_model]enc_outputs = word_emb + pos_embenc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:# enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns

(6)解码器模块

# 解码器层
class DecoderLayer(nn.Module):def __init__(self):super(DecoderLayer, self).__init__()self.dec_self_attn = MultiHeadAttention()self.dec_enc_attn = MultiHeadAttention()self.pos_ffn = PoswiseFeedForwardNet()def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):'''dec_inputs: [batch_size, tgt_len, d_model]enc_outputs: [batch_size, src_len, d_model]dec_self_attn_mask: [batch_size, tgt_len, tgt_len]dec_enc_attn_mask: [batch_size, tgt_len, src_len]'''# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)# dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]return dec_outputs, dec_self_attn, dec_enc_attn# 解码器模块
class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_vocab_size, d_model),freeze=True)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''dec_inputs: [batch_size, tgt_len]enc_intpus: [batch_size, src_len]enc_outputs: [batsh_size, src_len, d_model]'''word_emb = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]pos_emb = self.pos_emb(dec_inputs) # [batch_size, tgt_len, d_model]dec_outputs = word_emb + pos_embdec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]dec_self_attn_subsequent_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len]dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) # [batch_size, tgt_len, tgt_len]dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]dec_self_attns, dec_enc_attns = [], []for layer in self.layers:# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len,src_len]dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)dec_self_attns.append(dec_self_attn)dec_enc_attns.append(dec_enc_attn)return dec_outputs, dec_self_attns, dec_enc_attns

(7)Transformer模型

class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.encoder = Encoder()self.decoder = Decoder()self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)def forward(self, enc_inputs, dec_inputs):'''enc_inputs: [batch_size, src_len]dec_inputs: [batch_size, tgt_len]'''# tensor to store decoder outputs# outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)# enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]enc_outputs, enc_self_attns = self.encoder(enc_inputs)# dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

3、模型训练

d_model = 512   # 字 Embedding 的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度 
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8
num_epochs = 50 # 训练50轮
# 记录损失变化
loss_history = []model = Transformer()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.99)for epoch in tqdm(range(num_epochs)):total_loss = 0for enc_inputs, dec_inputs, dec_outputs in train_loader:'''enc_inputs: [batch_size, src_len]dec_inputs: [batch_size, tgt_len]dec_outputs: [batch_size, tgt_len]'''# enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)# outputs: [batch_size * tgt_len, tgt_vocab_size]outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)loss = criterion(outputs, dec_outputs.view(-1))optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss/len(train_loader)loss_history.append(avg_loss)print('Epoch:', '%d' % (epoch + 1), 'loss =', '{:.6f}'.format(avg_loss))
  2%|▏         | 1/50 [00:21<17:16, 21.15s/it]Epoch: 1 loss = 2.6337624%|▍         | 2/50 [00:42<17:06, 21.39s/it]Epoch: 2 loss = 2.0630026%|▌         | 3/50 [01:04<16:48, 21.47s/it]Epoch: 3 loss = 1.8669448%|▊         | 4/50 [01:25<16:16, 21.23s/it]Epoch: 4 loss = 1.80278310%|█         | 5/50 [01:45<15:49, 21.10s/it]Epoch: 5 loss = 1.64321712%|█▏        | 6/50 [02:07<15:27, 21.07s/it]Epoch: 6 loss = 1.80347114%|█▍        | 7/50 [02:27<15:04, 21.02s/it]Epoch: 7 loss = 1.51879416%|█▌        | 8/50 [02:48<14:41, 20.99s/it]Epoch: 8 loss = 1.63284018%|█▊        | 9/50 [03:10<14:23, 21.06s/it]Epoch: 9 loss = 1.44673020%|██        | 10/50 [03:31<14:02, 21.06s/it]Epoch: 10 loss = 1.34034822%|██▏       | 11/50 [03:52<13:40, 21.04s/it]Epoch: 11 loss = 1.36691724%|██▍       | 12/50 [04:13<13:20, 21.06s/it]Epoch: 12 loss = 1.49971526%|██▌       | 13/50 [04:34<13:01, 21.12s/it]Epoch: 13 loss = 1.37144628%|██▊       | 14/50 [04:55<12:41, 21.14s/it]Epoch: 14 loss = 1.38049830%|███       | 15/50 [05:16<12:19, 21.14s/it]Epoch: 15 loss = 1.29818332%|███▏      | 16/50 [05:37<11:53, 20.99s/it]Epoch: 16 loss = 1.10751234%|███▍      | 17/50 [05:57<11:27, 20.85s/it]Epoch: 17 loss = 1.01535536%|███▌      | 18/50 [06:18<11:04, 20.76s/it]Epoch: 18 loss = 0.89157338%|███▊      | 19/50 [06:39<10:41, 20.69s/it]Epoch: 19 loss = 1.03515740%|████      | 20/50 [06:59<10:19, 20.64s/it]Epoch: 20 loss = 1.05994342%|████▏     | 21/50 [07:20<09:58, 20.64s/it]Epoch: 21 loss = 0.99534744%|████▍     | 22/50 [07:40<09:38, 20.65s/it]Epoch: 22 loss = 0.82873046%|████▌     | 23/50 [08:01<09:18, 20.68s/it]Epoch: 23 loss = 0.71740348%|████▊     | 24/50 [08:22<08:59, 20.77s/it]Epoch: 24 loss = 0.76887050%|█████     | 25/50 [08:43<08:39, 20.80s/it]Epoch: 25 loss = 0.71392752%|█████▏    | 26/50 [09:04<08:18, 20.75s/it]Epoch: 26 loss = 0.79791854%|█████▍    | 27/50 [09:24<07:57, 20.74s/it]Epoch: 27 loss = 0.68024656%|█████▌    | 28/50 [09:45<07:36, 20.76s/it]Epoch: 28 loss = 0.61177058%|█████▊    | 29/50 [10:06<07:16, 20.77s/it]Epoch: 29 loss = 0.81035560%|██████    | 30/50 [10:27<06:57, 20.86s/it]Epoch: 30 loss = 0.53748762%|██████▏   | 31/50 [10:48<06:37, 20.93s/it]Epoch: 31 loss = 0.48465064%|██████▍   | 32/50 [11:09<06:15, 20.86s/it]Epoch: 32 loss = 0.44703366%|██████▌   | 33/50 [11:30<05:54, 20.83s/it]Epoch: 33 loss = 0.39907268%|██████▊   | 34/50 [11:51<05:34, 20.90s/it]Epoch: 34 loss = 0.37964970%|███████   | 35/50 [12:12<05:13, 20.92s/it]Epoch: 35 loss = 0.27082372%|███████▏  | 36/50 [12:32<04:52, 20.91s/it]Epoch: 36 loss = 0.33787874%|███████▍  | 37/50 [12:53<04:30, 20.81s/it]Epoch: 37 loss = 0.23544076%|███████▌  | 38/50 [13:14<04:09, 20.77s/it]Epoch: 38 loss = 0.33739378%|███████▊  | 39/50 [13:35<03:49, 20.85s/it]Epoch: 39 loss = 0.26019180%|████████  | 40/50 [13:56<03:28, 20.89s/it]Epoch: 40 loss = 0.21008482%|████████▏ | 41/50 [14:17<03:09, 21.03s/it]Epoch: 41 loss = 0.16861684%|████████▍ | 42/50 [14:38<02:47, 20.97s/it]Epoch: 42 loss = 0.21360786%|████████▌ | 43/50 [14:58<02:25, 20.82s/it]Epoch: 43 loss = 0.11055188%|████████▊ | 44/50 [15:19<02:04, 20.74s/it]Epoch: 44 loss = 0.18356290%|█████████ | 45/50 [15:39<01:43, 20.62s/it]Epoch: 45 loss = 0.09517292%|█████████▏| 46/50 [16:00<01:22, 20.57s/it]Epoch: 46 loss = 0.13238794%|█████████▍| 47/50 [16:20<01:01, 20.52s/it]Epoch: 47 loss = 0.16380596%|█████████▌| 48/50 [16:41<00:40, 20.49s/it]Epoch: 48 loss = 0.15219598%|█████████▊| 49/50 [17:01<00:20, 20.49s/it]Epoch: 49 loss = 0.086681
100%|██████████| 50/50 [17:22<00:00, 20.84s/it]Epoch: 50 loss = 0.085496
plt.plot(loss_history)
plt.ylabel('train loss')
plt.show()

4、模型预测

model.eval()
translation_results = []correct = 0
error = 0for enc_inputs, dec_inputs, dec_outputs in test_loader:'''enc_inputs: [batch_size, src_len]dec_inputs: [batch_size, tgt_len]dec_outputs: [batch_size, tgt_len]'''# enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)# outputs: [batch_size * tgt_len, tgt_vocab_size]outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)# pred形状为 (seq_len, batch_size, vocab_size) = (1, 1, vocab_size)# dec_outputs, dec_self_attns, dec_enc_attns = model.decoder(dec_inputs, enc_inputs, enc_output)outputs = outputs.squeeze()pred_seq = []for output in outputs:next_token_index = output.argmax().item()if next_token_index == tgt_vocab['<eos>']:breakpred_seq.append(next_token_index)pred_seq = tgt_vocab[pred_seq]tgt_seq = dec_outputs.squeeze().tolist()# 需要注意在<eos>之前截断if tgt_vocab['<eos>'] in tgt_seq:eos_idx = tgt_seq.index(tgt_vocab['<eos>'])tgt_seq = tgt_vocab[tgt_seq[:eos_idx]]else:tgt_seq = tgt_vocab[tgt_seq]translation_results.append((' '.join(tgt_seq), ' '.join(pred_seq)))for i in range(len(tgt_seq)):if i >= len(pred_seq) or pred_seq[i] != tgt_seq[i]:error += 1else:correct += 1print(correct/(correct+error))
0.3333333333333333
translation_results
[('h x n y e k', 'h y y y k'),('y l z k i t', 't i t j i t y'),('t s x e e v', 's s v e e v'),('e g a m t h', 'f i h h h'),...................

 


参考

Chapter-11/11.7 Transformer代码实现.ipynb · 梗直哥/Deep-Learning-Code - Gitee.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/313720.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

合伙企业法关于合伙企业的要求

合伙协议可以载明合伙企业的经营期限和合伙人争议的解决方式。 合伙协议经全体合伙人签名、盖章后生效。合伙人依照合伙协议享有权利&#xff0c;承担责任。 经全体合伙人协商一致&#xff0c;可以修改或者补充合伙协议。 申请合伙企业设立登记&#xff0c;应当向企业登记机关提…

CEC2017(Python):麻雀搜索算法SSA求解CEC2017(提供Python代码)

一、CEC2017简介 参考文献&#xff1a; [1]Awad, N. H., Ali, M. Z., Liang, J. J., Qu, B. Y., & Suganthan, P. N. (2016). “Problem definitions and evaluation criteria for the CEC2017 special session and competition on single objective real-parameter numer…

大甩卖-(CWRU)轴承故障诊数据集和代码全家桶

Python-凯斯西储大学&#xff08;CWRU&#xff09;轴承数据解读与分类处理 Python轴承故障诊断 (一)短时傅里叶变换STFT Python轴承故障诊断 (二)连续小波变换CWT_pyts 小波变换 故障-CSDN博客 Python轴承故障诊断 (三)经验模态分解EMD_轴承诊断 pytorch-CSDN博客 Pytorch…

基于C#的机械臂欧拉角与旋转矩阵转换

欧拉角概述 机器人末端执行器姿态描述方法主要有四种&#xff1a;旋转矩阵法、欧拉角法、等效轴角法和四元数法。所以&#xff0c;欧拉角是描述机械臂末端姿态的重要方法之一。 关于欧拉角的历史&#xff0c;由来已久&#xff0c;莱昂哈德欧拉用欧拉角来描述刚体在三维欧几里…

IBM介绍?

IBM&#xff0c;全名国际商业机器公司&#xff08;International Business Machines Corporation&#xff09;&#xff0c;是一家全球知名的美国科技公司。它成立于1911年&#xff0c;总部位于美国纽约州阿蒙克市&#xff08;Armonk&#xff09;&#xff0c;是世界上最大的信息…

240101-5步MacOS自带软件无损快速导出iPhone照片

硬件准备&#xff1a; iphone手机Mac电脑数据线 操作步骤&#xff1a; Step 1: 找到并打开MacOS自带的图像捕捉 Step 2: 通过数据线将iphone与电脑连接Step 3&#xff1a;iphone与电脑提示“是否授权“&#xff1f; >>> “是“Step 4&#xff1a;左上角选择自己的设…

Redis:原理+项目实战——Redis实战1(session实现短信登录(并剖析问题))

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位大四、研0学生&#xff0c;正在努力准备大四暑假的实习 &#x1f30c;上期文章&#xff1a;Redis&#xff1a;原理速成项目实战——Redis的Java客户端 &#x1f4da;订阅专栏&#xff1a;Redis速成 希望文章对你们有所帮助…

产品经理学习-从0-1搭建策略产品

从0-1搭建策略产品 目录&#xff1a; 回顾策略产品 如何从0-1搭建策略产品 回顾策略产品 之前也了解过从产品实施的角度来看&#xff0c;策略就是针对问题的解决方案&#xff0c;在互联网时代更集中体现在2个维度&#xff1a;业务场景和数据应用 如何从0-1搭建策略产品 我们…

黑马程序员SSM框架-SpringBoot

视频连接&#xff1a;SpringBoot-01-SpringBoot工程入门案例开发步骤_哔哩哔哩_bilibili SpringBoot简介 入门程序 也可以基于官网创建项目。 SpringBoot项目快速启动 下面的插件将项目运行所需的依赖jar包全部加入到了最终运行的jar包中&#xff0c;并将入口程序指定。 Spri…

SpringCloud-高级篇(九)

&#xff08;1&#xff09;Seata高可用 我们学习了Seata的各种用法了&#xff0c;Seata的服务是单节点部署的&#xff0c;这个服务如果挂了&#xff0c;整个事务都没有办法完了&#xff0c;下面我们学习Seata的高可用的知识。 实现高可用&#xff0c;还是比较简单&#xff0c;…

modelsim安装使用

目录 modelsim 简介 modelsim 简介 ModelSim 是三大仿真器公司之一mentor的产品&#xff0c;他可以模拟行为、RTL 和门级代码 - 通过独立于平台的编译提高设计质量和调试效率。单内核模拟器技术可在一种设计中透明地混合 VHDL 和 Verilog&#xff0c;常用在fpga 的仿真中。 #…

合伙企业有哪些分类

合伙企业分为&#xff1a;普通合伙企业和有限合伙企业。其中&#xff0c;普通合伙企业又包含特殊的普通合伙企业。 1、普通合伙企业由2人以上普通合伙人(没有上限规定)组成。 普通合伙企业中&#xff0c;合伙人对合伙企业债务承担无限连带责任。 特殊的普通合伙企业中&#xf…