机器学习深度学习——常见循环神经网络结构(RNN、LSTM、GRU)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——RNN的从零开始实现与简洁实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

常见循环神经网络结构(RNN、LSTM、GRU)

  • 引言
  • RNN
  • LSTM
    • 门控记忆元
      • 输入门、输出门和遗忘门
      • 候选记忆元
      • 记忆元
      • 隐状态
    • LSTM的简洁实现
  • GRU
    • 结构详解
    • GRU的简洁实现
  • 常用应用方式

引言

之前已经实现讲解并实现过了RNN模型,而LSTM可以弥补RNN的一些缺点,GRU是LSTM的简化版本,这里我们就回顾一下RNN模型,接着循序渐进讲解LSTM和GRU。
CNN和全连接网络的数据表示能力已经很强了,但是我们为啥还需要循环神经网络呢?这是因为现实的问题更复杂,很多数据的输入顺序对于结果都是有很大影响的。如文本数据(尤其是字母和文字的组合),先后顺序具有非常重要的意义,如果打乱,就会无法正确表示原始信息。而相比其他网络,循环神经网络因为具有记忆能力,所以更有效。

RNN

RNN循环神经网络使用torch.nn.RNN()来构建,如下图所示:
在这里插入图片描述
针对t时刻的隐状态,可以由下面公式计算:
h t = φ ( W i h x t + b i h + W h h h t − 1 + b h h ) = φ ( W i h x t + W h h h t − 1 + b h ) 其中: h t 是 t 时刻的隐藏状态; h t − 1 是 t − 1 时刻的隐藏状态 W i h 是输入到隐藏层的权重; W h h 是隐藏层到隐藏层的权重; b i h 是输入到隐藏层的偏置; b h h 是隐藏层到隐藏层的偏置; h_t=φ(W_{ih}x_t+b_{ih}+W_{hh}h_{t-1}+b_{hh})\\ =φ(W_{ih}x_t+W_{hh}h_{t-1}+b_{h})\\ 其中:h_t是t时刻的隐藏状态;h_{t-1}是t-1时刻的隐藏状态\\ W_{ih}是输入到隐藏层的权重;W_{hh}是隐藏层到隐藏层的权重;\\ b_{ih}是输入到隐藏层的偏置;b_{hh}是隐藏层到隐藏层的偏置; ht=φ(Wihxt+bih+Whhht1+bhh)=φ(Wihxt+Whhht1+bh)其中:htt时刻的隐藏状态;ht1t1时刻的隐藏状态Wih是输入到隐藏层的权重;Whh是隐藏层到隐藏层的权重;bih是输入到隐藏层的偏置;bhh是隐藏层到隐藏层的偏置;
激活函数可以使用ReLU或tanh。
虽然在对序列数据进行建模时,RNN有一定记忆能力,但单纯的RNN会随着递归次数的增加,出现权重指数级爆炸或消失的问题,从而难以捕捉长时间关联,并导师训练时收敛困难。

LSTM

LSTM称为长短期记忆网络,是一种特殊的RNN,主要用于解决长序列训练过程中的梯度消失和爆炸问题,能在长序列中获得更好的分析效果。

门控记忆元

记忆元的目的是为了记录附加的信息,要控制记忆元,我们需要下面的几个门:
1、输出门:用来从单元中输出条目
2、输入门:决定何时将数据读入单元
3、遗忘门:重置单元的内容
接下来来看看如何工作的:

输入门、输出门和遗忘门

当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图:
在这里插入图片描述
上图的σ是代表由sigmoid激活函数的全连接层处理,因此三个门的值都在(0,1)范围内,显然计算方法如下:
I t = σ ( X t W x i + H t − 1 W h i + b i ) O t = σ ( X t W x o + H t − 1 W h o + b o ) F t = σ ( X t W x f + H t − 1 W h f + b f ) I_t=\sigma(X_tW_{xi}+H_{t-1}W_{hi}+b_i)\\ O_t=\sigma(X_tW_{xo}+H_{t-1}W_{ho}+b_o)\\ F_t=\sigma(X_tW_{xf}+H_{t-1}W_{hf}+b_f) It=σ(XtWxi+Ht1Whi+bi)Ot=σ(XtWxo+Ht1Who+bo)Ft=σ(XtWxf+Ht1Whf+bf)

候选记忆元

