PyTorch-RNN

首先介绍一下什么是rnn,rnn特别擅长处理序列类型的数据,因为他是一个循环的结构

一个序列的数据依次进入网络A,网络A循环的往后传递。

这就是RNN的基本结构类型。而最早的RNN模型,序列依次进入网络中,之前进入序列的数据会保存信息而对后面的数据产生影响,所以RNN有着记忆的特性,而同时越前面的数据进入序列的时间越早,所以对后面的数据的影响也就越弱,简而言之就是一个数据会更大程度受到其临近数据的影响。但是我们很有可能需要更长时间之前的信息,而这个能力传统的RNN特别弱,于是有了LSTM这个变体。

LSTM

这就是LSTM的模型结构,也是一个向后传递的链式模型,而现在广泛使用的RNN其实就是LSTM,序列中每个数据传入LSTM可以得到两个输出,而这两个输出和序列中下一个数据一起又作为传入LSTM的输入,然后不断地循环向后,直到序列结束。

下面结合pytorch一步一步来看数据传入LSTM是怎么运算的

首先需要定义好LSTM网络,需要nn.LSTM(),首先介绍一下这个函数里面的参数

  • input_size 表示的是输入的数据维数
  • hidden_size 表示的是输出维数
  • num_layers 表示堆叠几层的LSTM,默认是1
  • bias True 或者 False,决定是否使用bias
  • batch_first True 或者 False,因为nn.lstm()接受的数据输入是(序列长度,batch,输入维数),这和我们cnn输入的方式不太一致,所以使用batch_first,我们可以将输入变成(batch,序列长度,输入维数)
  • dropout 表示除了最后一层之外都引入一个dropout
  • bidirectional 表示双向LSTM,也就是序列从左往右算一次,从右往左又算一次,这样就可以两倍的输出

是网络的输出维数,比如M,因为输出的维度是M,权重w的维数就是(M, M)和(M, K),b的维数就是(M, 1)和(M, 1),最后经过sigmoid激活函数,得到的f的维数是(M, 1)。

对于第一个数据,需要定义初始的h_0和c_0,所以nn.lstm()的输入Inputs:input, (h_0, c_0),表示输入的数据以及h_0和c_0,这个可以自己定义,如果不定义,默认就是0

第二步也是差不多的操作,只不多是另外两个权重加上不同的激活函数,一个使用的是sigmoid,一个使用的是tanh,得到的输出

都是(M, 1)。维数都是(K, 1)。

code

1

lstm = nn.LSTM(10, 30, batch_first=True)

可以通过这样定义一个一层的LSTM输入是10,输出是30

1

2

3

4

lstm.weight_hh_l0.size()

lstm.weight_ih_l0.size()

lstm.bias_hh_l0.size()

lstm.bias__ih_l0.size()

可以分别得到权重的维数,注意之前我们定义的4个weights被整合到了一起,比如这个lstm,输入是10维,输出是30维,相对应的weight就是30×10,这样的权重有4个,然后pytorch将这4个组合在了一起,方便表示,也就是lstm.weight_ih_l0,所以它的维数就是120×10

我们定义一个输入

1

2

3

x = Variable(torch.randn((50, 100, 10)))

h0 = Variable(torch.randn(1, 50, 30))

c0 = Variable(torch.randn(1, 50 ,30))

x的三个数字分别表示batch_size为50,序列长度为100,每个数据维数为10

h0的第二个参数表示batch_size为50,输出维数为30,第一个参数取决于网络层数和是否是双向的,如果双向需要乘2,如果是多层,就需要乘以网络层数

c0的三个参数和h0是一致的

1

out, (h_out, c_out) = lstm(x, (h0, c0))

这样就可以得到网络的输出了,和上面讲的一致,另外如果不传入h0和c0,默认的会传入相同维数的0矩阵

这就是我们如何在pytorch上使用RNN的基本操作了,了解完最基本的参数我们才能够使用其来做应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/512188.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024全网最全Excel函数与公式应用

