Transformer实战 单词预测

  •    🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

一、定义模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import Dataset
import math, os, torchclass TransformerModel(nn.Module):def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):super().__init__()self.pos_encoder = PositionalEncoding(d_model, dropout)# 编码器层堆栈encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)# 编码器堆栈. pytorch已经实现了Transformer编码器层的堆栈,这里直接调用即可self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)self.embedding = nn.Embedding(ntoken, d_model)self.d_model = d_modelself.linear = nn.Linear(d_model, ntoken)self.init_weights()# 初始化权重def init_weights(self) -> None:initrange = 0.1self.embedding.weight.data.uniform_(-initrange, initrange)self.linear.bias.data.zero_()self.linear.weight.data.uniform_(-initrange, initrange)def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:"""Arguments:src: Tensor, 形状为 [seq_len, batch_size]src_mask: Tensor, 形状为 [seq_len, seq_len]Returns:最终的 Tensor, 形状为 [seq_len, batch_size, ntoken]"""src = self.embedding(src) * math.sqrt(self.d_model)src = self.pos_encoder(src)output = self.transformer_encoder(src, src_mask)output = self.linear(output)return outputclass PositionalEncoding(nn.Module):def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):super().__init__()self.dropout = nn.Dropout(p=dropout)# 位置编码器的初始化部分position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe = torch.zeros(max_len, 1, d_model)pe[:, 0, 0::2] = torch.sin(position * div_term)pe[:, 0, 1::2] = torch.cos(position * div_term)# 注册为持久状态变量,不参与参数优化self.register_buffer('pe', pe)def forward(self, x: Tensor) -> Tensor:"""Arguments:x: Tensor, 形状为 [seq_len, batch_size, embedding_dim]Returns:最终的 Tensor, 形状为 [seq_len, batch_size, embedding_dim]"""x = x + self.pe[:x.size(0)]return self.dropout(x)

wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

二、加载数据集

wikitext2_dir = "d:/wikitext-2-v1/wikitext-2"# Modify the data processing function to read from the local file
def data_process(file_path: str) -> Tensor:with open(file_path, 'r', encoding='utf-8') as file:data = [torch.tensor(vocab(tokenizer(line)), dtype=torch.long) for line in file]return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))# Load train, validation, and test data from local files
train_file = os.path.join(wikitext2_dir, "wiki.train.tokens")
val_file = os.path.join(wikitext2_dir, "wiki.valid.tokens")
test_file = os.path.join(wikitext2_dir, "wiki.test.tokens")train_data = data_process(train_file)
val_data = data_process(val_file)
test_data = data_process(test_file)# 使用数据处理函数处理数据集
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)# 设置设备优先使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 将数据集分批的函数
def batchify(data: Tensor, bsz: int) -> Tensor:# 计算批次大小nbatch = data.size(0) // bsz# 裁剪掉多余的部分使得能够完全分为批次data = data.narrow(0, 0, nbatch * bsz)# 重新整理数据维度为[批次, 批次大小]data = data.view(bsz, -1).t().contiguous()# 将数据移动到指定设备return data.to(device)# 批次大小
batch_size = 20
eval_batch_size = 10# 应用batchify函数分批处理训练集、验证集和测试集
train_data = batchify(train_data, batch_size)  # 结果为 [序列长度, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:"""生成批次数据参数:source: Tensor, 形状为 `[full_seq_len, batch_size]`i: int, 批次索引值返回:tuple (data, target).- data包含输入 [seq_len, batch_size],- target包含标签 [seq_len * batch_size]"""seq_len = min(bptt, len(source) - 1 - i)data = source[i:i+seq_len]target = source[i+1:i+1+seq_len].reshape(-1)return data, target

三、实例初始化

