深度学习 精选笔记(7)前向传播、反向传播和计算图

学习参考:

  • 动手学深度学习2.0
  • Deep-Learning-with-TensorFlow-book
  • pytorchlightning

①如有冒犯、请联系侵删。
②已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。
③非常推荐上面(学习参考)的前两个教程,在网上是开源免费的,写的很棒,不管是开始学还是复习巩固都很不错的。

深度学习回顾,专栏内容来源多个书籍笔记、在线笔记、以及自己的感想、想法,佛系更新。争取内容全面而不失重点。完结时间到了也会一直更新下去,已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。所有文章涉及的教程都会写在开头、一起学习一起进步。

前向传播用于计算模型的预测输出,反向传播用于根据预测输出和真实标签之间的误差来更新模型参数。

前向传播和反向传播是神经网络训练中的核心步骤,通过这两个过程,神经网络能够学习如何更好地拟合数据,提高预测准确性。

一、计算图

计算图(Computational Graph)是一种图形化表示方法,用于描述数学表达式中各个变量之间的依赖关系和计算流程。在深度学习和机器学习领域,计算图常用于可视化复杂的数学运算和函数计算过程,尤其是在反向传播算法中的梯度计算过程中被广泛应用。

计算图通常包括两种节点:

  • 计算节点(Compute Nodes):这些节点表示数学运算,如加法、乘法等。计算节点接受输入,并产生输出。
  • 数据节点(Data Nodes):这些节点表示数据或变量,如输入数据、权重、偏置等。

通过连接计算节点和数据节点的边,构建了一个有向图,其中每个节点表示一个操作,边表示数据流向。计算图可以帮助理解复杂的计算过程,特别是在深度学习中涉及大量参数和运算的情况下。

二、前向传播

前向传播(forward propagation或forward pass) 指的是:按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

前向传播(Forward Propagation):

  • 定义:前向传播是指输入数据通过神经网络模型的各层,逐层进行计算并传递至输出层的过程。
  • 作用:在前向传播过程中,输入数据经过神经网络的权重和激活函数的计算,最终得到模型的预测输出。
  • 目的:前向传播的目的是计算模型对输入数据的预测值,为后续的损失函数计算和反向传播提供基础。

1.前向传播的计算图

假设单隐藏层神经网络中,输入样本是 𝐱∈ℝ d, 并且隐藏层不包括偏置项。 这里的中间变量是:
在这里插入图片描述
其中 𝐖(1)∈ℝℎ×𝑑 是隐藏层的权重参数。 将中间变量 𝐳∈ℝℎ 通过激活函数 𝜙 后, 得到长度为 ℎ 的隐藏激活向量是:
在这里插入图片描述
隐藏变量 𝐡也是一个中间变量。 假设输出层的参数只有权重 𝐖(2)∈ℝ𝑞×ℎ, 可以得到输出层变量,它是一个长度为 𝑞 的向量:
在这里插入图片描述
假设损失函数为 𝑙,样本标签为 𝑦,可以计算单个数据样本的损失项,
在这里插入图片描述
根据 𝐿2 正则化的定义,给定超参数 𝜆 ,正则化项为
在这里插入图片描述
其中矩阵的Frobenius范数是将矩阵展平为向量后应用的 𝐿2范数。 最后,模型在给定数据样本上的正则化损失为:
在这里插入图片描述
该函数J就是目标函数。

绘制计算图有助于可视化计算中操作符和变量的依赖关系。

与上述简单网络相对应的计算图, 其中正方形表示变量,圆圈表示操作符。 左下角表示输入,右上角表示输出。 注意显示数据流的箭头方向主要是向右和向上的。
在这里插入图片描述

三、反向传播

