【深度学习笔记】6_5 RNN的pytorch实现

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.5 循环神经网络的简洁实现

本节将使用PyTorch来更简洁地实现基于循环神经网络的语言模型。首先,我们读取周杰伦专辑歌词数据集。

import time
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

6.5.1 定义模型

PyTorch中的nn模块提供了循环神经网络的实现。下面构造一个含单隐藏层、隐藏单元个数为256的循环神经网络层rnn_layer

num_hiddens = 256
# rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)

与上一节中实现的循环神经网络不同,这里rnn_layer的输入形状为(时间步数, 批量大小, 输入个数)。其中输入个数即one-hot向量长度(词典大小)。此外,rnn_layer作为nn.RNN实例,在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输入。需要强调的是,该“输出”本身并不涉及输出层计算,形状为(时间步数, 批量大小, 隐藏单元个数)。而nn.RNN实例在前向计算返回的隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每一层的隐藏状态都会记录在该变量中;对于像长短期记忆(LSTM),隐藏状态是一个元组(h, c),即hidden state和cell state。我们会在本章的后面介绍长短期记忆和深度循环神经网络。关于循环神经网络(以LSTM为例)的输出,可以参考下图(图片来源)。

在这里插入图片描述

循环神经网络(以LSTM为例)的输出

来看看我们的例子,输出形状为(时间步数, 批量大小, 隐藏单元个数),隐藏状态h的形状为(层数, 批量大小, 隐藏单元个数)。

num_steps = 35
batch_size = 2
state = None
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new[0].shape)

输出:

torch.Size([35, 2, 256]) 1 torch.Size([2, 256])

如果rnn_layernn.LSTM实例,那么上面的输出是什么?

接下来我们继承Module类来定义一个完整的循环神经网络。它首先将输入数据使用one-hot向量表示后输入到rnn_layer中,然后使用全连接输出层得到输出。输出个数等于词典大小vocab_size

# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):def __init__(self, rnn_layer, vocab_size):super(RNNModel, self).__init__()self.rnn = rnn_layerself.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) self.vocab_size = vocab_sizeself.dense = nn.Linear(self.hidden_size, vocab_size)self.state = Nonedef forward(self, inputs, state): # inputs: (batch, seq_len)# 获取one-hot向量表示X = d2l.to_onehot(inputs, self.vocab_size) # X是个listY, self.state = self.rnn(torch.stack(X), state)# 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出# 形状为(num_steps * batch_size, vocab_size)output = self.dense(Y.view(-1, Y.shape[-1]))return output, self.state

6.5.2 训练模型

同上一节一样,下面定义一个预测函数。这里的实现区别在于前向计算和初始化隐藏状态的函数接口。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,char_to_idx):state = Noneoutput = [char_to_idx[prefix[0]]] # output会记录prefix加上输出for t in range(num_chars + len(prefix) - 1):X = torch.tensor([output[-1]], device=device).view(1, 1)if state is not None:if isinstance(state, tuple): # LSTM, state:(h, c)  state = (state[0].to(device), state[1].to(device))else:   state = state.to(device)(Y, state) = model(X, state)if t < len(prefix) - 1:output.append(char_to_idx[prefix[t + 1]])else:output.append(int(Y.argmax(dim=1).item()))return ''.join([idx_to_char[i] for i in output])

让我们使用权重为随机值的模型来预测一次。

model = RNNModel(rnn_layer, vocab_size).to(device)
predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)

输出:

'分开戏想暖迎凉想征凉征征'

接下来实现训练函数。算法同上一节的一样,但这里只使用了相邻采样来读取数据。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes):loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)model.to(device)state = Nonefor epoch in range(num_epochs):l_sum, n, start = 0.0, 0, time.time()data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样for X, Y in data_iter:if state is not None:# 使用detach函数从计算图分离隐藏状态, 这是为了# 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)if isinstance (state, tuple): # LSTM, state:(h, c)  state = (state[0].detach(), state[1].detach())else:   state = state.detach()(output, state) = model(X, state) # output: 形状为(num_steps * batch_size, vocab_size)# Y的形状是(batch_size, num_steps),转置后再变成长度为# batch * num_steps 的向量,这样跟输出的行一一对应y = torch.transpose(Y, 0, 1).contiguous().view(-1)l = loss(output, y.long())optimizer.zero_grad()l.backward()# 梯度裁剪d2l.grad_clipping(model.parameters(), clipping_theta, device)optimizer.step()l_sum += l.item() * y.shape[0]n += y.shape[0]try:perplexity = math.exp(l_sum / n)except OverflowError:perplexity = float('inf')if (epoch + 1) % pred_period == 0:print('epoch %d, perplexity %f, time %.2f sec' % (epoch + 1, perplexity, time.time() - start))for prefix in prefixes:print(' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char,char_to_idx))

使用和上一节实验中一样的超参数(除了学习率)来训练模型。

num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 # 注意这里的学习率设置
pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)

输出:

