《动手学深度学习(PyTorch版)》笔记8.7

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter8 Recurrent Neural Networks

8.7 Backpropagation Through Time

通过时间反向传播(backpropagation through time,BPTT)是循环神经网络中反向传播技术的一个特定应用,它要求我们将循环神经网络的计算图一次展开一个时间步,以获得模型变量和参数之间的依赖关系,然后,基于链式法则,应用反向传播来计算和存储梯度。由于序列可能相当长,因此依赖关系也可能相当长,在下文中,我们将阐明计算过程会发生什么以及如何在实践中解决它们。

8.7.1 RNN’s Gradient Analysis

我们从一个描述循环神经网络工作原理的简化模型开始,此模型忽略了隐状态的特性及其更新方式的细节,且其数学表示没有明确地区分标量、向量和矩阵。在这个简化模型中,我们将时间步 t t t的隐状态表示为 h t h_t ht,输入表示为 x t x_t xt,输出表示为 o t o_t ot,分别使用 w h w_h wh w o w_o wo来表示隐藏层和输出层的权重。每个时间步的隐状态和输出可以写为:

h t = f ( x t , h t − 1 , w h ) , o t = g ( h t , w o ) , (2) \begin{aligned}h_t &= f(x_t, h_{t-1}, w_h),\\o_t &= g(h_t, w_o),\end{aligned}\tag{2} htot=f(xt,ht1,wh),=g(ht,wo),(2)

其中 f f f g g g分别是隐藏层和输出层的变换。因此,我们有一个链 { … , ( x t − 1 , h t − 1 , o t − 1 ) , ( x t , h t , o t ) , … } \{\ldots, (x_{t-1}, h_{t-1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\} {,(xt1,ht1,ot1),(xt,ht,ot),},它们通过循环计算彼此依赖。前向传播相当简单,一次一个时间步的遍历三元组 ( x t , h t , o t ) (x_t, h_t, o_t) (xt,ht,ot),然后通过一个目标函数在所有 T T T个时间步内评估输出 o t o_t ot和对应的标签 y t y_t yt之间的差异:

L ( x 1 , … , x T , y 1 , … , y T , w h , w o ) = 1 T ∑ t = 1 T l ( y t , o t ) . L(x_1, \ldots, x_T, y_1, \ldots, y_T, w_h, w_o) = \frac{1}{T}\sum_{t=1}^T l(y_t, o_t). L(x1,,xT,y1,,yT,wh,wo)=T1t=1Tl(yt,ot).

对于反向传播,按照链式法则:

∂ L ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ o t ∂ g ( h t , w o ) ∂ h t ∂ h t ∂ w h . \begin{aligned}\frac{\partial L}{\partial w_h} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_h} \\& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t} \frac{\partial h_t}{\partial w_h}.\end{aligned} whL=T1t=1Twhl(yt,ot)=T1t=1Totl(yt,ot)htg(ht,wo)whht.

在上式乘积的第一项和第二项很容易计算,而第三项比较棘手,因为我们需要循环地计算参数 w h w_h wh h t h_t ht的影响。根据式(2), h t h_t ht既依赖于 h t − 1 h_{t-1} ht1又依赖于 w h w_h wh,其中 h t − 1 h_{t-1} ht1的计算也依赖于 w h w_h wh。因此,使用链式法则产生:

∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . (3) \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}.\tag{3} whht=whf(xt,ht1,wh)+ht1f(xt,ht1,wh)whht1.(3)

为了导出上述梯度,假设我们有三个序列 { a t } , { b t } , { c t } \{a_{t}\},\{b_{t}\},\{c_{t}\} {at},{bt},{ct},当 t = 1 , 2 , … t=1,2,\ldots t=1,2,时,序列满足 a 0 = 0 a_{0}=0 a0=0 a t = b t + c t a t − 1 a_{t}=b_{t}+c_{t}a_{t-1} at=bt+ctat1。对于 t ≥ 1 t\geq 1 t1,就很容易得出:

a t = b t + ∑ i = 1 t − 1 ( ∏ j = i + 1 t c j ) b i . (4) a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.\tag{4} at=bt+i=1t1(j=i+1tcj)bi.(4)

基于下列公式替换 a t a_t at b t b_t bt c t c_t ct