反向传播(Backpropagation):

  • 定义:反向传播是指通过计算损失函数对模型参数的梯度(梯度是一个由偏导数组成的向量,表示函数在某一点处的变化率或者斜率方向、也就是在每个自变量方向上的偏导数),从输出层向输入层传播梯度的过程。
  • 作用:在反向传播过程中,根据损失函数计算模型参数的梯度,然后利用梯度下降等优化算法更新模型参数,以减小损失函数的值。
  • 目的:反向传播的目的是根据模型预测与真实标签的误差,调整神经网络中每个参数的值,使模型能够更好地拟合训练数据,并提高在新数据上的泛化能力。

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法。 简言之,该方法根据微积分中的链式规则,按相反的顺序从输出层到输入层遍历网络。 该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。 假设有函数 𝖸=𝑓(𝖷) 和 𝖹=𝑔(𝖸) , 其中输入和输出 𝖷,𝖸,𝖹 是任意形状的张量。 利用链式法则,可以计算 𝖹 关于 𝖷 的导数:

在这里插入图片描述
使用 prod 运算符在执行必要的操作(如换位和交换输入位置)后将其参数相乘。 对于向量,这很简单,它只是矩阵-矩阵乘法。

在前向传播的计算图中,单隐藏层简单网络的参数是 𝐖(1) 和 𝐖(2) 。 反向传播的目的是计算梯度 ∂𝐽/∂𝐖(1)∂𝐽/∂𝐖(2) 。为此,应用链式法则,依次计算每个中间变量和参数的梯度。 计算的顺序与前向传播中执行的顺序相反,因为需要从计算图的结果开始,并朝着参数的方向努力。第一步是计算目标函数 𝐽=𝐿+𝑠 相对于损失项 𝐿 和正则项 𝑠 的梯度。

这里为什么等于1?因为单隐藏层简单网络的最后一层上面是
在这里插入图片描述
根据链式法则计算目标函数关于输出层变量 𝐨 的梯度:
在这里插入图片描述
计算正则化项相对于两个参数的梯度:

在这里插入图片描述
计算最接近输出层的模型参数的梯度 ∂𝐽/∂𝐖(2)∈ℝ𝑞×ℎ 。 使用链式法则得出:

在这里插入图片描述
为了获得关于 𝐖(1)的梯度,需要继续沿着输出层到隐藏层反向传播。 关于隐藏层输出的梯度 ∂𝐽/∂𝐡∈ℝℎ 由下式给出:
在这里插入图片描述
由于激活函数 𝜙 是按元素计算的, 计算中间变量 𝐳的梯度 ∂𝐽/∂𝐳∈ℝℎ 需要使用按元素乘法运算符,用 表示:
在这里插入图片描述
最后,可以得到最接近输入层的模型参数的梯度 ∂𝐽/∂𝐖(1)∈ℝℎ×𝑑 。 根据链式法则,我们得到:
在这里插入图片描述

四、训练神经网络

在训练神经网络时,前向传播和反向传播相互依赖。

对于前向传播,沿着依赖的方向遍历计算图并计算其路径上的所有变量。 然后将这些用于反向传播,其中计算顺序与计算图的相反。

以上述简单网络为例:
正则项:

在这里插入图片描述
反向传播中计算J对W(2)的梯度公式:
在这里插入图片描述
反向传播中计算J对W(1)的梯度公式:
在这里插入图片描述
一方面,在前向传播期间计算正则项取决于模型参数𝐖(1)和 𝐖(2)的当前值。 它们是由优化算法根据最近迭代的反向传播给出的。 另一方面,反向传播期间参数的梯度计算, 取决于由前向传播给出的隐藏变量𝐡的当前值。

因此,在训练神经网络时,在初始化模型参数后, 交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。

注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。 带来的影响之一是需要保留中间值,直到反向传播完成。 这也是训练比单纯的预测需要更多的内存(显存)的原因之一。 此外,这些中间值的大小与网络层的数量和批量的大小大致成正比。 因此,使用更大的批量来训练更深层次的网络更容易导致内存不足(out of memory)错误

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/499328.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

欢迎参与2024年 OpenTiny 开源项目用户调研

🎉 欢迎参与 OpenTiny 开源项目的用户调研 🎉 📣 调研背景 随着 OpenTiny 开源项目的不断发展,我们一直致力于为开发者提供高质量的 Web 前端开发解决方案。为了更好地满足用户需求,提升项目的实用性和易用性&#x…

【Maven】Maven 基础教程(二):Maven 的使用

