《PyTorch深度学习实践》第十二讲循环神经网络基础

一、RNN简介

1、RNN网络最大的特点就是可以处理序列特征,就是我们的一组动态特征。比如,我们可以通过将前三天每天的特征(是否下雨,是否有太阳等)输入到网络,从而来预测第四天的天气。
       我们可以看RNN的网络结构如下:

二、RNN cell用法

import torchbatch_size = 1 # 批处理大小
seq_len = 3 # 序列长度
input_size = 4 # 输入维度
hidden_size = 2 # 隐藏层维度cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)# (seq, batch, features)
dataset = torch.randn(seq_len, batch_size, input_size)
print(dataset)
hidden = torch.zeros(batch_size, hidden_size)
print(hidden)for idx, input in enumerate(dataset):print( '=' * 20, idx, '=' * 20)print( 'Input size: ', input.shape)hidden = cell(input, hidden)print( 'outputs size: ', hidden.shape)print(hidden)

三、RNN用法

import torchbatch_size = 1 # 批处理大小
seq_len = 3 # 序列长度
input_size = 4 # 输入维度
hidden_size = 2 # 隐藏层维度
num_layers = 4  # 隐藏层数量cell = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)# (seqLen, batchSize, inputSize)
inputs = torch.randn(seq_len, batch_size, input_size)
hidden = torch.zeros(num_layers, batch_size, hidden_size)
out, hidden = cell(inputs, hidden)print( 'Output size:', out.shape)
print( 'Output:', out)
print( 'Hidden size: ', hidden.shape)
print( 'Hidden: ', hidden)

四、Embedding

把input变为稠密的数据

代码:

import torch# parameters
num_class = 4
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5# 准备数据集
idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # (batch, seq_len)
y_data = [3, 1, 2, 3, 2]    # (batch * seq_len)inputs = torch.LongTensor(x_data)   # Input should be LongTensor: (batchSize, seqLen)
labels = torch.LongTensor(y_data)   # Target should be LongTensor: (batchSize * seqLen)# 构建模型
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = torch.nn.Embedding(input_size, embedding_size)self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)self.fc = torch.nn.Linear(hidden_size, num_class)def forward(self, x):hidden = torch.zeros(num_layers, x.size(0), hidden_size)x = self.emb(x)  # (batch, seqLen, embeddingSize)x, _ = self.rnn(x, hidden)  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, hidden_size)x = self.fc(x)  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)return x.view(-1, num_class)  # reshape to use Cross Entropy: (𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆×𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)net = Model()# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)# 训练模型
for epoch in range(15):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()_, idx = outputs.max(dim=1)idx = idx.data.numpy()print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

 运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/504364.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《CrackCollect》

CrackCollect 类型:益智学习 视角:2d 乐趣点:趣味化英语学习,闯关增加学习动力 时间:2019 个人职责: 1、所有功能的策划讨论 2、所有开发工作 3、所有上架工作 此游戏旨在针对英语水平处于初级阶段的人&…

【已亲测有效】如何彻底删除nodejs,避免影响安装新版本

第一步开始菜单搜索uninstall node.js,点击之后等待删除(删除node_modules文件夹以及以下这些文件) 第二步手动删除nodejs下载位置的其他文件夹。(就是另外自己新建的两个文件夹node_cache和node_global) 到这里其实应…

STM32标准库开发—实时时钟(BKP+RTC)

BKP配置结构 注意事项 BKP基本操作 时钟初始化 RCC_APB1PeriphClockCmd(RCC_APB1Periph_PWR, ENABLE);RCC_APB1PeriphClockCmd(RCC_APB1Periph_BKP, ENABLE);PWR_BackupAccessCmd(ENABLE);//设置PWR_CR的DBP,使能对PWR以及BKP的访问读写寄存器操作 uint16_t ArrayW…

Python学习 问题汇总(None)

None的总结 在Python中,对于一些变量往往需要赋初始值,为了防止初始值与正常值混淆,通常采用置0或置空操作,置0比较简单,置空则是赋NoneNone是一个空值,可以赋给任意类型的变量,起到占位的作用…

社区店经营实战策略:如何打造火爆生意并持续盈利?

在竞争激烈的商业环境中,经营一家成功的社区店需要一套全面而有效的策略。作为一名开鲜奶吧5年的创业者,我将分享一些关键的经营策略,帮助你打造火爆生意并实现持续盈利。 1、 市场调研: 在开店之前,深入了解你所在社…

ubuntu22.04安裝mysql8.0

官网下载mysql:MySQL :: Download MySQL Community Server 将mysql-server_8.0.20-2ubuntu20.04_amd64.deb-bundle.tar上传到/usr/local/src #解压压缩文件 tar -xvf mysql-server_8.0.20-2ubuntu20.04_amd64.deb-bundle.tar解压依赖包依次输入命令 sudo dpkg -i m…

微服务简介及其相关技术栈

目录 1、简介 2、技术栈 3、单体架构 4、分布式架构 5、微服务 6、总结 🍃作者介绍:双非本科大三网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发、数据结构和算法,初步涉猎Pyth…

c++之旅——第二弹

大家好啊,这里是c之旅第二弹,跟随我的步伐来开始这一篇的学习吧! 如果有知识性错误,欢迎各位指正!!一起加油!! 创作不易,希望大家多多支持哦! 一、内存四区…

浅析TSN网络之车载以太网协议测试

TSN是一项从视频音频数据领域延伸至工业领域、汽车领域的技术。TSN最初来源于音视频领域的应用需求,当时该技术被称为AVB,由于针对音视频网络需要较高的带宽和最大限度的实时,借助AVB能较好的传输高质量音视频。 2012年,AVB任务组…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的交通信号标志识别软件(Python+PySide6界面+训练代码)

摘要:开发高效的交通信号标志识别软件对于提升道路安全和自动驾驶技术发展具有重要意义。本篇博客详细阐述了如何利用深度学习构建一个交通信号标志识别软件,并提供了完整的实现代码。该软件基于先进的YOLOv8算法,并对比了YOLOv7、YOLOv6、YO…

自动化测试的10大误区!

自动化测试因提高效率,减少重复工作的特性而被广泛采用。然而,随着自动化测试的普及,自动化测试也面临一系列挑战和误解。 这些误区不仅影响了测试的有效性,还会导致一定的项目风险,为了确保自动化测试能够真正提升测…

MySQL误truncate截断后数据恢复2024.3.1

近期很多MySQL数据丢失情况,很多是人为误操作导致。MySQL数据库丢失可能由truncate截断表、delete删除表中数据行、delete删除表、delete删除库、操作系统rm删除数据库文件、硬盘坏道等情况导致。本案例是一个误截断表导致的丢失。 不管哪种情况,第一时…