RNN LSTM

参考资料:

  • 《机器学习2022》李宏毅
  • 史上最详细循环神经网络讲解(RNN/LSTM/GRU) - 知乎 (zhihu.com)
  • LSTM如何来避免梯度弥散和梯度爆炸? - 知乎 (zhihu.com)

1 RNN 的结构

首先考虑这样一个 slot filling 问题:

image-20230702152122630

注意到,上图中 Taipei 的输出为 destination。如果我们只是单纯地将每个词向量输入到一个神经网络中,那么对于"leave Taipei on …" 这句话,模型对 Taipei 的输出也会是 destination,但我们希望它是 departure。要实现这一目的,必须要引入当前向量与上下文的关系,于是就有了循环神经网络(RNN):

image-20230702163833223

注意到,RNN 与一般的神经网络的主要区别在与将隐层的上一次输出保存并作为本次的输入,即:
O t = g ( V ⋅ S t ) S t = f ( U ⋅ X t + W ⋅ S t − 1 ) \begin{align} O_t&=g(V\cdot S_t)\notag\\ S_t&=f(U\cdot X_t+W\cdot S_{t-1})\notag \end{align} OtSt=g(VSt)=f(UXt+WSt1)

矩阵 U , W , V U,W,V U,W,V 即为 RNN 的参数,与 t t t 无关。

引入时间这一维度,RNN 可以表示为如下结构:

image-20230702152604317

如果采用双向 RNN ,则每个向量都可以充分地考虑到上下文。

2 RNN 的梯度消失与梯度爆炸

考虑这样一个简单的 RNN 结构:

假设神经元没有激活函数(激活函数的导数一般是恒 < 1 <1 <1 的),则有:
S 1 = W x X 1 + W s S 0 + b 1 O 1 = W o S 1 + b 2 S 2 = W x X 2 + W s S 1 + b 1 O 2 = W o S 2 + b 2 S 3 = W x X 3 + W s S 2 + b 1 O 3 = W o S 3 + b 2 \begin{align} S_1&=W_xX_1+W_sS_0+b_1\quad&O_1=W_oS_1+b_2\notag\\ S_2&=W_xX_2+W_sS_1+b_1\quad&O_2=W_oS_2+b_2\notag\\ S_3&=W_xX_3+W_sS_2+b_1\quad&O_3=W_oS_3+b_2\notag\\ \end{align} S1S2S3=WxX1+WsS0+b1=WxX2+WsS1+b1=WxX3+WsS2+b1O1=WoS1+b2O2=WoS2+b2O3=WoS3+b2
t 3 t_3 t3 时刻的损失函数为 L 3 L_3 L3 ,则有:
∂ L 3 ∂ W o = ∂ L 3 ∂ O 3 ∂ O 3 ∂ W o ∂ L 3 ∂ W x = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ( X 3 + W s ( X 2 + W s X 1 ) ) ∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ( S 2 + W s ( S 1 + W s S 0 ) ) \begin{align} \frac{\partial L_3}{\partial W_o}&=\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial W_o}\notag\\ \frac{\partial L_3}{\partial W_x}&=\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}\bigg(X_3+W_s\Big(X_2+W_sX_1\Big)\bigg)\notag\\ \frac{\partial L_3}{\partial W_s}&=\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}\bigg(S_2+W_s\Big(S_1+W_sS_0\Big)\bigg)\notag\\ \end{align} WoL3WxL3WsL3=O3L3WoO3=O3L3S3O3(X3+Ws(X2+WsX1))=O3L3S3O3(S2+Ws(S1+WsS0))

这部分的公式和参考资料里的不太一样,但我感觉参考资料里的公式不太严格吧?

所以,任意时刻损失函数对 W x , W s W_x,W_s Wx,Ws 的偏导为:
∂ L t ∂ W x = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∑ k = 1 t W s t − k X k ∂ L t ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∑ k = 1 t W s t − k S k − 1 \begin{align} \frac{\partial L_t}{\partial W_x}&=\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}\sum\limits_{k=1}^{t}W_s^{t-k}X_k\notag\\ \frac{\partial L_t}{\partial W_s}&=\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}\sum\limits_{k=1}^{t}W_s^{t-k}S_{k-1}\notag\\ \end{align} WxLtWsLt=O3L3S3O3k=1tWstkXk=O3L3S3O3k=1tWstkSk1
W s ∈ ( 0 , 1 ) W_s\in(0,1) Ws(0,1) 时,损失函数对 W x , W s W_x,W_s Wx,Ws 的偏导会逐渐“遗忘”距离较远的梯度,所以模型很难学习到距离较远的依赖关系。

W s > 1 W_s>1 Ws>1 时,前面的梯度对当前的影响会随着距离增加而指数级增大,甚至变成 NaN.

3 LSTM

LSTM(Long Short-term Memory) 是 RNN 的变体,并且已经逐渐成为了 RNN 的代名词,其基本结构如下图所示:

image-20230702153438853

相比普通的 RNN ,LSTM增加了输入门、输出门和遗忘门。

image-20230702185523392

上图中, z f , z i , z , z o z_f,z_i,z,z_o zf,zi,z,zo 均有相应的权值矩阵乘上拼接后的输入向量得到。

