[PyTorch][chapter 8][李宏毅深度学习][Back propagation]

前言:

              反向传播算法(英:Backpropagation algorithm,简称:BP算法)是一种监督学习算法,常被用来训练多层感知机。 它用于计算梯度计算中,降低误差。

      

目录:

  1.     链式法则
  2.     模型简介(Model)
  3.     损失函数,梯度
  4.     手写例子
  5.     min-batch

一  链式法则

      链式法则是反向传播算法里面的核心。

     case1: y=g(x),z=h(y), x,y,z 都是scalar

                       

                     \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}        

      case2:  x=g(s),y=h(s),z=k(x,y),s,x,y,z 都是scalar

                   

                       \frac{dz}{ds}=\frac{dz}{dy}\frac{dy}{ds}+\frac{dz}{dx}\frac{dx}{ds}

      case3:   x,y,z 都是向量vector

                   x\rightarrow y\rightarrow z

                    \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}


二  模型(Model)

以常用的网络模型DNN 为例:

 激活函数为 \sigma

 总的层数为 L


三    损失函数,梯度

       3.1 损失函数

           J(w,b)=||a^{L}-y||_2^{2}

       3.2 梯度更新

               梯度计算分为两步:

   Forward pass, Backward pass

         a Forward pass

               假设 \delta^{l}=\frac{\partial J}{\partial z^l}:

            利用微分和迹的关系很容易得到

         

          b  Backward pass  

               假设为最后一层L

                 \delta^{L}=(\frac{\partial a^L}{\partial z^L})^T\frac{\partial J}{\partial a^L}

                       =diag(\sigma^{'}(z^{L}))(a^{L}-\hat{y})

                      =(a^{L}-\hat{y})\odot \sigma{'}(z^{L})

            我们用数学归纳法,第L层的\delta^{L}已经求出, 假设第l+1层的\delta^{l+1}已经求出来了,那么我们如何求出第l层的\delta^{l}呢?

                \delta^{l}=\frac{\partial J}{\partial z^{l}}

                    =(\frac{\partial z^{l+1}}{\partial z^{l}})^T\frac{\partial J}{\partial z^{l+1}}

                    =(\frac{\partial z^{l+1}}{\partial a^l}\frac{\partial a^{l}}{\partial z^l})^T \delta^{l+1}

                    =(diag(\sigma^{'}(z^l)(w^{l+1})^T)\delta^{l+1}

                    =(w^{l+1})^T\delta^{t+1}\odot \sigma^{'}(z^l)


四   简单DNN 网络例子

 4.1 说明:

          这里面随机生成5张图形,分别对应手写数字1,2,3,4,5。

简单的了解一下如何快速搭建一个DNN Model, 梯度如何计算,更新的.

 

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 17:21:35 2023@author: chengxf2
"""import torch 
from torch import nn
from torch import optimclass DNN(nn.Module):'''它是一个序列容器,是nn.Module的子类。 `nn.Sequential` 中的层是有顺序的,而且严格按照其顺序执行相邻两个层连接必须保证前一个层的输出与后一个层的输入相匹配。'''def __init__(self):super(DNN, self).__init__()self.net = nn.Sequential(nn.Linear(in_features=28*28, out_features=500),nn.Sigmoid(),nn.Linear(in_features=500, out_features=10),nn.Sigmoid())def forward(self, input):output = self.net(input)return outputdef train():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = DNN()criteon = torch.nn.CrossEntropyLoss(reduction='mean')optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)batch_size= 5data = torch.rand((batch_size,28*28))epochs = 2target = torch.tensor([0,1,2,3,4])target = target.to(device)for epoch in range(epochs):yHat = model(data)loss = criteon(yHat, target)loss.backward()print("\n loss ",loss)optimizer.step()if __name__ == "__main__":train()

 


五  min-batch

  在深度学习训练中,数据集我们通常采用min-batch 方案

    我们采用随机梯度方法,是为了加快运算速度。

但是GPU 可以并行运算,所以可以采用min-batch 方法进行梯度计算。

   使用min-batch 有个限制:

    1: 硬件限制 batch 不能超过硬件大小

    2:    batch 不能太大,否则容易陷入到局部极小值点,采用小的batch 可以有一定的随机性

每次出发点都不一样,一定概率跳过局部极小值点

参考:

7: Backpropagation_哔哩哔哩_bilibili

https://www.cnblogs.com/pinard/p/6422831.html

CSDN

8-1: “Hello world” of deep learning_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/288764.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java实现限流算法

下面是一个使用Java实现的令牌桶算法的例子: import java.util.concurrent.atomic.AtomicLong;public class RateLimiter {private final long capacity; // 令牌桶容量private final long rate; // 令牌生成速率private AtomicLong tokens; // 当前令牌数量privat…

RAII智能指针

RAII resource acquisition is initialization RAII是利用对象声明周期来控制程序资源的简单技术 在对象构造时获取资源,控制着对资源的访问使之在对象的生命周期内始终保持有效,最后在对象析构的时候释放资源。借此,我们实际上把管理一份资源…

2024Web自动化测试的技术框架和工具有哪些?

Web 自动化测试是一种自动化测试方式,旨在模拟人工操作对 Web 应用程序进行测试。这种测试方式可以提高测试效率和测试精度,减少人工测试的工作量和测试成本。在 Web 自动化测试中,技术框架和工具起着至关重要的作用。本文将介绍几种常见的 W…

VM安装Sonoma【笔记】

VMware Workstation安装MacOS Sonoma 1、配置虚拟机,根据系统性能调整参数; 2、先不焦急启动虚拟机,打开虚拟机存储目录,以文本方式打开.vmx文件(这里以Sonoma.vmx为例); 这里只针对Inter CP…

vue-springboot二手图书商城交易系统ij5dr

本系统依赖于MySQL数据库来储存信息,系统完成后,所有需要的数据都要从数据库中读取,这也意味着无论是插入、更新还是删除操作,只要对数据有改动的操作都需要与数据库交互,因此,系统的全部数据都要储存在数据…

C++相关闲碎记录(18)

1、strftime()的转换指示器 #include <locale> #include <chrono> #include <ctime> #include <iostream> #include <exception> #include <cstdlib> using namespace std;int main () {try {// query local time:auto now chrono::syste…

Python 多维数组详解(numpy)

文章目录 1 概述1.1 numpy 简介1.2 ndarray 简介 2 数组操作2.1 创建数组&#xff1a;array()2.2 裁切数组&#xff1a;切片2.3 拼接数组&#xff1a;concatenate()2.4 拆分数组&#xff1a;array_split()2.5 改变数组形状&#xff1a;reshape() 3 元素操作3.1 获取元素&#x…

飞天使-k8s知识点1-kubernetes架构简述

文章目录 名词功能要点 k8s核心要素CNCF 云原生框架简介k8s组建介绍 名词 CI 持续集成, 自动化构建和测试&#xff1a;通过使用自动化构建工具和自动化测试套件&#xff0c;持续集成可以帮助开发人员自动构建和测试他们的代码。这样可以快速检测到潜在的问题&#xff0c;并及早…

【QT】解决QTableView修改合并单元格内容无法修改到合并范围内的单元格

问题:修改合并单元格的内容 修改合并单元格的内容时,希望直接修改到合并范围内的单元格,Qt没有实现这个功能,需要自己写出 Delegate来实现 方案:Delegate class EditDelegate : public QStyledItemDelegate {public:EditDelegate(QTableView *view): tableView(view){}pu…

位操作符详解(C语言)

前言 C语言中的位操作符是用来对数据的二进制表示进行位级操作的运算符。这些操作符包括位与&#xff08;&&#xff09;、位或&#xff08;|&#xff09;、位异或&#xff08;^&#xff09;、位取反&#xff08;~&#xff09;&#xff0c;这些位操作符可以用来进行各种位级…

华为选择“力图生存”!国家队正式出手,外媒:鸿蒙将全面爆发

引言 在国际舞台上&#xff0c;国与国之间的关系错综复杂&#xff0c;舆论的力量也十分重要。近日&#xff0c;关于华为鸿蒙系统失去用户的预测成为热议的话题。这背后所面对的挑战和对抗也异常严峻。本文将解释鸿蒙系统的崛起与前景展望&#xff0c;揭示其中的机遇与挑战。 …

9ACL访问控制列表

为什么要有访问控制&#xff08;Access Control List&#xff09;&#xff1f; 因为我可能在局域网中提供了一些服务&#xff0c;我只希望合法的用户可以访问&#xff0c;其他非授权用户不能访问。 原理比较简单&#xff0c;通过对数据包里的信息做过滤&#xff0c;实现访问控…