【Transformer从零开始代码实现 pytoch版】(六)模型基本测试运行

模型基本测试及运行

在这里插入图片描述

(1)构建数据生成器

def data_generator(V, batch, num_batch):""" 用于随机生成copy任务的数据:param V: 随机生成数字的最大值+1:param batch: 每次输送给模型更新一次参数的数据量:param num_batch: 输送多少次完成一轮:return:"""# 遍历nbatchesfor i in range(num_batch):# 在循环中使用np的random.randint方法随机生成[1, v)的整数# 每批次10个样本,分布在(batch, 10)形状的矩阵中,然后再把numpy形式转换成torch中的tensordata = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))# 生成起始标志,使数据矩阵中的第一列数字都为1,这一列也就成为了起始标志列# 当解码器进行第一次解码的时候,会使用起始标志列作为输入data[:, 0] = 1# 因为是copy任务,所有source与target是完全相同的,且数据样本作用变量不需要求梯度# 因此requires_grad设置为Falsewith torch.no_grad():target = source = data# 使用Batch对source和target进行对应批次的掩码张量生成,最后使用yield返回yield Batch(source, target)

示例

V = 11                  # 将生成0-10的整数
batch = 20              # 每次喂给模型20个数据进行参数更新
num_batch = 30          # 连续喂30次完成全部数据的遍历res = data_generator(V, batch, num_batch)
print(f"res {res}")res <generator object data_generator at 0x000001BD670E4D60>

(2) 获得Transformer模型及其优化器和损失函数

# 获得Transformer模型机及其优化器和损失函数
from pyitcast.transformer_utils import get_std_opt          # 导入优化器工具包,用于获得标准的针对Transformer模型的优化器
from pyitcast.transformer_utils import LabelSmoothing       # 导入标签平滑工具包,用于标签平滑(小幅度的改变原有标签值的值域)
from pyitcast.transformer_utils import SimpleLossCompute    # 导入损失计算工具包,能够使标签平滑后的结果进行损失计算# 使用make_mode获得model
model = make_model(V, V, N=2)
# 使用get_std_opt获得模拟优化器
model_optimizer = get_std_opt(model)
# 使用LabelSmoothing获得平滑对象
criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)     # 输入目标词汇的总数
# 使用SimpleLossCompute获得利用标签平滑结果的损失计算方法
loss = SimpleLossCompute(model.generator, criterion, model_optimizer)
  • from pyitcast.transformer_utils import get_std_opt :该标准优化器基于Adam优化器,使其对序列到序列的任务更有效
  • from pyitcast.transformer_utils import LabelSmoothing:因为在理论上人工标注的数据可能并非完全正确,会受一些外界隐私影响而产涩会给你一些微笑的偏差,因此使用标签平滑来弥补这种偏差,减少模型对某一条规律的绝对认知,以防过拟合。
  • from pyitcast.transformer_utils import SimpleLossCompute:损失的计算方法可以认为使交叉熵损失函数。

在这里插入图片描述
在这里插入图片描述

(3)运行模型进行训练和评估

# 导入模型单轮训练工具包run_epoch,该工具将对模型使用给定的损失函数计算方法进行单轮参数更新,同时,打印每轮参数更新的损失结果
from pyitcast.transformer_utils import run_epochdef run(model, loss, epochs=10):""" 模型训练函数:param model: 要进行训练的模型:param loss: 使用的损失计算方法:param epochs: 模型的训练轮数:return:"""for epoch in range(epochs):# 使用训练模式,进行反向传播,所有参数将被更新model.train()run_epoch(data_generator(V, 8, 20), model, loss)        # batch_size = 20# 使用评估模型,不进行反向传播,所有参数不会被更新model.eval()run_epoch(data_generator(V, 8, 5), model, loss)        # batch_size = 5

示例

run(model, loss)

(4)使用模型进行贪婪解码

# 贪婪解码
from pyitcast.transformer_utils import greedy_decode        # 导入贪婪解码工具包greedy_decode,每次预测都是选择概率最大的结果作为输出def greedy_run(model, loss, epochs=10):for epoch in range(epochs):model.train()run_epoch(data_generator(V, 8, 20), model, loss)model.eval()run_epoch(data_generator(V, 8, 5), model, loss)# 模型训练结束,进入评估模式model.eval()# 初始化一个输入张量source = torch.LongTensor([[1,3,2,5,4,6,7,8,9,10]])# 定义源数据掩码张量,因为元素都是1,这里1代表不遮掩,因此相当于对数据源没有遮掩source_mask = torch.ones(1, 1, 10)# 起始标志默认为1result = greedy_decode(model, source, source_mask, max_len=10, start_symbol=1)print(result)

示例

greedy_run(model, loss)

(5)小结

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/174496.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自媒体项目详述

