NNDL 作业10 BPTT

习题6-1P 推导RNN反向传播算法BPTT.

我的推导

和PPT结果对比,可得答案没问题

习题6-2 推导公式(6.40)和公式(6.41)中的梯度. 

习题6-3 当使用公式(6.50)作为循环神经网络的状态更新公式时, 分析其可能存在梯度爆炸的原因并给出解决方法. 

解决方法. 

梯度消失

挺好奇门控循环单元的,就看到我室友呕心沥血的巨作,我拜读一下,大呼牛逼!

   门控机制的核心思想是通过一些门控单元,来控制信息的流动和保存。这些门控单元充当了数据的筛选器,可以选择性地让某些信息通过或阻止。主要的门控单元有以下两种:

  1. 遗忘门(Forget Gate):

    • 在LSTM中存在遗忘门,它决定了前一时刻的记忆状态中哪些信息需要被保留,哪些需要被遗忘。
    • 遗忘门的输出是一个在0到1之间的值,用于加权前一时刻的记忆状态。
  2. 输入门(Input Gate):

    • 输入门决定了当前时刻的输入信息中哪些部分需要被添加到记忆状态中。
    • 输入门的输出是一个在0到1之间的值,表示对应位置的输入是否重要。

      这两个门控单元使得LSTM网络能够更好地处理长序列信息,允许网络选择性地记住和遗忘信息。

习题6-2P 设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试. 

import torch
import numpy as np
class RNNCell:def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):self.weight_ih = weight_ihself.weight_hh = weight_hhself.bias_ih = bias_ihself.bias_hh = bias_hhself.x_stack = []  # 存储输入样本self.dx_list = []  # 存储反向传播计算得到的输入梯度self.dw_ih_stack = []  # 存储反向传播计算得到的输入权重梯度self.dw_hh_stack = []  # 存储反向传播计算得到的隐藏层权重梯度self.db_ih_stack = []  # 存储反向传播计算得到的输入偏置梯度self.db_hh_stack = []  # 存储反向传播计算得到的隐藏层偏置梯度self.prev_hidden_stack = []  # 存储前一时刻的隐藏状态self.next_hidden_stack = []  # 存储当前时刻的隐藏状态self.prev_dh = None  # 临时缓存,用于存储前一时刻的隐藏状态梯度def __call__(self, x, prev_hidden):self.x_stack.append(x)next_h = np.tanh(np.dot(x, self.weight_ih.T)+ np.dot(prev_hidden, self.weight_hh.T)+ self.bias_ih + self.bias_hh)  # 前向传播计算隐藏状态self.prev_hidden_stack.append(prev_hidden)self.next_hidden_stack.append(next_h)self.prev_dh = np.zeros(next_h.shape)  # 清空隐藏状态梯度缓存return next_hdef backward(self, dh):x = self.x_stack.pop()prev_hidden = self.prev_hidden_stack.pop()next_hidden = self.next_hidden_stack.pop()d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)  # 计算当前时刻的隐藏状态梯度self.prev_dh = np.dot(d_tanh, self.weight_hh)  # 更新前一时刻隐藏状态梯度缓存dx = np.dot(d_tanh, self.weight_ih)  # 计算输入梯度self.dx_list.insert(0, dx)dw_ih = np.dot(d_tanh.T, x)  # 计算输入权重梯度self.dw_ih_stack.append(dw_ih)dw_hh = np.dot(d_tanh.T, prev_hidden)  # 计算隐藏层权重梯度self.dw_hh_stack.append(dw_hh)self.db_ih_stack.append(d_tanh)  # 存储输入偏置梯度self.db_hh_stack.append(d_tanh)  # 存储隐藏层偏置梯度return self.dx_listif __name__ == '__main__':np.random.seed(123)torch.random.manual_seed(123)np.set_printoptions(precision=6, suppress=True)# 创建一个PyTorch的RNN模型rnn_PyTorch = torch.nn.RNN(4, 5).double()# 使用PyTorch参数初始化一个对应的NumPy RNN模型rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),rnn_PyTorch.all_weights[0][1].data.numpy(),rnn_PyTorch.all_weights[0][2].data.numpy(),rnn_PyTorch.all_weights[0][3].data.numpy())nums = 3x3_numpy = np.random.random((nums, 3, 4))  # 随机生成输入样本x3_tensor = torch.tensor(x3_numpy, requires_grad=True)h3_numpy = np.random.random((1, 3, 5))  # 随机生成初始隐藏状态h3_tensor = torch.tensor(h3_numpy, requires_grad=True)dh_numpy = np.random.random((nums, 3, 5))  # 随机生成隐藏状态梯度dh_tensor = torch.tensor(dh_numpy, requires_grad=True)h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)  # PyTorch前向传播h_numpy_list = []h_numpy = h3_numpy[0]for i in range(nums):h_numpy = rnn_numpy(x3_numpy[i], h_numpy)  # NumPy前向传播h_numpy_list.append(h_numpy)h3_tensor[0].backward(dh_tensor)  # PyTorch反向传播for i in reversed(range(nums)):rnn_numpy.backward(dh_numpy[i])  # NumPy反向传播# 打印NumPy和PyTorch的输出结果进行对比print("numpy_hidden :\n", np.array(h_numpy_list))print("torch_hidden :\n", h3_tensor[0].data.numpy())print("-----------------------------------------------")# 打印NumPy和PyTorch的输入梯度进行对比print("dx_numpy :\n", np.array(rnn_numpy.dx_list))print("dx_torch :\n", x3_tensor.grad.data.numpy())print("------------------------------------------------")# 打印NumPy和PyTorch的输入权重梯度进行对比print("dw_ih_numpy :\n",np.sum(rnn_numpy.dw_ih_stack, axis=0))print("dw_ih_torch :\n",rnn_PyTorch.all_weights[0][0].grad.data.numpy())print("------------------------------------------------")# 打印NumPy和PyTorch的隐藏层权重梯度进行对比print("dw_hh_numpy :\n",np.sum(rnn_numpy.dw_hh_stack, axis=0))print("dw_hh_torch :\n",rnn_PyTorch.all_weights[0][1].grad.data.numpy())print("------------------------------------------------")# 打印NumPy和PyTorch的输入偏置梯度进行对比print("db_ih_numpy :\n",np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))print("db_ih_torch :\n",rnn_PyTorch.all_weights[0][2].grad.data.numpy())print("-----------------------------------------------")print("db_hh_numpy :\n",np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))print("db_hh_torch :\n",rnn_PyTorch.all_weights[0][3].grad.data.numpy())

