《动手学深度学习(PyTorch版)》笔记7.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.5 Batch Normalization

批量规范化应用于单个可选层(也可以应用到所有层),其原理如下:
在每次训练迭代中,首先规范化输入,即通过减去其均值并除以其标准差(两者均基于当前小批量处理)。接下来应用比例系数和比例偏移。请注意,如果使用大小为1的小批量应用批量规范化,将无法学到任何东西,这是因为在减去均值之后,每个隐藏单元将为0,所以只有使用足够大的小批量,批量规范化才有效且稳定。用 x ∈ B \mathbf{x} \in \mathcal{B} xB表示一个来自小批量 B \mathcal{B} B的输入,批量规范化 B N \mathrm{BN} BN根据下式转换 x \mathbf{x} x

B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}. BN(x)=γσ^Bxμ^B+β.

在上式中, μ ^ B \hat{\boldsymbol{\mu}}_\mathcal{B} μ^B是小批量 B \mathcal{B} B的样本均值, σ ^ B \hat{\boldsymbol{\sigma}}_\mathcal{B} σ^B是小批量 B \mathcal{B} B的样本标准差。应用标准化后,生成的小批量的平均值为0和方差为1。由于单位方差是一个主观选择的结果,因此通常包含拉伸参数(scale) γ \boldsymbol{\gamma} γ偏移参数(shift) β \boldsymbol{\beta} β,它们的形状与 x \mathbf{x} x相同,且是需要与其他模型参数一起学习的参数。

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x , σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ . \begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\ \hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned} μ^Bσ^B2=B1xBx,=B1xB(xμ^B)2+ϵ.

我们在方差估计值中添加一个小的常量 ϵ > 0 \epsilon > 0 ϵ>0,以确保永远不会除以零。此外,估计值 μ ^ B \hat{\boldsymbol{\mu}}_\mathcal{B} μ^B σ ^ B {\hat{\boldsymbol{\sigma}}_\mathcal{B}} σ^B还会通过使用关于平均值和方差的噪声(noise)来抵消缩放问题。乍看起来,这种噪声是一个问题,而事实上它是有益的。

7.5.1 Batch Normalize Layer

批量规范化层和其他层之间的一个关键区别是:由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。我们在下面讨论全连接层和卷积层的批量规范化实现。

7.5.1.1 Fully Connected Layer

我们通常将批量规范化层置于全连接层中的仿射变换和激活函数之间。设全连接层的输入为x,权重参数和偏置参数分别为 W \mathbf{W} W b \mathbf{b} b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N \mathrm{BN} BN,那么使用批量规范化的全连接层的输出的计算公式如下:
h = ϕ ( B N ( W x + b ) ) . \mathbf{h} = \phi(\mathrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ). h=ϕ(BN(Wx+b)).

7.5.1.2 Convolutional Layer

对于卷积层,可以在卷积操作之后和非线性激活函数之前应用批量规范化。当卷积有多个输出通道时,我们需要对这些通道的每个输出执行批量规范化。每个通道都有自己的拉伸(scale)和偏移(shift)参数,两者都是标量。假设小批量包含 m m m个样本,并且对于每个通道,卷积的输出具有高度 p p p和宽度 q q q,那么对于卷积层,在每个输出通道的 m ⋅ p ⋅ q m \cdot p \cdot q mpq个元素上同时执行每个批量规范化。因此在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

7.5.2 Batch Normalization During Prediction

批量规范化在训练模式和预测模式下的行为和结果通常不同。在训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型,一种常用的方法是通过移动平均(具体方法在batch_norm()函数中定义)估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出。而在将训练好的模型用于预测时,我们可以根据整个数据集精确计算批量规范化所需的平均值和方差,不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as pltdef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.dataclass BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)#参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0#这些参数是nn.Parameter类型,因此会被PyTorch的优化器所管理,梯度更新是自动进行的self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,就将moving_mean和moving_var复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Ynet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))#简洁实现
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/454656.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt|实现时间选择小功能