💂 个人网站:【 海拥】【神级代码资源网站】【办公神器】🤟 基于Web端打造的:👉轻量化工具创作平台💅 想寻找共同学习交流的小伙伴,请点击【全栈技术交流群】 引言 Excel是一款广泛应用于商业、教育和个人…

内存冷热标记 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 100分 题解: Java / Python / C 题目描述 现代计算机系统通常存在多级的存储设备,针对海量的 wordload 的优化的一种思路是将热点内存页优化先放到快速存储层级,这就需要对内存页进…

【vue/组件封装】封装一个带条件筛选的搜索框组件(多组条件思路、可多选)详细流程

引入:实现一个带有筛选功能的搜索框,封装成组件; 搜索框长这样子: 点击右侧筛选图标后弹出层,长这样子: 实际应用中有多组筛选条件,这里为了举栗子就展示一组; 预览:…

ES核心概念(45-48)(56-62)(101-103)

ES集群 ES集群(Cluster)包含多个节点(服务器),整体提供服务 核心概念 索引Index:类似于mysql中的表 映射Mapping:数据的结构信息 文档:相当于表中的一条记录 分片: 将数据分成多片…

010-原型链

原型链 1、概念2、原理3、new 操作符原理4、应用 1、概念 原型链:javascript的继承机制,是指获取JavaScript对象的属性会顺着其_proto_的指向寻找,直至找到Object.prototype上。 2、原理 💡 Tips:构造函数 Fn&#…

Day12-【Java SE进阶】JDK8新特性:Lambda表达式、方法引用、常见算法、正则表达式、异常

一、JDK8新特性 1.Lambda表达式 Lambda表达式是JDK 8开始新增的一种语法形式;作用:用于简化名内部类的代码写法。 注意:Lambda表达式并不是说能简化全部匿名内部类的写法,只能简化函数式接口的匿名内部类。 有且仅有一个抽象方法的接口。注意:将来我们见到的大部…

电商直播大屏是什么?想搞这个怎么做?

随着电商行业的快速发展,直播带货已成为当下最热门的市场营销方式之一。为了更好地掌握直播数据,为企业决策提供有力支持,电商直播数据大屏应运而生。 一、电商直播数据大屏概述 电商直播数据大屏是一种集成了多种数据源的大屏幕可视化展示…

C if...else 语句

一个 if 语句 后可跟一个可选的 else 语句,else 语句在布尔表达式为 false 时执行。 语法 C 语言中 if…else 语句的语法: if(boolean_expression) {/* 如果布尔表达式为真将执行的语句 */ } else {/* 如果布尔表达式为假将执行的语句 */ }如果布尔表…

主流抠图算法trimap-based/free

GitHub - JizhiziLi/matting-survey: Deep Image Matting: A Comprehensive SurveyDeep Image Matting: A Comprehensive Survey. Contribute to JizhiziLi/matting-survey development by creating an account on GitHub.https://github.com/JizhiziLi/matting-survey数据集介…

C语言文件操作,linux文件操作,文件描述符,linux下一切皆文件,缓冲区,重定向

目录 C语言文件操作 如何打开文件以及打开文件方式 读写文件 关闭文件 Linux系统下的文件操作 open 宏标志位 write,read,close,lseek接口 什么是当前路径? linux下一切皆文件 文件描述符 文件描述符排序 C语言文件操…

【更新2022】各省数字经济水平测算 原始数据+结果 2011-2022

数据说明:参照赵涛等(2020)的文章,利用熵值法和主成分对省市数字经济水平进行测算,原始数据来自第五期北大数字普惠金融指数,含原始数据,以及熵值法、主成分两种测算结果。一、数据介绍 数据名…

【EI会议征稿通知】第七届交通运输与土木建筑国际学术论坛(ISTTCA 2024)

第七届交通运输与土木建筑国际学术论坛(ISTTCA 2024) 2024 7th International Symposium on Traffic Transportation and Civil Architecture 交通运输是经济发展的先行官,而岩土是发展交通运输网络无法避开的话题。将传统的土木工程技术与先…