《Maven 基础教程》系列,包含以下 2 篇文章: Maven 基础教程(一):基础介绍、开发环境配置Maven 基础教程(二):Maven 的使用 😊 如果您觉得这篇文章有用 ✔️ 的话&#…

Sanctuary AI旗下世界上首个采用Carbon驱动的人形通用机器人Phoenix最新演示视频

Phoenix,作为世界上首个采用Carbon驱动的人形通用机器人,展现了一种开创性和独特的AI控制系统,赋予了机器人接近人类的智能水平。这一革命性的Carbon系统能够将自然语言无缝转化为现实世界中的实际行动,从而赋予Phoenix跨足多个行…

Java中的时间API:Date、Calendar到Java.time的演变

引言 在软件开发中,处理时间和日期是一项基本且不可或缺的任务。无论是日志记录、用户信息管理还是复杂的定时任务,准确地处理时间都显得至关重要。然而,时间的处理并不像它看起来那么简单,尤其是当我们考虑到时区、夏令时等因素…

基于Python3的数据结构与算法 - 06 topk问题

一、引入 问题&#xff1a;目前共有n个数&#xff0c;设计算法得到前k大的数。&#xff08;m<n&#xff09; 解决思路&#xff1a; 排序后切片&#xff1a;O(n*lognm) O(n*logn)排序LowB三人组&#xff1a;O(mn) 例如冒泡排序&#xff0c;交换m次&#xff0c;即可取前m…

印刷行业实施MES管理系统有哪些重要的意义

随着科技的飞速发展和市场需求的不断变化&#xff0c;印刷行业正面临着前所未有的挑战和机遇。为了在激烈的市场竞争中立于不败之地&#xff0c;印刷企业急需通过引入先进的管理系统来提升生产效率、优化资源配置、确保产品质量。而万界星空科技MES正是这样一种能够帮助印刷企业…

力扣:9. 回文数

力扣&#xff1a;9. 回文数 给你一个整数 x &#xff0c;如果 x 是一个回文整数&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 回文数是指正序&#xff08;从左向右&#xff09;和倒序&#xff08;从右向左&#xff09;读都是一样的整数。 例如&#xf…

面试经典150题——合并区间

​"Do not wait to strike till the iron is hot; but make it hot by striking." - William Butler Yeats 1. 题目描述 2. 题目分析与解析 2.1 思路一 还是先来分析题目&#xff0c;观察一下重叠的区间和不重叠的区间数组之间的关系是什么&#xff1a; 比如区间…

Mendix 开发实践指南|Mendix的核心概念

在当今快速变化的技术环境中&#xff0c;Mendix平台以模型驱动开发方法&#xff0c;重新定义了应用程序的构建过程。本章内容&#xff0c;将深入探讨Mendix的几大核心概念&#xff1a;模型驱动开发、微流、纳流 、 实体模型和页面&#xff0c;旨在帮助我们全面理解Mendix平台的…

Windows虚拟主机如何开启网页debug模式

前不久&#xff0c;有客户咨询想要知道如何开启网页debug模式,以便后期他网站出现异常可以自行排查。这边了解到他当前使用的是Hostease 的Windows 虚拟主机&#xff0c;而开启网页debug模式的操作步骤如下&#xff1a; 1.Hostease的Windows虚拟主机都是带Plesk面板的,因此需要…

本届挑战赛亚军方案:面向微服务架构系统中无标注、多模态运维数据的异常检测、根因定位与可解释性分析

CheerX团队来自于南瑞研究院系统平台研发中心&#xff0c;中心主要从事NUSP电力自动化通用软件平台的关键技术研究与软件研发。 选题分析 图1 研究现状 本次CheerX团队的选题紧密贴合了目前的运维现状。实际运维中存在多种问题导致运维系统的不可用。比如故障发生时&#xff…

C语言:字符函数 字符串函数 内存函数

C语言&#xff1a;字符函数 & 字符串函数 & 内存函数 字符函数字符分类函数字符转换函数tolowertoupper 字符串函数strlenstrcpystrcatstrcmpstrstrstrtok 内存函数memcpymemmovememsetmemcmp 字符函数 顾名思义&#xff0c;字符函数就是作用于字符的函数&#xff0c;…