epoch 50, perplexity 10.658418, time 0.05 sec- 分开始我妈  想要你 我不多 让我心到的 我妈妈 我不能再想 我不多再想 我不要再想 我不多再想 我不要- 不分开 我想要你不你 我 你不要 让我心到的 我妈人 可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的
epoch 100, perplexity 1.308539, time 0.05 sec- 分开不会痛 不要 你在黑色幽默 开始了美丽全脸的梦滴 闪烁成回忆 伤人的美丽 你的完美主义 太彻底 让我- 不分开不是我不要再想你 我不能这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的肩膀 你 在我
epoch 150, perplexity 1.070370, time 0.05 sec- 分开不能去河南嵩山 学少林跟武当 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 习武之人切记 仁者无敌- 不分开 在我会想通 是谁开没有全有开始 他心今天 一切人看 我 一口令秋软语的姑娘缓缓走过外滩 消失的 旧
epoch 200, perplexity 1.034663, time 0.05 sec- 分开不能去吗周杰伦 才离 没要你在一场悲剧 我的完美主义 太彻底 分手的话像语言暴力 我已无能为力再提起- 不分开 让我面到你 爱情来的太快就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不
epoch 250, perplexity 1.021437, time 0.05 sec- 分开 我我外的家边 你知道这 我爱不看的太  我想一个又重来不以 迷已文一只剩下回忆 让我叫带你 你你的- 不分开 我我想想和 是你听没不  我不能不想  不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 

小结

  • PyTorch的nn模块提供了循环神经网络层的实现。
  • PyTorch的nn.RNN实例在前向计算后会分别返回输出和隐藏状态。该前向计算并不涉及输出层计算。

注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/525979.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL中常见的DDL操作及示例,数据库操作及表操作

目录 一、数据库操作 1、创建数据库 2、查看所有数据库 3、使用数据库 4、删除数据库 二、表操作&#xff1a; 1、创建表 2、查看表结构 3、修改表结构 3.1 添加列 3.2 修改列数据类型 3.3 修改列名 3.4 删除列 3.5 修改表名 3.6 删除表 注意&#xff1a; 在数…

倒计时35天

dp预备(来源&#xff1a;b站acm刘春英老师) 1. 2. 3. 4. 5. 6. 7.

去除PDF论文行号的完美解决方案

去除PDF论文行号的完美解决方案 1. 遇到的问题 我想去除论文的行号&#xff0c;但是使用网上的Adobe Acrobat裁剪保存后 如何去掉pdf的行编号&#xff1f; - 知乎 (zhihu.com) 翻译时依然会出现行号&#xff0c;或者是转成word&#xff0c;这样就大大损失了格式&#xff0c;…

flask-sqlalchemy库

彩笔激流勇退。 1. 简介 ORM&#xff0c;对象关系映射。简单来说&#xff0c;ORM将数据库中的表与面向对象中的类建立了一种对应关系。这样&#xff0c;我们要操作数据库&#xff0c;表&#xff0c;记录就可以直接通过操作类或者类实例来完成。 SQLAlchemy 是目前python中最…

小程序网页view多行文本超出隐藏或显示省略号

实现效果&#xff1a; 限制两行&#xff0c;超出即显示省略号 实现&#xff1a;话不多说&#xff0c;展示代码 关键代码 .box{ width:100rpx; overflow:hidden; text-overflow: ellipsis;//超出省略号 display:-webkit-box; -webkit-line-clamp: 2;//显…

leetcode 热题 100_相交链表

题解一&#xff1a; 哈希表&#xff1a;两链表出现的第一个相同的值就是相交节点&#xff0c;因此我们先用哈希记录链表A所有出现过的值&#xff0c;再遍历链表B查找哈希表&#xff0c;找出第一个相同的值即为结果。 import java.util.HashSet;public class Solution {public …

Feign实现微服务间远程调用续;基于Redis实现消息队列用于延迟任务的处理,Redis分布式锁的实现;(黑马头条Day05)

目录 延迟任务和定时任务 使用Redis设计延迟队列原理 点评项目中选用list和zset两种数据结构进行实现 如何缓解Redis内存的压力同时保证Redis中任务能够被正确消费不丢失 系统流程设计 使用Feign实现微服务间的任务消费以及文章自动审核 系统微服务功能介绍 提交文章-&g…

zookeeper Study

zk介绍&#xff1b;一种分布式协调服务。 分布式锁&#xff0c;集群选举&#xff0c;数据同步 。 zk都能进行操作&#xff0c;redis&#xff0c;kafka&#xff0c;rabbitmq&#xff0c;都能够用zk做协调管理服务。关键时zk简单操作。 应用说明&#xff1a; 简单介绍一下流程 &…

Python 3 教程(2)

Python3 基础语法 编码 默认情况下&#xff0c;Python 3 源码文件以 UTF-8 编码&#xff0c;所有字符串都是 unicode 字符串。 当然你也可以为源码文件指定不同的编码&#xff1a; # -*- coding: cp-1252 -*- 上述定义允许在源文件中使用 Windows-1252 字符集中的字符编码&…

02-gitlab的数据备份和恢复

一、gitlab的数据备份 关于数据备份&#xff0c;咱们就不需要多说什么了&#xff0c;主要就是方式数据意外丢失&#xff0c;导致代码上线流程及数据的损坏崩溃&#xff1b; 为了避免严重生产事故&#xff0c;进而gitlab有了数据备份与恢复的功能。 1&#xff0c;修改gitlab的配…

悬浮工具球(仿 iphone 辅助触控)

悬浮工具球&#xff08;仿 iphone 辅助触控&#xff09; 兼容移动端 touch 事件点击元素以外位置收起解决鼠标抬起触发元素的点击事件问题 Demo Github <template><divref"FloatingBal"class"floating_ball":class"[dragging, isClick]&q…

C语言中的UTF-8编码转换处理

C语言UTF-8编码的转换 1.C语言简介2.什么是UTF-8编码&#xff1f;2.1 UTF-8编码特点&#xff1a; 3.C语言中的UTF-8编码转换处理步骤1&#xff1a;获取UTF-8编码的字节流步骤2&#xff1a;解析UTF-8编码步骤3&#xff1a;Unicode码点转换为汉字 4.总结 1.C语言简介 C语言是一门…