a t = ∂ h t ∂ w h , b t = ∂ f ( x t , h t − 1 , w h ) ∂ w h , c t = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 , \begin{aligned}a_t &= \frac{\partial h_t}{\partial w_h},\\ b_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}, \\ c_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}},\end{aligned} atbtct=whht,=whf(xt,ht1,wh),=ht1f(xt,ht1,wh),

则:

∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x j , h j − 1 , w h ) ∂ h j − 1 ) ∂ f ( x i , h i − 1 , w h ) ∂ w h . (5) \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}.\tag{5} whht=whf(xt,ht1,wh)+i=1t1(j=i+1thj1f(xj,hj1,wh))whf(xi,hi1,wh).(5)

虽然我们可以使用链式法则递归地计算 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh,但当 t t t很大时这个链就会变得很长,在实践中是不可取的。

8.7.1.1 Cutting Off Time Steps

我们也可以在 τ \tau τ步后截断式(5)中的求和计算,即将求和终止为 ∂ h t − τ / ∂ w h \partial h_{t-\tau}/\partial w_h htτ/wh,这种截断是通过在给定数量的时间步之后分离梯度来实现的。这样做导致该模型主要侧重于短期影响,而不是长期影响,在现实中是可取的。

8.7.1.2 Randomly Truncating

我们也可以用一个随机变量替换 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh,这个随机变量通过序列 ξ t \xi_t ξt实现。序列预定义了 0 ≤ π t ≤ 1 0 \leq \pi_t \leq 1 0πt1,其中 P ( ξ t = 0 ) = 1 − π t P(\xi_t = 0) = 1-\pi_t P(ξt=0)=1πt P ( ξ t = π t − 1 ) = π t P(\xi_t = \pi_t^{-1}) = \pi_t P(ξt=πt1)=πt,因此 E [ ξ t ] = 1 E[\xi_t] = 1 E[ξt]=1。使用 z t z_t zt来替换式(3)中的梯度 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh得到:

z t = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ξ t ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . z_t= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. zt=whf(xt,ht1,wh)+ξtht1f(xt,ht1,wh)whht1.

ξ t \xi_t ξt的定义中推导出来 E [ z t ] = ∂ h t / ∂ w h E[z_t] = \partial h_t/\partial w_h E[zt]=ht/wh,当 ξ t = 0 \xi_t = 0 ξt=0时,递归计算终止在这个 t t t时间步。这导致了不同长度序列的加权和,其中长序列出现的很少,所以需要适当地加大权重。

在这里插入图片描述

上图说明了当基于循环神经网络使用通过时间反向传播分析数据集的三种策略:

  • 第一行采用随机截断,方法是将文本划分为不同长度的片断;
  • 第二行采用常规截断,方法是将文本分解为相同长度的子序列;
  • 第三行采用通过时间的完全反向传播,结果是产生了在计算上不可行的表达式。

虽然随机截断在理论上具有吸引力,但由于多种因素在实践中并不总比常规截断更好。首先,在对过去若干个时间步经过反向传播后,观测结果足以捕获实际的依赖关系。其次,增加的方差抵消了时间步数越多梯度越精确的事实。第三,模型可能需要经过一定程度的正则化,以防止过拟合。通过常规截断方法,时间反向传播会引入一定程度的正则化效果,有助于控制模型的复杂度,并提高其泛化能力。

8.7.2 Details of BPTT

下面将展示如何计算目标函数相对于所有模型参数的梯度。简单起见,我们考虑一个没有偏置参数的RNN,其在隐藏层中的激活函数使用恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。对于时间步 t t t,设单个样本的输入及其对应的标签分别为 x t ∈ R d \mathbf{x}_t \in \mathbb{R}^d xtRd y t y_t yt。计算隐状态 h t ∈ R h \mathbf{h}_t \in \mathbb{R}^h htRh和输出 o t ∈ R q \mathbf{o}_t \in \mathbb{R}^q otRq的方式为:

h t = W h x x t + W h h h t − 1 , o t = W q h h t , \begin{aligned}\mathbf{h}_t &= \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1},\\ \mathbf{o}_t &= \mathbf{W}_{qh} \mathbf{h}_{t},\end{aligned} htot=Whxxt+Whhht1,=Wqhht,