总体框架 本项目主要着手于获取最新最热新闻资讯&#xff0c;以微服务构架为技术基础搭建校内仅供学生教师使用的校园新媒体app。以文章为主线的核心业务主要分为如下子模块。自媒体模块实现用户创建功能、文章发布功能、素材管理功能。app端用户模块实现文章搜索、文章点赞、…

【论文阅读】CTAB-GAN: Effective Table Data Synthesizing

论文地址&#xff1a;[2102.08369] CTAB-GAN: Effective Table Data Synthesizing (arxiv.org) 介绍 虽然数据共享对于知识发展至关重要&#xff0c;但遗憾的是&#xff0c;隐私问题和严格的监管&#xff08;例如欧洲通用数据保护条例 GDPR&#xff09;限制了其充分发挥作用。…

PySide/PYQT如何用Qt Designer和代码来设置文字属性,如何设置文字颜色?

文章目录 📖 介绍 📖🏡 环境 🏡📒 实现方法 📒📝 Qt Designer设置📝 代码📖 介绍 📖 本人介绍如何使用Qt Designer/代码来设置字体属性(包含字体颜色) 🏡 环境 🏡 本文使用Pyside6来进行演示📒 实现方法 📒 📝 Qt Designer设置 首先打开Qt De…

爱心代码--C语言特供(可直接复制,亲测有效)

情人节到了&#xff0c;作为一名程序员&#xff0c;我们拥有属于我们的浪漫。 这里我总结了几种常见的爱心代码&#xff0c;简单易上手。 一.这是一种最为常见的爱心代码 #include<stdio.h> #include<Windows.h>int main() {float x, y, a;for (y 1.5; y > -1.…

手摸手入门Springboot+Grafana10.2接收JSON

JSON&#xff08;JavaScript Object Notation, JS对象简谱&#xff09;是一种轻量级的数据交换格式。它基于 ECMAScript&#xff08;European Computer Manufacturers Association, 欧洲计算机协会制定的js规范&#xff09;的一个子集&#xff0c;采用完全独立于编程语言的文本…

面试?看完这篇就够了-深入分析从点击应用图标到应用界面展示

作者&#xff1a;GeeJoe 从点击桌面图标到应用界面展示 从桌面点击图标到应用界面第一帧绘制出来&#xff0c;整个流程涉及的过程复杂&#xff0c;为了便于理解&#xff0c;这里将整个流程分为四个阶段&#xff1a;应用进程启动阶段、应用进程初始化阶段、Activity 启动阶段、…

【JVM】Java内存溢出分析(堆溢出、栈溢出、方法区溢出、直接内存溢出)

&#x1f4eb;作者简介&#xff1a;小明java问道之路&#xff0c;2022年度博客之星全国TOP3&#xff0c;专注于后端、中间件、计算机底层、架构设计演进与稳定性建设优化&#xff0c;文章内容兼具广度、深度、大厂技术方案&#xff0c;对待技术喜欢推理加验证&#xff0c;就职于…

Golang 中的 Context 包

简介 今天&#xff0c;我们将讨论 Go 编程中非常重要的一个主题&#xff1a;context 包。如果你现在觉得它很令人困惑&#xff0c;不用担心 — 在本文结束时&#xff0c;你将像专家一样处理 context&#xff01; 想象一下&#xff0c;你在一个主题公园&#xff0c;兴奋地准备…

界面控件DevExtreme图表和仪表(v23.1) - 新功能(Angular,React,Vue,jQuery)

本文将为大家总结下DevExtreme在v23.1版本中发布的一些与图表和仪表盘相关的功能。 DevExtreme拥有高性能的HTML5 / JavaScript小部件集合&#xff0c;使您可以利用现代Web开发堆栈&#xff08;包括React&#xff0c;Angular&#xff0c;ASP.NET Core&#xff0c;jQuery&#…

2023年【汽车驾驶员(中级)】免费试题及汽车驾驶员(中级)考试试卷

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年【汽车驾驶员&#xff08;中级&#xff09;】免费试题及汽车驾驶员&#xff08;中级&#xff09;考试试卷&#xff0c;包含汽车驾驶员&#xff08;中级&#xff09;免费试题答案和解析及汽车驾驶员&#xff08;…

文心一言 VS 讯飞星火 VS chatgpt (135)-- 算法导论11.3 1题

一、用go语言&#xff0c;假设我们希望查找一个长度为 n 的链表&#xff0c;其中每一个元素都包含一个关键字k并具有散列值h(k)。每一个关键字都是长字符串。那么在表中查找具有给定关键字的元素时&#xff0c;如何利用各元素的散列值呢? 文心一言&#xff1a; 在Go语言中&a…

【JAVA学习笔记】70 - 反射

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter23/src 反射 一、反射的引出 package com.yinhai.reflection.question;import com.yinhai.Cat;import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IO…