ntokens = len(vocab) # 词汇表的大小
emsize = 200         # 嵌入维度
nhid = 200           # nn.TransformerEncoder 中间层的维度
nlayers = 2          # nn.TransformerEncoder层的数量
nhead = 2            # nn.MultiheadAttention 头的数量
dropout = 0.2        # 丢弃率# 初始化 Transformer 模型,并将其发送到指定设备
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

四、训练模型

import time# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
lr = 5.0 # 学习率
optimizer = torch.optim.SGD(model.parameters(), lr=lr) # 使用SGD优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95) # 学习率衰减# 训练函数
def train(model: nn.Module) -> None:model.train() # 开启训练模式total_loss = 0.log_interval = 200 # 每隔200个batch打印一次日志start_time = time.time()num_batches = len(train_data) // bpttfor batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):data, targets = get_batch(train_data, i)output = model(data)output_flat = output.view(-1, ntokens)loss = criterion(output_flat, targets)optimizer.zero_grad() # 梯度清零loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪optimizer.step() # 更新参数total_loss += loss.item()if batch % log_interval == 0 and batch > 0:lr = scheduler.get_last_lr()[0] # 获取当前学习率ms_per_batch = (time.time() - start_time) * 1000 / log_intervalcur_loss = total_loss / log_intervalppl = math.exp(cur_loss)print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | 'f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | 'f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')total_loss = 0start_time = time.time()# 评估函数
def evaluate(model: nn.Module, eval_data: Tensor) -> float:model.eval() # 开启评估模式total_loss = 0.with torch.no_grad():for i in range(0, eval_data.size(0) - 1, bptt):data, targets = get_batch(eval_data, i)output = model(data)output_flat = output.view(-1, ntokens)total_loss += criterion(output_flat, targets).item()return total_loss / (len(eval_data) - 1)

训练函数train通过多次迭代数据,并使用梯度下降来更新模型的权重。它还包括了每个日志间隔打印损失和困惑度(perplexity,常用于语言模型的评估指标)。评估函数evaluate用于计算模型在验证集或测试集上的性能,但不会进行参数更新。代码还展示了如何使用学习率调度器来随着训练进行逐步减小学习率。

best_val_loss = float('inf') # 初始设置最佳验证集损失为无穷大
epochs = 1 # 设置训练的总轮数为1
best_model_params = None # 用于存储最佳模型参数# 使用临时目录存储模型参数
with TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, "best_model_params.pt")# 循环遍历每个epochfor epoch in range(1, epochs + 1):epoch_start_time = time.time()train(model)val_loss = evaluate(model, val_data) # 在验证集上评估当前模型print('-' * 89)elapsed = time.time() - epoch_start_timeprint(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | valid loss {val_loss:5.2f} | 'f'valid ppl {math.exp(val_loss):8.2f}')print('-' * 89)# 检查当前epoch的验证集损失是否为最佳if val_loss < best_val_loss:best_val_loss = val_loss # 更新最佳验证集损失best_model_params = model.state_dict() # 保存最佳模型参数# 保存有最佳验证集损失的模型参数torch.save(best_model_params, best_model_params_path)scheduler.step() # 更新学习率# 加载最佳模型参数,以便在测试集上进行评估或进一步训练
model.load_state_dict(torch.load(best_model_params_path))

五、评估模型