其计算与上面类似,但是使用tanh来作为激活函数,函数范围为(-1,1),计算方式为:
G t = t a n h ( X t W x g + H t − 1 W h g + b g ) G_t=tanh(X_tW_{xg}+H_{t-1}W_{hg}+b_g) Gt=tanh(XtWxg+Ht1Whg+bg)
如图所示:
在这里插入图片描述

记忆元

在LSTM中,有两个门用于实现一种输入和遗忘的机制:输入门控制采用多少来自候选记忆元的新数据,而遗忘门控制保留多少过去的记忆元的内容。使用按元素乘法,得出:
C t = F t ⨀ C t − 1 + I t ⨀ G t C_t=F_t \bigodot C_{t-1}+I_t \bigodot G_t Ct=FtCt1+ItGt
若遗忘门始终为1且输入门始终为0,则过去的记忆元 将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题, 并更好地捕获序列中的长距离依赖关系。
如下图所示:
在这里插入图片描述

隐状态

最后是计算隐状态,这里就是输出门的作用了。LSTM中,它是记忆元的tanh的门控版本,确保了隐状态的值在(-1,1)之间:
H t = O t ⨀ t a n h ( C t ) H_t=O_t \bigodot tanh(C_t) Ht=Ottanh(Ct)
只要输出门接近1,就能有效将所有记忆换递给预测部分,对于输出门接近0,我们只保留记忆元内的所有信息,而不需要更新隐状态。
那么整体的LSTM图示如下所示:
在这里插入图片描述

LSTM的简洁实现

使用高级API,我们可以直接实例化LSTM模型。这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节:

from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

运行结果:

perplexity 1.1, 48684.5 tokens/sec on cpu
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

运行图片:
在这里插入图片描述

GRU

结构详解

LSTM对很多需要“长期记忆”的任务来说效果显著。但是门控状态太多,导致需要训练更多的参数,使得训练难度加大。因此提出循环门控单元GRU,GRU通过将遗忘门和输入门组合在一起,减少了门的数量,并做了其他改变,在保证记忆能力同时,提升网络训练效率。其组成如下所示:
在这里插入图片描述
而每个GRU单元针对输入进行下面函数的计算:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) 候选隐状态 H t ′ = t a n h ( X t W x h + ( R t ⨀ H t − 1 ) W h h + b h ) 其中 R t ⨀ H t − 1 可以减少以往遗忘状态的影响: 每当 R t 接近 1 时,我们恢复一个传统 R N N 网络; R t 接近 0 时,候选隐状态是以 X t 作为输入的多层感知机的结果 H t = Z t ⨀ H t − 1 + ( 1 − Z t ) ⨀ H t ′ Z t 接近 1 时,模型倾向于保留旧状态; Z t 接近 0 时,倾向于候选隐状态 R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\\ 候选隐状态H_t^{'}=tanh(X_tW_{xh}+(R_t \bigodot H_{t-1})W_{hh}+b_h)\\ 其中R_t \bigodot H_{t-1}可以减少以往遗忘状态的影响:\\ 每当R_t接近1时,我们恢复一个传统RNN网络;\\ R_t接近0时,候选隐状态是以X_t作为输入的多层感知机的结果\\ H_t=Z_t \bigodot H_{t-1}+(1-Z_t) \bigodot H_t^{'}\\ Z_t接近1时,模型倾向于保留旧状态;Z_t接近0时,倾向于候选隐状态 Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)候选隐状态Ht=tanh(XtWxh+(RtHt1)Whh+bh)其中RtHt1可以减少以往遗忘状态的影响:每当Rt接近1时,我们恢复一个传统RNN网络;Rt接近0时,候选隐状态是以Xt作为输入的多层感知机的结果Ht=ZtHt1+(1Zt)HtZt接近1时,模型倾向于保留旧状态;Zt接近0时,倾向于候选隐状态
总之,GRU有以下显著特征:
1、重置门有助于捕获序列中的短期依赖关系
2、更新门有助于捕获序列中的长期依赖关系

GRU的简洁实现

from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

运行结果:

perplexity 1.0, 12581.5 tokens/sec on cpu
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

运行图片:
在这里插入图片描述

常用应用方式

循环神经网络中的不同的输入输出对应情况都有不同的应用方式。其中,一对多的网络结构可以用于图像描述(根据输入的一张图像,自动使用文字描述图像内容);多对一的网络结构可用于文本分类;多对多的网络结构可用于语言翻译。
比如,我们可以用RNN来做手写体分类,可以用LSTM来做中文新闻分类,可以用GRU来进行情感分类等等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/61197.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot整合JMH做优化实战