在软件开发过程中,QtDesigner系统给出的控件很多时候都无法满足炫酷的效果,前一段时间需要用Qt实现选择时间的小功能,今天为大家分享一下! 首先看一下时间效果吧! 如果有需要继续往下看下去哟~ 功能 1:开…

Vue-easy-tree封装及使用

1.使用及安装 下载依赖 npm install wchbrad/vue-easy-tree引入俩种方案 1.在main.js中引入 import VueEasyTree from "wchbrad/vue-easy-tree"; import "wchbrad/vue-easy-tree/src/assets/index.scss" Vue.use(VueEasyTree)2.当前页面引入 import VueEa…

PiflowX新增Apache Beam引擎支持

参考资料: Apache Beam 架构原理及应用实践-腾讯云开发者社区-腾讯云 (tencent.com) 在之前的文章中有介绍过,PiflowX是支持spark和flink计算引擎,其架构图如下所示: 在piflow高度抽象的流水线组件的支持下,我们可以…

2024第九届国际发酵培养基应用与发展技术论坛会议通知

会议简介 随着科技的不断发展,发酵技术已逐渐成为生物制造中的重要支撑,发酵技术应用领域更加广泛。发酵培养基是影响生物发酵产业技术水平、环境友好程度的重要因素之一,为进一步推动发酵培养基的科学应用,提升发酵培养基的高效、稳定应用&…

电商推荐系统

此篇博客主要记录一下商品推荐系统的主要实现过程。 一、获取用户对商品的偏好值 代码实现 package zb.grms;import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Doub…

开发大佬为什么都不喜欢关电脑?

引言 在平时工作中,咱们程序员这一群体往往展现出一些特有的行为习惯,其中之一便是不喜欢频繁地关闭电脑、拒绝关机、长久待机、特别是苹果的机器。 下面从技术分析与用户行为研究的角度出发,将深入探讨程序员倾向于保持电脑开机状态的原因…

数字孪生网络攻防模拟与城市安全演练

在数字化浪潮的推动下,网络攻防模拟和城市安全演练成为维护社会稳定的不可或缺的环节。基于数字孪生技术我们能够在虚拟环境中进行高度真实的网络攻防模拟,为安全专业人员提供实战经验,从而提升应对网络威胁的能力。同时,在城市安…

P60_ib公式推导2

永磁无刷直流电机B相电流值推导过程(t2过程)

RabbitMQ-1.介绍与安装

介绍与安装 1.RabbitMQ1.0.技术选型1.1.安装1.2.收发消息1.2.1.交换机1.2.2.队列1.2.3.绑定关系1.2.4.发送消息 1.2.数据隔离1.2.1.用户管理1.2.3.virtual host 1.RabbitMQ 1.0.技术选型 消息Broker,目前常见的实现方案就是消息队列(MessageQueue&…

springboot155基于JAVA语言的在线考试与学习交流网页平台

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的 适用于计算机类毕业设计,课程设计参考与学习用途。仅供学习参考, 不得用于商业或者非法用途,否则,一切后果请用户自负。 看运行截图看 第五章 第四章 获取资料方式 **项…

智能优化算法 | Matlab实现飞蛾扑火(MFO)(内含完整源码)

文章目录 效果一览文章概述源码设计参考资料效果一览 文章概述 智能优化算法 | Matlab实现飞蛾扑火(MFO)(内含完整源码) 源码设计 %%%% clear all clc SearchAgents_no=100; % Number of search ag

Photoshop 2023下载安装教程,免费直装版,2步搞定安装,附安装包

准备工作: 1、提前准备好photoshop 2023安装包 没有的可以参考下面方式获取 2、系统要求Windows 10 及以上 安装步骤 1.找到下载好的安装包,直接双击解压 2.双击运行【Set-up.exe】文件 3.点击文件夹图标,更改安装位置 4.点击【继续】&a…