pytorch(十)循环神经网络

文章目录

    • 卷积神经网络与循环神经网络的区别
    • RNN cell结构
    • 构造RNN
    • 例子 seq2seq

卷积神经网络与循环神经网络的区别

卷积神经网络:在卷积神经网络中,全连接层的参数占比是最多的。

卷积神经网络主要用语处理图像、语音等空间数据,它的特点是局部连接和权值共享,减少了参数的数量。CNN通常包括卷积层、池化层和全连接层,卷积层用于提取输入数据的特诊,池化层用语降维和压缩特征,全连接层将特征映射到更加高维度的空间中。

卷积神经网络在处理空间数据的时候,其输入通常是一个二维或者三维的张量,每一个元素对应一个输出。CNN通过卷积操作获取特征,这些特征往往是局部相关的,即每个特征都与输入数据的一个局部区域相关联(相邻的特征往往具有相似性)。

循环神经网络

循环神经网络主要用于处理序列数据,比如自然语言处理、语音识别和时间序列预测(具有时间依赖的数据),它的特点是具有时间轴,能够保留和处理序列中的历史信息。RNN包括循环层,能够接收前一层的信息并且与当前的输入结合,生成当前的输出序列。

循环神经网络处理的数据往往是序列数据,即输入的数据是一个序列,每一个时间步都有一个对应的输出,RNN通过循环层保留和传递历史信息,从而建立时间上的依赖关系。

循环神经网络也使用到了权值共享的想法,RNN在不同的时间位置共享参数(CNN在不同的空间位置共享参数)

RNN cell结构

RNN cell的结构在本质上类似于传统的神经网络模型(输入层、隐藏层、输出层),但是RNN于传统的NN最大的区别在于:隐藏层的输入不仅包括输入层的输出,还包括上一时刻隐藏层的输出,其结构如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

之所以叫做循环神经网络,是因为以上的cell都是同一个cell,参数也只有cell里的一份而已,在训练中,cell层使用的是循环结构,使用不同的输入数据进行训练。

构造RNN

第一种方法 RNNCell

cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)
hidden=cell(input,hidden)
# input of shape(batch,input_size)
# hidden of shape(batch,hidden_size)
# output of shape(batch,hidden_size)
# dataset.shape(seqLen,batchSize,inputSize)
import torchbatch_size=1 # N=1
seq_len=3 # x=[x1,x2,x3]
input_size=4 # [4*3]
hidden_size=2 # [2*3]
# 两种构造方式——1
cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)#(seq,batch=n,features) 序列的长度 * 批量 * 维度
dataset=torch.randn(seq_len,batch_size,input_size)
# 批量 * 维度
hidden=torch.zeros(batch_size,hidden_size)print('dataset:',dataset)
print('hidden:',hidden)# 构造循环
for idx,input in enumerate(dataset):print('=' *20,idx,'='*20)print(input)print('Input size:',input.shape)hidden=cell(input,hidden)print(hidden)print('Outputs size:',hidden.shape)print(hidden)

第二种方法

cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
out,hidden=cell(input,hidden)
# input:表示整个输入序列;input of shape(seqSize,batch,input_size),这里的seqSize实际上就是上面的循环过程
# hidden:表示h0;hidden of shape(numLayers,batch,hidden_size)
# output of shape(seqLen,batch,hidden_size)
import torchbatch_size=1 # N=1
seq_len=3 # x=[x1,x2,x3]
input_size=4 # [4*3]
hidden_size=2 # [2*3]
num_layers=1
# 两种构造方式——2 num_layers表示层数
cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
# cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
# inputs=torch.randn(batch_size,seq_len,input_size)inputs=torch.randn(seq_len,batch_size,input_size)
hidden=torch.zeros(num_layers,batch_size,hidden_size)out,hidden=cell(inputs,hidden)
print('Output size:',out.shape)
print('Output:',out)
print('Hidden size:',hidden.shape)
print('Hidden:',hidden)

例子 seq2seq

需要把 hello 序列经过训练变成 ohlol 序列,x=[x1,x2,x3,x4,x5]

这里是需要训练一个模型,使得输入“hello”,而输出为“ohlol”,实际上对于输入的数据,判别输出的数据属于哪一个类别,这里只有4种字母,所以输出的维度为4。

在这里插入图片描述

将序列数据转变成向量,一般转换为独热向量

在这里插入图片描述

import torchbatch_size=1
input_size=4
hidden_size=4idx2char=['e','h','l','o']
# 字典转变
x_data=[1,0,2,2,3]
y_data=[3,1,2,3,2]one_hot_looup=[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]x_one_hot=[one_hot_looup[x] for x in x_data]# 两个数据中的-1表示序列的长度seqSize
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
print('inputs:',inputs)
labels=torch.LongTensor(y_data).view(-1,1)class Model(torch.nn.Module):def __init__(self,input_size,hidden_size,batch_size):super(Model,self).__init__()self.batch_size=batch_sizeself.input_size=input_sizeself.hidden_size=hidden_sizeself.rnncell=torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)def forward(self,input,hidden):hidden=self.rnncell(input,hidden)return hidden# h0def init_hidden(self):return torch.zeros(self.batch_size,self.hidden_size)net=Model(input_size,hidden_size,batch_size)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)for epoch in range(15):loss=0optimizer.zero_grad()hidden=net.init_hidden()print('Predicted string:',end='')# 这个实际上就是RNNfor input,label in zip(inputs,labels):# print('input:',input,'label:',label)hidden=net(input,hidden)loss+=criterion(hidden,label)# 输出可能性最大的那个值,也就是预测的值_,idx=hidden.max(dim=1)print(idx2char[idx.item()],end='')loss.backward()optimizer.step()print(',Epoch [%d/15] loss=%.4f'%(epoch+1,loss.item()))