心得体会:

1、学习这东西果然是孰能生巧,推导第一个题时候,还在认认真真的做,第二个和第三个题发现完全可以类比过去,就最后一个偏导不一样,所以直接将这钱的式子略加修改即可

2、卷积神经网络和RNN的不同之处就是,卷积神经并不权重共享,每一层的参数都是不一样的,而RNN是权重共享的。

3、代码又是看上一届的,不过要自己过一遍,不能囫囵吞枣

4、温故而知新,好好复习了一下之前学的两个激活函数sigmoid 和relu

选择relu等梯度大部分落在常数上的激活函数

relu函数的导数在正数部分是恒等于1的,因此在深层网络中使用relu激活函数就不会导致梯度消失和爆炸的问题。并且tanh 和 sigmoid 激活函数需要使用指数计算, 而ReLU只需要max(),因此他计算上更简单,计算成本也更低 。

 参考链接:

NNDL 作业9:分别使用numpy和pytorch实现BPTT-CSDN博客

【23-24 秋学期】NNDL 作业10 BPTT-CSDN博客

L5W1作业1 手把手实现循环神经网络-CSDN博客

NNDL 作业10 BPTT [HBU]-CSDN博客

梯度爆炸与梯度消失是什么?有什么影响?如何解决?_梯度爆炸和梯度消失-CSDN博客

原来ReLU这么好用?一文带你深度了解ReLU激活函数-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/277790.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Spring的AOP】Spring的简介、案例与工作流程

文章目录 1. 什么是AOP2. AOP的核心概念3. AOP的入门案例原始代码思路分析第一步:导入坐标第二步:制作连接点(原始操作,Dao接口与实现类)第三步:制作共性功能(通知类与通知)第四步&a…

