【动手学深度学习-pytorch】8.5 循环神经网络的从零开始实现

在这里插入图片描述
在这里插入图片描述

转换输入的维度, 以获得形状为(时间步数,批量大小,词表大小)的输出,这将使我们能够更方便地通过最外层的维度, 一步一步地更新小批量数据的隐状态。
在这里插入图片描述>当训练语言模型时,输入和输出来自相同的词表

在这里插入图片描述

循环神经网络模型通过inputs最外层的维度实现循环, 以便逐时间步更新小批量数据的隐状态H
最外层为时间步,与上面的转置相关
输出 output 和隐状态

在这里插入图片描述

我们可以看到输出形状是(时间步数
批量大小,词表大小), 而隐状态形状保持不变,即(批量大小,隐藏单元数)。

在这里插入图片描述

预热(warm-up)期
问题:不用把预测值加到末尾再预测下一个吗?

梯度裁剪

为了防止梯度爆炸或者消失,进行梯度剪裁
在这里插入图片描述

@save
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练网络一个迭代周期(定义见第8章)"""state, timer = None, d2l.Timer()metric = d2l.Accumulator(2)  # 训练损失之和,词元数量for X, Y in train_iter:if state is None or use_random_iter:# 在第一次迭代或使用随机抽样时初始化statestate = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(net, nn.Module) and not isinstance(state, tuple):# state对于nn.GRU是个张量state.detach_()else:# state对于nn.LSTM或对于我们从零开始实现的模型是个张量for s in state:s.detach_()y = Y.T.reshape(-1)X, y = X.to(device), y.to(device)y_hat, state = net(X, state)l = loss(y_hat, y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)# 因为已经调用了mean函数updater(batch_size=1)metric.add(l * y.numel(), y.numel())return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()
#@save
def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型(定义见第8章)"""loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 训练和预测for epoch in range(num_epochs):ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter)if (epoch + 1) % 10 == 0:print(predict('time traveller'))animator.add(epoch + 1, [ppl])print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')print(predict('time traveller'))print(predict('traveller'))
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())

小结

  • 我们可以训练一个基于循环神经网络的字符级语言模型,根据用户提供的文本的前缀生成后续文本。

  • 一个简单的循环神经网络语言模型包括输入编码、循环神经网络模型和输出生成。

  • 循环神经网络模型在训练以前需要初始化状态,不过随机抽样和顺序划分使用初始化方法不同。

  • 当使用顺序划分时,我们需要分离梯度以减少计算量。

  • 在进行任何预测之前,模型通过预热期进行自我更新(例如,获得比初始值更好的隐状态)。

  • 梯度裁剪可以防止梯度爆炸,但不能应对梯度消失。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/578155.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北京WordPress建站公司

北京wordpress建站,就找北京wordpress建站公司 http://wordpress.zhanyes.com/beijing

C#OpenCvSharp YOLO v3 Demo

目录 效果 项目 代码 下载 效果 项目 代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms; using OpenCvSharp; using S…

SiameseRPN原理详解(个人学习笔记)

参考资源: 视觉目标跟踪SiamRPNSiameseRPN详解CVPR2018视觉目标跟踪之 SiameseRPN 目录) 1. 模型架构1.1 Siamese Network1.2 RPN 2. 模型训练2.1 损失函数2.2 端到端训练2.3 正负样本选择 3. 跟踪阶段总结 SiamRPN是在SiamFC的基础上进行改进而得到的一…

产品推荐 | 基于华为海思ARM+Xilinx FPGA双核的8路SDI高清视频图像处理平台

一、板卡概述 PCIE703 是我司自主研制的一款基于 PCIE 总线架构的高性能综 合视频图像处理平台,该平台采用 Xilinx 的高性能 Kintex UltraScale 系列 FPGA 加上华为海思的高性能视频处理器来实现。 华为海思的 HI3531DV200 是一款集成了 ARM A53 四核处理 器性能强…

003 高并发内存池_整体框架设计

​🌈个人主页:Fan_558 🔥 系列专栏:高并发内存池 🌹关注我💪🏻带你学更多知识 文章目录 前言一、ThreadCache整体框架设计二、CentralCache整体框架设计三、PageCache整体框架设计 小结 前言 在…

个人简历主页搭建系列-05:部署至 Github

前面只是本地成功部署网站,网站运行的时候我们可以通过 localhost: port 进行访问。不过其他人是无法访问我们本机部署的网站的。 接下来通过 Github Pages 服务把网站部署上去,这样大家都可以通过特定域名访问我的网站了! 创建要部署的仓库…

U盘文件突然消失?原因与恢复策略全解析

一、遭遇不测:U盘文件突然消失 在日常生活和工作中,U盘扮演着不可或缺的角色,它小巧便捷,能够随时随地存储和传输文件。然而,有时我们会遭遇一个令人头疼的问题:U盘中的文件突然消失。这种突如其来的变故往…

Web漏洞-深入WAF注入绕过

目录 简要其他测试绕过 方式一:白名单(实战中意义不大) 方式二:静态资源 方式三: url白名单 方式四:爬虫白名单 #阿里云盾防SQL注入简要分析 #安全狗云盾SQL注入插件脚本编写 在攻防实战中,往往需要掌握一些特性,比如服务…

Linux CPU 占用率 100% 排查

Linux CPU 占用率 100% 排查 总体来说分为五个步骤 top 命令定位应用进程 pidtop -Hp [pid] 定位应用进程对应的线程 tidprintf “%x\n” [tid] 将 tid 转换为十六进制jstack [pid] | grep -A 10 [tid 的十六进制] 打印堆栈信息根据堆栈信息分析问题 以下为实战例子 写一段…

SQLBolt,一个练习SQL的宝藏网站

知乎上有人问学SQL有什么好的网站,这可太多了。 我之前学习SQL买了本SQL学习指南,把语法从头到尾看了个遍,但仅仅是心里有数的程度,后来进公司大量的写代码跑数,才算真真摸透了SQL,知道怎么调优才能最大化…

Autosar-Mcal配置详解(免费)-MCU

3.6.1创建、配置RAM 1)配置MCU通用配置项 MCU的通用配置项可参考以下配置: 各配置项的说明如下: Wake Up Factor Clear Isr: 是否在唤醒的中断服务函数中清除Wakeup Factor Wake Up Factors Clear Centralised: 是否在shutdown前集中集中清除Wakeu…