自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)

自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

目录标题

  • 自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)
  • 1. Scheduled Sampling(计划采样)
    • 1.1 概念解释
    • 1.2 代码实现
  • 2. Teacher forcing(教师强制)
    • 2.1 概念解释
    • 2.1 代码实现
  • 参考

1. Scheduled Sampling(计划采样)

1.1 概念解释

Scheduled Sampling是一种用于训练序列生成模型的策略,旨在缓解曝光偏差(Exposure Bias)问题。曝光偏差是指模型在训练时接触到的数据分布与测试时的数据分布不一致,导致性能下降。

在Scheduled Sampling中,模型在每个时间步骤都有一定的概率选择使用真实目标序列中的单词作为输入,而不是使用前一个时间步骤生成的单词。这样可以使模型更好地适应真实数据分布,减少曝光偏差问题。

具体来说,Scheduled Sampling使用以下公式计算每个时间步骤生成当前单词的概率:

P ( y t ∣ y 1 , . . . , y t − 1 ) = ( 1 − ϵ ) ∗ P model ( y t ∣ y 1 , . . . , y t − 1 ) + ϵ ∗ P data ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) = (1 - \epsilon) * P_{\text{model}}(y_t|y_1, ..., y_{t-1}) + \epsilon * P_{\text{data}}(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)=(1ϵ)Pmodel(yty1,...,yt1)+ϵPdata(yty1,...,yt1)其中, P ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)表示在给定前面的生成序列条件下生成当前单词 y t y_t yt的概率, P model ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{model}}(y_t|y_1, ..., y_{t-1}) Pmodel(yty1,...,yt1)表示模型生成该单词的概率, P data ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{data}}(y_t|y_1, ..., y_{t-1}) Pdata(yty1,...,yt1)表示真实目标序列中该单词的概率。参数 ϵ \epsilon ϵ用于控制采样策略,可以随着训练的进行而逐渐增加。

1.2 代码实现

下面是一个使用Python实现Scheduled Sampling的示例代码:

在这里插入图片描述

其中, P ( y t ∣ y < t , x ) P(y_t | y_{<t}, x) P(yty<t,x)表示在给定前文和输入的条件下,生成当前时间步的输出的概率。 P model P_{\text{model}} Pmodel表示由模型生成的概率分布, P prev P_{\text{prev}} Pprev表示根据上一个时间步的真实输出计算得到的概率分布。sample是从均匀分布中采样得到的一个随机数,threshold是一个控制Scheduled Sampling引入程度的超参数。

2. Teacher forcing(教师强制)

2.1 概念解释

Teacher forcing(教师强制)是一种在序列生成模型中使用的训练技术。具体来说,当使用RNN(循环神经网络)或类似架构的模型进行序列生成时,每个时间步都会根据前一个时间步的输入和隐藏状态生成输出。在训练期间,如果使用teacher forcing,那么每个时间步的输入将是真实的目标序列(而不是模型自身生成的序列)。这意味着模型在每个时间步都能够观察到正确的答案,从而更容易地学习到正确的模式和规律。

2.1 代码实现

import torch
import torch.nn as nn# 定义序列到序列模型
class Seq2SeqModel(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(Seq2SeqModel, self).__init__()self.hidden_dim = hidden_dim# 定义编码器self.encoder = nn.RNN(input_dim, hidden_dim)# 定义解码器self.decoder = nn.RNN(output_dim, hidden_dim)# 定义全连接层,将解码器的输出映射为目标序列self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, input_seq, target_seq):# 编码器计算输入序列的隐藏状态_, hidden_state = self.encoder(input_seq)# 解码器初始化隐藏状态decoder_hidden_state = hidden_state# 用真实目标序列作为输入来指导解码器的生成过程decoder_outputs, _ = self.decoder(target_seq, decoder_hidden_state)# 对解码器的输出应用全连接层进行映射output_seq = self.fc(decoder_outputs)return output_seq# 创建模型实例
input_dim = 10
hidden_dim = 20
output_dim = 10
model = Seq2SeqModel(input_dim, hidden_dim, output_dim)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):# 步骤1:将模型设为训练模式model.train()# 步骤2:清零梯度optimizer.zero_grad()# 步骤3:前向传播input_seq = torch.randn(5, 3, input_dim)  # 输入序列target_seq = torch.randn(5, 3, output_dim)  # 目标序列output_seq = model(input_seq, target_seq)# 步骤4:计算损失loss = criterion(output_seq, target_seq)# 步骤5:反向传播和优化loss.backward()optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

上述代码中,我们定义了一个简单的序列到序列模型Seq2SeqModel,其中包括一个RNN编码器、一个RNN解码器和一个全连接层。在forward方法中,我们首先使用编码器计算输入序列的隐藏状态,然后将隐藏状态作为解码器的初始隐藏状态。接下来,我们使用真实目标序列来指导解码器的生成过程,并将解码器的输出映射为目标序列。在训练阶段,我们使用真实目标序列作为输入来指导模型的生成过程。最后,我们定义了损失函数和优化器,并进行训练。