test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)print('-' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | 'f'test ppl {test_ppl:8.2f}')
print('-' * 89)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/659933.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在config.json文件中配置出来new mars3d.graphic.PolylineCombine({大量线合并渲染类型的geojson图层

在config.json文件中配置出来new mars3d.graphic.PolylineCombine({大量线合并渲染类型的geojson图层 问题场景&#xff1a; 1.浏览官网示例的时候图层看到大量线数据合并渲染的示例 2.矢量数据较大量级的时候&#xff0c;这种时候怎么在config.json文件中尝试配置呢&#x…

大数据运维之数据质量管理

第1章 数据质量管理概述 1.1 数据质量管理定义 数据质量管理&#xff08;Data Quality Management&#xff09;&#xff0c;是指对数据从计划、获取、存储、共享、维护、应用、消亡生命周期的每个阶段里可能引发的各类数据质量问题&#xff0c;进行识别、度量、监控、预警等一…

【好书推荐8】《智能供应链:预测算法理论与实战》

【好书推荐8】《智能供应链&#xff1a;预测算法理论与实战》 写在最前面编辑推荐内容简介作者简介目录精彩书摘前言/序言我为什么要写这本书这本书能带给你什么 致谢 &#x1f308;你好呀&#xff01;我是 是Yu欸 &#x1f30c; 2024每日百字篆刻时光&#xff0c;感谢你的陪伴…

ARP学习及断网攻击

1.什么是ARP ARP&#xff08;Address Resolution Protocol&#xff09;是一种用于在IPv4网络中将IP地址映射到MAC地址的协议。在计算机网络中&#xff0c;每个网络接口都有一个唯一的MAC地址&#xff08;Media Access Control address&#xff09;&#xff0c;用于识别网络设备…

tomcat部署

1.客户端和服务器端的交互过程 客户端发送请求给服务器 由服务器中的服务器软件拦截请求 根据请求调动相应的Java业务逻辑执行相关的处理 我们前面知道Java代码的运行势必提前将其装载在JVM上 而服务器软件一般都是由Java代码编写 所以两者都要装载在JVM上 而Java业务逻辑装载…

个人学习资源整理

文章目录 视频相关stl源码讲解相关 网站相关CPP网站 视频相关 stl源码讲解相关 跳转 网站相关 CPP网站 https://cplusplus.com/ https://gcc.gnu.org/

C语言实验-函数与模块化程序设计

一&#xff1a; 编写函数fun&#xff0c;其功能是&#xff1a;输入一个正整数&#xff0c;将其每一位上为偶数的数取出重新构成一个新数并输出。主函数负责输入输出&#xff0c;如输入87653142&#xff0c;则输出8642。&#xff08;main函数->fun函数&#xff09; #define _…

解决RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

下图说明在一瞬间我的GPU就被占满了 我的模型在训练过程中遇到了 CUDA 相关的错误&#xff0c;这是由于 GPU资源问题或内存不足导致的。这类错误有时候也可能是由于某些硬件兼容性问题或驱动程序问题引起的。 为了解决这个问题&#xff0c;可以尝试以下几个解决方案&#xff1a…

实验14 MVC

二、实验项目内容&#xff08;实验题目&#xff09; 编写代码&#xff0c;掌握MVC的用法。【参考课本 例1 】 三、源代码以及执行结果截图&#xff1a; example7_1.jsp&#xff1a; <% page contentType"text/html" %> <% page pageEncoding "ut…

【C/C++】动态内存管理(C:malloc,realloc,calloc,free || C++:new,delete)

&#x1f525;个人主页&#xff1a; Forcible Bug Maker &#x1f525;专栏&#xff1a; C | | C语言 目录 前言C/C内存分布C语言中的动态内存管理&#xff1a;malloc/realloc/realloc/freemallocrealloccallocfree C中的动态内存管理&#xff1a;new/deletenew和delete操作内…

2-4 任务:等差数列求和

本次实战的目标是计算1到100的累加和。我们将使用Java编程语言&#xff0c;通过三种不同的循环结构&#xff08;for循环、while循环和do-while循环&#xff09;来实现这个任务。在每个循环结构中&#xff0c;我们将逐步累加数字&#xff0c;并在最后输出结果。 首先&#xff0…

从零开始构建大语言模型(MEAP)

原文&#xff1a;annas-archive.org/md5/c19a4ef8ab1664a3c5a59d52651430e2 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 一、理解大型语言模型 本章包括 大型语言模型&#xff08;LLM&#xff09;背后的基本概念的高层次解释 探索 ChatGPT 类 LLM 源自的 Transfo…