人工智能-循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/207911.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【AI】行业消息精选和分析(11月22日)

今日动态 👓 Video-LLaVA:视觉语言模型革新: - 图像和视频信息转换为文字格式。 - 多模态理解能力,适用于自动问答系统等。 📈 百度文心一言用户数达7000万: 🔊 RealtimeTTS:实时文本…

还记得高中生物书上的莫斯密码吗?利用Python破解摩斯密码的代码示例!

文章目录 前言摩尔斯电码Python实现摩斯密码对照表加密解密测试 完整代码总结关于Python技术储备一、Python所有方向的学习路线二、Python基础学习视频三、精品Python学习书籍四、Python工具包项目源码合集①Python工具包②Python实战案例③Python小游戏源码五、面试资料六、Py…

力扣.面试题 04.06. 后继者(java 树的中序遍历)

Problem: 面试题 04.06. 后继者 文章目录 题目描述思路解题方法复杂度Code 题目描述 设计一个算法,找出二叉搜索树中指定节点的“下一个”节点(也即中序后继)。 如果指定节点没有对应的“下一个”节点,则返回null。 思路 由于题…

「Docker」如何在苹果电脑上构建简单的Go云原生程序「MacOS」

介绍 使用Docker开发Golang云原生应用程序,使用Golang服务和Redis服务 注:写得很详细 为方便我的朋友可以看懂 环境部署 确保已经安装Go、docker等基础配置 官网下载链接直达:Docker官网下载 Go官网下载 操作步骤 第一步 创建一个…

【SpringBoot】ThreadLocal 的详解

一、ThreadLocal 简介 ThreadLocal 叫做线程变量,意思是 ThreadLocal 中填充的变量属于当前线程,该变量对其他线程而言是隔离的,也就是说该变量是当前线程独有的变量。ThreadLocal 为变量在每个线程中都创建了一个副本,那么每个线…

虚拟机centos设置网络模式(桥接|NAT)

前言 桥接模式是通过物理网卡直接与外部网络建立联系的,而NAT模式则是通过虚拟网卡VMnet1或VMnet8通过宿主机共享IP与外部建立网络关系当需要将虚拟机资源共享给局域网用户使用时,宜采用桥接模式;当需要保护虚拟机资源,确保只能由…

封面从这里取好啦

文章目录 前端NPMViteNode.js 后端JavaMavenPython 数据库算法 前端 NPM Vite Node.js 后端 Java Maven Python 数据库 算法

C#使用whisper.net实现语音识别(语音转文本)

目录 介绍 效果 输出信息 项目 代码 下载 介绍 github地址:https://github.com/sandrohanea/whisper.net Whisper.net. Speech to text made simple using Whisper Models 模型下载地址:https://huggingface.co/sandrohanea/whisper.net/tree…

【wireshark】基础学习

TOC 查询tcp tcp 查询tcp握手请求的代码 tcp.flags.ack 0 确定tcp握手成功的代码 tcp.flags.ack 1 确定tcp连接请求的代码 tcp.flags.ack 0 and tcp.flags.syn 1 3次握手后确定发送成功的查询 tcp.flags.fin 1 查询某IP对外发送的数据 ip.src_host 192.168.73.134 查询某…

飞瓜数据B站丨B站UP主11月第3周榜单排行榜榜单(B站平台)发布!

飞瓜轻数发布2023年11月13日-11月19日飞瓜数据UP主排行榜(B站平台),通过充电数、涨粉数、成长指数、带货数据等维度来体现UP主账号成长的情况,为用户提供B站号综合价值的数据参考,根据UP主成长情况用户能够快速找到运营…

竞赛选题 酒店评价的情感倾向分析

前言 🔥 优质竞赛项目系列,今天要分享的是 酒店评价的情感倾向分析 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/post…

Docker Swarm总结(2/3)

目录 8、service 操作 8.1 task 伸缩 8.2 task 容错 8.3 服务删除 8.4 滚动更新 8.5 更新回滚 9、service 全局部署模式 9.1 环境变更 9.2 创建 service 9.3 task 伸缩 10、overlay 网络 10.1 测试环境 1搭建 10.2 overlay 网络概述 10.3 docker_gwbridg 网络基础…