需要注意的是,在实际应用中,模型的推理阶段并不会使用真实目标序列来指导生成过程。在推理阶段,可以将前一个时间步的模型输出作为下一个时间步的输入,从而进行序列的自我生成。

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

Scheduled Sampling的搜索结果_百度图片搜索 (baidu.com)
Teacher forcing RNN的搜索结果_百度图片搜索 (baidu.com)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/23441.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vscode 无法格式化python代码、无法格式化C++代码(vscode格式化失效)另一种解决办法:用外部工具yapf格式化(yapf工具)

文章目录 我真的解决方法&#xff1a;用yapfyapf工具使用方法示例格式化单个文件&#xff08;格式化前先用-d参数预先查看格式化更改内容&#xff0c;以决定是否要更改&#xff09;格式化某个目录递归格式化某个目录 20230716 齐拉帕&#xff0c;我删除了虚拟环境目录&#xff…

electron globalShortcut 快捷键与系统全局快捷键冲突

用 electron 开发自己的接口测试工具&#xff08;Post Tools&#xff09;&#xff0c;在设置了 globalShortcut 快捷键后&#xff0c;发现应用中的快捷键与系统全局快捷键冲突了&#xff0c;导致系统快捷键不可正常使用。 快捷键配置 export function initGlobalShortcut(main…

Java设计模式之结构型-外观模式(UML类图+案例分析)

目录 一、基础概念 二、UML类图 三、角色设计 四、案例分析 五、总结 一、基础概念 外观模式&#xff0c;为子系统中的一组接口提供一个一致的界面&#xff0c;此模式定义了一个高层接口&#xff0c;这个接口使得这一子系统更加容易使用。 二、UML类图 三、角色设计 角…

Modbus协议是什么?Modbus协议类型解析

什么是Modbus协议?Modbus 是由 Modicon(现为施耐德电气公司的一个品牌)在 1979 年发明的一种工业控制总线协议&#xff0c;是全球第一个真正用于工业现场的总线协议。Modbus 以其简单、健壮、开放而且不需要特许授权的特点&#xff0c;成为通用通信协议。为了适应以太网环境&a…

RDS-Tools RDS-Knight Crack

RDS 高级安全性 利用全面的网络安全工具箱中有史以来最强大的安全功能集来保护您的 RDS 基础架构。 全方位 360 保护 无与伦比的功能集 无与伦比的物有所值 企业远程桌面安全。现代工作空间的智能解决方案。 办公室正在权力下放。远程办公室和移动员工数量创历史新高。随…

细节:双花括号({{ ... }})在Vue.js中的用法

问题&#xff1a; 为什么后端返回的是数字类型时&#xff0c; {{ form.orderPrice }}可以拿到值展示&#xff0c; {{ form.orderPrice || "-" }} 不可以&#xff1f; 接口返回数据&#xff1a; <el-form-item label"订单金额&#xff1a;" prop"…

Qt Creator常用快捷键及技巧

文章目录 1.[Qt Creator常用快捷键及技巧提升编码效率]2.win10上安装QT &#xff0c;选择安装组件3.qt配置过程中主要注意的几点4.目录结构附&#xff1a;网友整理快捷方式&#xff1a; 1.[Qt Creator常用快捷键及技巧提升编码效率] (https://blog.csdn.net/luoyayun361/artic…

ChatGPT 最佳实践指南之:使用外部工具

Use external tools 使用外部工具 Compensate for the weaknesses of GPTs by feeding them the outputs of other tools. For example, a text retrieval system can tell GPTs about relevant documents. A code execution engine can help GPTs do math and run code. If a …

graylog源码搭建

这里主要讲如何源码安装graylog 下载地址&#xff1a; https://www.graylog.org/downloads/ 下载带有JVM的源码文件源码安装 下载graylog-5.1.3-linux-x64.tgz&#xff0c;并上传到Centos中&#xff0c;执行以下操作 tar -zxvf graylog-5.1.3-linux-x64.tgzcd /etcmkdir -p …

ARPACK特征值求解分析

线性方程组求解、特征值求解是(数值)线性代数的主要研究内容。力学、电磁等许多问题&#xff0c;最终都可以归结为特征值、特征向量的求解。 ARPACK使用IRAM(Implicit Restarted Arnoldi Method)求解大规模系数矩阵的部分特征值与特征向量。了解或者熟悉IRAM算法&#xff0c;必…

Spring Boot中如何解决Redis的缓存穿透、缓存击穿、缓存雪崩?

Redis的缓存穿透、缓存击穿、缓存雪崩 一、概述 ① 缓存穿透&#xff1a;大量请求根本不存在的key ② 缓存雪崩&#xff1a;Redis中大量key集体过期 ③ 缓存击穿&#xff1a;Redis中一个热点key过期 三者出现的根本原因&#xff1a;Redis命中率下降&#xff0c;请求直接打…

机器学习 day27(反向传播)

导数 函数在某点的导数为该点处的斜率&#xff0c;用height / width表示&#xff0c;可以看作若当w增加ε&#xff0c;J(w,b)增加k倍的ε&#xff0c;则k为该点的导数 反向传播 tensorflow中的计算图&#xff0c;由有向边和节点组成。从左向右为正向传播&#xff0c;神经网…