继续看回溯问题

关卡名 继续看回溯问题 我会了✔️ 内容 1.复习递归和N叉树,理解相关代码是如何实现的 ✔️ 2.理解回溯到底怎么回事 ✔️ 3.掌握如何使用回溯来解决二叉树的路径问题 ✔️ 1 复原IP地址 这也是一个经典的分割类型的回溯问题。LeetCode93.有效IP地址正好由四…

TrustZone之完成器:外围设备和内存

到目前为止,在本指南中,我们集中讨论了处理器,但TrustZone远不止是一组处理器功能。要充分利用TrustZone功能,我们还需要系统其余部分的支持。以下是一个启用了TrustZone的系统示例: 本节探讨了该系统中的关键组件以及它们在TrustZone中的作用。 完成器:外围设备…

概念解读稳定性保障

什么是稳定 百度百科关于稳定的定义: “稳恒固定;没有变动。” 很明显这里的“稳定”是相对的,通常会有参照物,例如 A 车和 B 车保持相同速度同方向行驶,达到相对平衡相对稳定的状态。 那么软件质量的稳定是指什么…

PhotoMaker——通过堆叠 ID 嵌入定制逼真的人像照片

论文网址链接:https://arxiv.org/abs/2312.04461 详情网址链接:PhotoMaker 开源代码网址链接:GitHub - TencentARC/PhotoMaker: PhotoMaker 文本到图像AI生成的最新进展在根据给定文本提示合成逼真的人类照片方面取得了显着进展。然而&#…

UDS DTC老化机制

文章目录 简介基本概念1、操作周期(Operation Cyle)2、错误计数(FDC, Fault Detection Counter)3、确认阈值(Confirmation Threshold)4、老化计数(Aging Counter)5、老化阈值(Aging Threshold) 老化条件非排放 DTC 示例参考 简介 当某个DTC在一定次数的操作循环内,…

蓝桥杯专题-真题版含答案-【扑克牌排列】【放麦子】【纵横放火柴游戏】【顺时针螺旋填入】

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分…

mac电脑html文件 局域网访问

windows html文件 局域网访问 参考 https://blog.csdn.net/qq_38935512/article/details/103271291mac电脑html文件 局域网访问 开发工具vscode 安装vscode插件 Live Server 完成后打开项目的html 右键使用Live Server打开页面 效果如下,使用本地ip替换http://12…

湖农大邀请赛shell_rce漏洞复现

湖农大邀请赛 shell_rce 复现 在 2023 年湖南农业大学邀请赛的线上初赛中&#xff0c;有一道 shell_rce 题&#xff0c;本文将复现该题。 题目内容&#xff0c;打开即是代码&#xff1a; <?phpclass shell{public $exp;public function __destruct(){$str preg_replace…

vue文件下载请求blob文件流token失效的问题

页面停留很久token失效没有刷新页面&#xff0c;这时候点击下载依然可以导出文件&#xff0c;但是文件打不开且接口实际上返回的是401&#xff0c;这是因为文件下载的方式通过window创建a标签的形式打开的&#xff0c;并没有判断token失效问题 const res await this.$axios.…

python基本数据类型(一)-字符串

1.字符串 字符串就是一系列字符&#xff0c;在Python中&#xff0c;用引号括起的都是字符串&#xff0c;其中的引号可以是单引号&#xff0c;也可以是双引号&#xff0c;如下所示&#xff1a; "This is a string." This is also a string.这种灵活性让你能够在字符…

【Idea】SpringBoot项目中,jar包引用冲突异常的排查 / SM2算法中使用bcprov-jdk15to18的报错冲突问题

问题描述以及解决方法&#xff1a; 项目中使用了bcprov-jdk15to18 pom依赖&#xff0c;但是发现代码中引入的版本不正确。 追溯代码发现版本引入的是bcprov-jdk15on&#xff0c;而不是bcprov-jdk15to18&#xff0c;但是我找了半天pom依赖也没有发现有引入bcprov-jdk15on依赖。…