在这里插入图片描述

独热向量的缺点

  • 独热向量的维度过高,比如26个字母有26维度。给每一个汉字都映射一个独热向量会有维度的诅咒;
  • 独热向量过于稀疏;
  • 独热向量是硬编码,不是学习出来的,每一个向量对应都是确定的。

解决方案:Embedding,也就是把高维度的稀疏的数据映射成低维的稠密的数据中,也就是数据降维

降维后的网络

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/540381.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

复现文件上传漏洞

一、搭建upload-labs环境 将下载好的upload-labs的压缩包,将此压缩包解压到WWW中,并将名称修改为upload,同时也要在upload文件中建立一个upload的文件。 然后在浏览器网址栏输入:127.0.0.1/upload进入靶场。 第一关 选择上传文件…

Qt/QML编程之路:基于QWidget编程及各种2D/3D/PIC绘制的示例(45)

关于使用GWidget,这里有一个示例,看了之后很多图形绘制,控件使用,及最基本的QWidget编程都比较清楚了。ui的绘制: 运行后的界面如 工程中有非常丰富的关于各种图形的绘制,比如上图中circle,还有image。有下面一段readme的说明: # EasyQPainter Various operation pra…

AI预测福彩3D第9弹【2024年3月15日预测--第2套算法重新开始计算第1次测试】

我有一个彩友交流群,里面有很多热衷研究彩票技术的彩友,经过群里很多彩友的建议,我昨天晚上又设计了一套新的算法,这套算法加入了012路权重,将012路的权重按照开出概率设为不同的权重系数。 然后与第1套算法的进行相乘…

RabbitMq踩坑记录

1、连接报错:Broker not available; cannot force queue declarations during start: java.io.IOException 2.1、原因:端口不对 2.2、解决方案: 检查你的连接配置,很可能是你的yml里面的端口配置的是15672,更改为5672即…

STM32基础--中断应用

本文章里面假设中断就是异常,不做区分。 异常类型 F103 在内核水平上搭载了一个异常响应系统,支持为数众多的系统异常和外部中断。其中系统异常有 8 个(如果把 Reset 和 HardFault 也算上的话就是 10 个),外部中断有…

upload-labs 0.1 靶机详解

下载地址https://github.com/c0ny1/upload-labs/releases Pass-01 他让我们上传一张图片,我们先尝试上传一个php文件 发现他只允许上传图片格式的文件,我们来看看源码 我们可以看到它使用js来限制我们可以上传的内容 但是我们的浏览器是可以关闭js功能的…

Linux 大页内存 Huge Pages 虚拟内存

Linux 大页内存 Huge Pages 虚拟内存 - 秋来叶黄 - 博客园 (cnblogs.com) Linux为什么要有大页内存?为什么DPDK必须要设置大页内存?这都是由系统架构决定的。一开始为了解决一个问题,设计了对应的方案,随着事物的发展&#xff0c…

产品必会的7个Axure使用技巧(中阶)

一、格式刷 格式刷可以用来复制样式 先选择需要复制样式的控件,然后在点击钢笔旁边的“更多”,点击选择格式刷 然后选择需要复制样式的目标控件,点击“应用”,就可以将样式复制过来。 不过在Axure9中,可以直接右键复…

2025MEM/MBA学习计划-1

三年大考, 从2022年开始,经历了2022年失败,2023年因为口罩弃考,2024年A线差4分,超B线6分,但我不想调剂,今年继续。英语42分给我了比较大信心,今年的英语我要过65。从今天开始,先从数…

Redis:使用redis-dump导出、导入、还原数据实例

redis的备份和还原,借助了第三方的工具,redis-dump 1、安装必要环境 yum -y install zlib-devel openssl-devel2、安装redis-dump 安装ruby: ruby下载地址:https://www.ruby-lang.org/zh_cn/downloads/ 我下载的是 2.5.0 版本…

网站首页添加JS弹屏公告窗口教程

很多小白站长会遇到想给自己的网站添加一个弹屏公告&#xff0c;用于做活动说明、演示站提示等作用与目的。 下面直接上代码&#xff1a;&#xff08;直接复制到网页头部、底部php、HTML文件中&#xff09; <script src"https://www.mohuda.com/site/js/sweetalert.m…

孙宇晨最新研判:加密货币将成为全球金融基础设施的一部分

近日,波场TRON创始人、火币HTX全球顾问委员会委员孙宇晨接受了在加密社区有重要影响力的媒体平台Bankless的专访,就自己的从业经历、涉足加密行业的理想、波场TRON本身的发展和未来的市场走向等话题进行了详细的分享。 孙宇晨认为,波场TRON的使命是为那些没有银行账户的人提供…