l ( o t , y t ) l(\mathbf{o}_t, y_t) l(ot,yt)表示时间步 t t t处的损失函数,则目标函数的总体损失是:

L = 1 T ∑ t = 1 T l ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t). L=T1t=1Tl(ot,yt).

模型绘制一个计算图如下所示。

在这里插入图片描述

上图中的模型参数是 W h x \mathbf{W}_{hx} Whx W h h \mathbf{W}_{hh} Whh W q h \mathbf{W}_{qh} Wqh。通常,训练该模型需要分别计算: ∂ L / ∂ W h x \partial L/\partial \mathbf{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \mathbf{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \mathbf{W}_{qh} L/Wqh。根据上图中的依赖关系,我们可以沿箭头的相反方向遍历计算图,依次计算和存储梯度。为了灵活地表示链式法则中不同形状的矩阵、向量和标量的乘法,我们继续使用4.7中所述的 prod \text{prod} prod运算符。

首先有:

∂ L ∂ o t = ∂ l ( o t , y t ) T ⋅ ∂ o t ∈ R q . (6) \frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q.\tag{6} otL=Totl(ot,yt)Rq.(6)

接着得到:

∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ ∈ R q × h \frac{\partial L}{\partial \mathbf{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top\in \mathbb{R}^{q \times h} WqhL=t=1Tprod(otL,Wqhot)=t=1TotLhtRq×h

其中 ∂ L / ∂ o t \partial L/\partial \mathbf{o}_t L/ot是由式(6)给出的。

接下来,如上图所示,在最后的时间步 T T T,目标函数 L L L仅通过 o T \mathbf{o}_T oT依赖于隐状态 h T \mathbf{h}_T hT。因此,我们通过使用链式法可以很容易地得到梯度$\partial L/\partial \mathbf{h}_T :

∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T ∈ R h . (7) \frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}\in \mathbb{R}^h.\tag{7} hTL=prod(oTL,hToT)=WqhoTLRh.(7)

隐状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \mathbf{h}_t \in \mathbb{R}^h L/htRh在任何 t < T t < T t<T时都可以递归地计算为:

∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t . (8) \frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}.\tag{8} htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL.(8)

对于任何时间步 1 ≤ t ≤ T 1 \leq t \leq T 1tT展开递归计算得:

∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . (9) \frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_{hh}^\top\right)}^{T-i} \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}.\tag{9} htL=i=tT(Whh)TiWqhoT+tiL.(9)

我们可以从式(9)中看到,这个简单的线性例子已经陷入到 W h h ⊤ \mathbf{W}_{hh}^\top Whh的潜在的非常大的幂。在这个幂中,小于1的特征值将会消失,大于1的特征值将会发散。这在数值上是不稳定的,表现形式为梯度消失或梯度爆炸,解决此问题的一种方法如8.7.1中所述。

最后,应用链式规则得:

∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ ∈ R h × d , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ ∈ R h × d , \begin{aligned} \frac{\partial L}{\partial \mathbf{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top\in \mathbb{R}^{h \times d},\\ \frac{\partial L}{\partial \mathbf{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top\in \mathbb{R}^{h \times d}, \end{aligned} WhxLWhhL=t=1Tprod(htL,Whxht)=t=1ThtLxtRh×d,=t=1Tprod(htL,Whhht)=t=1ThtLht1Rh×d,

其中 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t L/ht由式(7)和式(8)递归计算得到,是影响数值稳定性的关键量。在训练过程中一些中间值会被存储,以避免重复计算,例如存储 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t L/ht,以便在计算 ∂ L / ∂ W h x \partial L / \partial \mathbf{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L / \partial \mathbf{W}_{hh} L/Whh时使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/471029.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python算法题集_二叉树的直径

Python算法题集_二叉树的直径 题543&#xff1a;二叉树的直径1. 示例说明2. 题目解析- 题意分解- 优化思路- 测量工具 3. 代码展开1) 标准求解【DFS字典引用】2) 改进版一【DFS全局变量】3) 改进版二【DFS递归返回】 4. 最优算法 本文为Python算法题集之一的代码示例 题543&am…

Linux内核-时间子系统(时钟中断)专题汇总

文章目录 概要一、专题汇总1.1、优秀系列博文1.2、时间子系统1.3、高精度定时器hrtimer1.4、RTC硬件芯片驱动 概要 中断机制是计算机系统的重要组成部分&#xff0c;在Linux中也不例外&#xff0c;中断按照来源分为硬中断和软中断&#xff0c;而硬中断根据硬件范围分为外中断和…

GPT翻译网站的加载与使用

Sider: ChatGPT侧边栏 GPTs, GPT-4 Turbo, 联网, 绘图 sider.ai https://chromewebstore.google.com/detail/sider-chatgpt%E4%BE%A7%E8%BE%B9%E6%A0%8F-gpts-g/difoiogjjojoaoomphldepapgpbgkhkb?hlzh-CN 加入与移除 第二个翻译网站 https://chromewebstore.google.com/…

MATLAB知识点:poissrnd函数(★★☆☆☆)生成泊松分布的随机数

讲解视频&#xff1a;可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇&#xff08;数学建模清风主讲&#xff0c;适合零基础同学观看&#xff09;_哔哩哔哩_bilibili 节选自第3章&#xff1a;课后习题讲解中拓展的函数 在讲解第三…

Sibelius安装包免费下载激活指南,西贝柳斯,专业作曲打谱软件

Sibelius来自芬兰音乐巨匠西贝柳斯的故乡&#xff0c;被誉为世界上最强的五线谱软件。Sibelius功能全面、音色音质精准受到广大作曲家的喜爱。其乐谱记号十分全面&#xff0c;所有的乐谱都可以应付自如&#xff0c;Sibelius可以迅速完成作曲、编曲、发布任务&#xff0c;轻松开…

react+ts【项目实战一】配置项目/路由/redux

文章目录 1、项目搭建1、创建项目1.2 配置项目1.2.1 更换icon1.2.2 更换项目名称1.2.1 配置项目别名 1.3 代码规范1.3.1 集成editorconfig配置1.3.2 使用prettier工具 1.4 项目结构1.5 对css进行重置1.6 注入router1.7 定义TS组件的规范1.8 创建代码片段1.9 二级路由和懒加载1.…

微服务中台架构的设计与实现

本文将探讨微服务中台架构的设计与实现&#xff0c;介绍如何通过微服务的方式进行系统拆分和组合&#xff0c;构建灵活、可扩展且易于维护的中台架构&#xff0c;以加速企业的数字化转型和提升竞争力。 ## 1. 引言 随着企业规模的不断扩大和业务的日益复杂化&#xff0c;传统…

飞天使-k8s知识点18-kubernetes实操3-pod的生命周期

文章目录 探针的生命周期流程图prestop 探针的生命周期 docker 创建&#xff1a;在创建阶段&#xff0c;你需要选择一个镜像来运行你的应用。这个镜像可以是公开的&#xff0c;如 Docker Hub 上的镜像&#xff0c;也可以是你自己创建的自定义镜像。创建自己的镜像通常需要编写一…

Tuxera NTFS2024版本的文件操作功能有哪些特点?

Tuxera NTFS通过集成先进的文件系统驱动程序和算法&#xff0c;实现了对多种文件系统的全面支持。具体来说&#xff0c;它具备以下功能和特点&#xff0c;使其能够支持多种文件系统&#xff1a; Tuxera NTFS2024下载如下: https://wm.makeding.com/iclk/?zoneid58824 先进的…

wechat协议接口免费分享(价值5w)

几乎涵盖所有功能&#xff0c;仅供学习交流使用&#xff01; 接口地址 https://apifox.com/apidoc/shared-86280587-c0f0-480a-85de-ee292b4aae82/doc-3721193

RK3568平台开发系列讲解(实验篇)杂项设备驱动实验

🚀返回专栏总目录 文章目录 一、什么是杂项设备驱动二、杂项设备的注册和卸载三、杂项设备驱动实验代码沉淀、分享、成长,让自己和他人都能有所收获!😄 一、什么是杂项设备驱动 在 Linux 中,把无法归类的五花八门的设备定义成杂项设备。相较于字符设备,杂项设备有以下两…

Activation of network connection failed(ubuntu连不上网)

ubuntu连不上网&#xff0c;看了好几个方法找到个有用的记录一下 1. 还原默认设置 2. 更改适配器&#xff1a;加上vmware bridge protocol