这段时间接手项目出现各种问题,令人不胜烦扰。吐槽下公司做项目完全靠人堆,大上快上风格注定留下一地鸡毛,修修补补不如想如何提升同事代码水准免得背锅。偶然看到关于JMH对于优化java代码的直观性,于是有了这篇文章,希…

使用docker安装mysql(谷粒商城)

前提准备:已经安装好了centos7 系统和docker容器 1、直接su root使用管理员下载镜像文件; 可以使用docker images查看下载是否成功 docker pull mysql:5.7bug1: 如果出现空间不足,比如报错no space left on device;我…

Oracle 开发篇+Java调用OJDBC访问Oracle数据库

标签:JAVA语言、Oracle数据库、Java访问Oracle数据库释义:OJDBC是Oracle公司提供的Java数据库连接驱动程序 ★ 实验环境 ※ Oracle 19c ※ OJDBC8 ※ JDK 8 ★ Java代码案例 package PAC_001; import java.sql.Connection; import java.sql.ResultSet…

Redis布隆过滤器的原理和应用场景,解决缓存穿透

目录 一、redis 二、布隆过滤器 三、缓存穿透问题 四、布隆过滤器解决缓存穿透 一、redis Redis(Remote Dictionary Server)是一种开源的内存数据存储系统,也是一个使用键值对(Key-Value)方式的高性能数据库。Red…

putty使用记录

在官网下载并安装putty 一、SSH 二、FTP open 192.168.1.118 put -r C:\Users\Administrator\Desktop\test /opt/lanren312/test # 上传(文件夹) get -r /opt/lanren312/test C:\Users\Administrator\Desktop\test2 # 下载(文件夹&#xff…

JS逆向系列之猿人学爬虫第14题-备而后动-勿使有变

文章目录 题目地址参数分析参考jspython 调用往期逆向文章推荐题目地址 https://match.yuanrenxue.cn/match/14题目难度标的是困难,主要难在js混淆部分。 参数分析 初始抓包有无限debugger反调试,可以直接hook 函数构造器过掉无限debugger Function.prototype.__construc…

【React学习】—函数式组件(四)

【React学习】—函数式组件&#xff08;四&#xff09; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><ti…

FreeRTOS( 任务与中断优先级,临界保护)

资料来源于硬件家园&#xff1a;资料汇总 - FreeRTOS实时操作系统课程(多任务管理) 目录 一、中断优先级 1、NVIC基础知识 2、FreeRTOS配置NVIC 3、SVC、PendSV、Systick中断 4、不受FreeRTOS管理的中断 5、STM32CubeMX配置 二、任务优先级 1、任务优先级说明 2、任务…

Boost开发指南-4.3optional

optional 在实际的软件开发过程中我们经常会遇到“无效值”的情况&#xff0c;例如函数并不是总能返回有效值&#xff0c;很多时候函数正确执行了&#xff0c;但结果却不是合理的值。如果用数学语言来解释&#xff0c;就是返回值位于函数解空间之外。 求一个数的倒数&#xf…

75. 颜色分类

题目链接&#xff1a;力扣 解题思路&#xff1a;因为整个nums数组中只有0&#xff0c;1&#xff0c;2三个数组成。对nums升序排序后&#xff0c;0一定都在数组的最左边&#xff0c;2一定都在数组的最右边&#xff0c;1在数组的中间。那么只需要将0移动到数组的左边&#xff0c;…

基于低代码和数字孪生技术的电力运维平台设计

电力能源服务商在为用能企业提供线上服务的时候&#xff0c;不可避免要面对用能企业的各种个性化需求。如果这些需求和想法都要靠平台厂家研发人员来实现&#xff0c;那在周期、成本、效果上都将是无法满足服务运营需要的&#xff0c;这也是目前很多线上能源云平台应用效果不理…

React使用antd的图片预览组件,点击哪个图片就预览哪个的设置

使用了官方推荐的相册模式的预览&#xff0c;但是点击预览之后&#xff0c;每次都是从图片列表的第一张开始预览&#xff0c;而不是点击哪张就从哪张开始预览&#xff1a; 所以这里我就封装了一下&#xff0c;对初始化预览的列表进行了逻辑处理&#xff1a; 当点击开始预览的…