LSTM 可以解决 RNN 梯度消失的问题,因为如果不考虑遗忘门,距离再远的梯度也可以通过 c i → ⋯ → c t − 1 → c t c_{i}\rightarrow\cdots\rightarrow c_{t-1}\rightarrow c_t cict1ct 这条路径无损地传递到到当前的梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/9849.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

postgresql内核分析 spinlock与lwlock原理与实现机制

​专栏内容&#xff1a; postgresql内核源码分析 手写数据库toadb 并发编程 个人主页&#xff1a;我的主页 座右铭&#xff1a;天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物. 概述 在postgresql 中&#xff0c;有大量的并发同步&#xff0…

硬件速攻-激光测距传感器VL530L0X

介绍 VL53L0X是一种时间飞行&#xff08;Time-of-Flight&#xff0c;TOF&#xff09;技术的激光测距传感器芯片。TOF技术利用红外激光发射器发送短脉冲光束&#xff0c;并通过测量光束从传感器到目标物体返回的时间来计算距离。 外观 现象 串口打印数据 接线 VCC 3.3V G…

MFC 单文档模式

Doc类利用自带框架存数据 void CCADDoc::Serialize(CArchive& ar) {if (ar.IsStoring()){// TODO: 在此添加存储代码//保存数据到文件ar << m_nShapeCount;for (int i 0; i < m_arrShapes.GetSize(); i){CShape* pShape NULL;pShape (CShape*)m_arrShapes[i];…

Prometheus 时序数据

一 时序索引 Prometheus 存储的是时序数据&#xff0c;时间戳&#xff08;timestamp&#xff09;来源于服务端本地的系 统时间。Prometheus 使用 Unix 时间戳&#xff08;即自 1970 年 1 月 1 日 00:00:00 UTC 起经过的秒数&#xff09;表示时间。 数 据 格 式 &#xff1a; …

C++—异常与类型转换、大小端存储、不使用额外空间的情况下交换两个数

异常 常见的异常包括&#xff1a;数组下标越界&#xff0c;除法计算的时候除数为0&#xff0c;动态分配空间时空间不足。 try&#xff0c;throw&#xff0c;catch #include <iostream> using namespace std; int main() {double m 1, n 0;try {cout << "b…

基于 RK3399+fpga 的 VME 总线控制器设计(一)总体设计

2.1 需求分析及技术指标 2.1.1 需求分析 VME 总线控制器需要实现数据传输、中断处理、测量显示等功能。同时还需 要具有操作系统、底层驱动程序以及功能接口等&#xff0c;以方便用户进行上层应用软件开 发及使用。 本课题需要实现 VME 控制器的国产化开发&#xff0…

FPGA实验六:PWM信号调制器设计

目录 一、实验目的 二、设计要求 三、实验代码 1.顶层文件代码 2.仿真文件部分代码 3.系统工程文件 四、实验结果及分析 1、引脚锁定 2、仿真波形及分析 3、下载测试结果及分析 五、实验心得 一、实验目的 &#xff08;1&#xff09;掌握通信信号调制过程及实现原理…

【模式识别目标检测】——模式识别技术车牌检测应用

目录 引入 一、模式识别主要方法 1、统计模式识别 2、基于隐马尔可夫模型识别 3、模糊模式识别 4、人工神经网络模式识别 总结 二、模式识别应用 1、车牌定位 2、车牌识别 参考文献&#xff1a; 引入 人在观察事物或现象时&#xff0c;常寻找它与其他事物或现象不同…

嵌入式Linux开发实操(二):uboot+kernal

要理解如何进行嵌入式Linux编程,必须知道系统启动引导过程: 上电后,芯片将开始执行其启动固件,它就是uboot,主要目的是加载一个程序,然后在芯片上运行它,uboot通过查看引导模式寄存器、保险丝Fuses或GPIO引脚的状态来确定从哪里加载程序,比如从从eMMC flash启动。 SPL是…

【MySQL体系结构及CetOS7安装MySQL和修改密码】

MySQL体系结构及安装MySQL MySQL体系结构CentOS7安装MySQL四种方法1、离线安装2、在线安装3、通用二级制方式4、容器方式安装 设置及修改密码忘记密码恢复 MySQL体系结构 MySQL是一种常用的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;其体系结构包括以下&…

更改VS code Jupyter 插件的默认快捷键

更改vscode 中Jupyter插件的默认快捷键&#xff0c;解放插入空行的系统快捷键 替换Jupyter默认快捷键 更改vscode 中Jupyter插件的默认快捷键&#xff0c;解放插入空行的系统快捷键打开keyboard shortcuts 设置方法一方法二 更换快捷键 end Jupyter 插件很好的在VS code中集成了…

华为OD机试真题B卷 Python 实现【整理扑克牌】,附详细解题思路

目录 一、题目描述步骤1步骤2步骤3 二、输入描述三、输出描述四、解题思路五、Python算法源码六、效果展示1、输入2、输出3、说明 一、题目描述 给定一组数字&#xff0c;表示扑克牌的牌面数字&#xff0c;忽略扑克牌的花色&#xff0c;请按如下规则对这一组扑克牌进行整理&am…