【深度学习中的批量归一化BN和层归一化LN】BN层(Batch Normalization)和LN层(Layer Normalization)的区别

文章目录

  • 1、概述
  • 2、BN层
  • 3、LN层
  • 4、Pytorch的实现
  • 5、BN层和LN层的对比

1、概述

  • 归一化(Normalization) 方法:指的是把不同维度的特征(例如序列特征或者图像的特征图等)转换为相同或相似的尺度范围内的方法,比如把数据特征映射到[0, 1]或[−1, 1]区间内,或者映射为服从均值为0、方差为1的标准正态分布。
  • 那为什么要进行归一化?

样本特征由于来源和度量单位的不同或者经过多个卷积层处理后导致不同来源或者不同卷积层的输入特征尺度存在较大差异,模型的优化方向可能会被尺度较大的特征所主导。而进行归一化可以使得尺度大致处于同一范围内,从而有利于模型的训练和优化。

  • BN层(Batch Normalization):是在不同样本之间进行归一化。
  • LN层(Layer Normalization):是在同一样本内部进行归一化。
  • 以下的图简单展示了二者的区别:
    在这里插入图片描述
    参考链接:https://blog.csdn.net/qq_44397802/article/details/128452207

2、BN层

  • 下图很清晰的解释了BN层:由于是Batch Normalization,那么简单来说,就是针对Batch中的不同样本之间求均值和标准差,再做归一化

1)如下图,针对神经元的输出进行BN,确定Batch size为N,但是不同类型样本的维度可能不一样(下图中维度为1,例如图像经过卷积以后维度为C × \times ×H × \times ×W)
2)不论维度为多少,各个样本之间的维度是相同的,因此针对不同样本之间的对应维度计算出均值和标准差,肯定与每个样本的维度相同(下图中,均值和标准差都为一维,对于图像,均值和标准差的维度为C × \times ×H × \times ×W)
3)针对每个神经元训练一组可学习的参数 γ \gamma γ β \beta β,用于对输出的每个响应值做缩放和平移。
4)注意如果样本为一维,可学习参数的组数与输出的响应值的数量相等,也与神经元的个数相等;如果样本是图像,输入为N × \times ×C1 × \times ×H × \times ×W,卷积核个数为C2,那么输出为N × \times ×C2 × \times ×H × \times ×W,因此可学习参数的组数与输出通道数相等,为C2,也与卷积核个数相等。
5)所以简单来说,可学习参数的组数就与通道数相等。

在这里插入图片描述

3、LN层

  • 一般来说,层归一化所做的就是,对于图像,即输入为N × \times ×C × \times ×H × \times ×W的特征图:在每个样本内部,计算所有像素点的均值和标准差,并针对每个像素点训练一组可学习参数 γ \gamma γ β \beta β,用于进行缩放和平移,以归一化到同一范围内。

  • 如下图所示,针对的是一个样本中的所有通道内的所有像素。也就是说和Batch无关。

  • 因此可学习参数的组数就等于C × \times ×H × \times ×W。
    在这里插入图片描述

  • 计算公式:
    在这里插入图片描述

4、Pytorch的实现

  • BN层的实现:
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
''
num_features:输入尺寸为(N,C,H,W),则该值为C
''
  • LN层的实现:
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)
''
1)normalized_shape:归一化的尺寸,输入的尺寸必须符合:[∗×normalized_shape[0]×normalized_shape[1]×…×normalized_shape[1]]
如果为单个整数,则对最后一维进行归一化
2)elementwise_affine:是否具有可学习的参数,默认为True
''
  • 如下为BN和LN层的实现,以及参数量的计算
import torch
from torch import nn# NLP Example
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
print(layer_norm)
param_num = sum([param.numel() for param in layer_norm.parameters()])
print(param_num)
output_embed = layer_norm(embedding)
print(output_embed.shape)
  • 输出为:
LayerNorm((10,), eps=1e-05, elementwise_affine=True)
20
torch.Size([20, 5, 10])
import torch
from torch import nn# Image Example
N, C, H, W = 20, 5, 10, 10
input0 = torch.randn(N, C, H, W)
# Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
layer_norm = nn.LayerNorm([C, H, W]) # 参数量为C*H*W*2
print(layer_norm)
param_num = sum([param.numel() for param in layer_norm.parameters()])
print(param_num)
output = layer_norm(input0)
print(output.shape)input1 = torch.randn(N, C, H, W)
batch_norm = nn.BatchNorm2d(C) # 参数量为C*2
print(batch_norm)
param_num1 = sum([param.numel() for param in batch_norm.parameters()])
print(param_num1)
output1 = batch_norm(input1)
print(output1.shape)
  • 输出为:
LayerNorm((5, 10, 10), eps=1e-05, elementwise_affine=True)
1000
torch.Size([20, 5, 10, 10])
BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
10
torch.Size([20, 5, 10, 10])

5、BN层和LN层的对比

  • 简单对比如下:
    在这里插入图片描述

参考链接:https://blog.csdn.net/hymn1993/article/details/122719043

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/62548.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分清性能测试,负载测试,压力测试这三个的区别

做测试一年多来,虽然平时的工作都能很好的完成,但最近突然发现自己在关于测试的整体知识体系上面的了解很是欠缺,所以,在工作之余也做了一些测试方面的知识的补充。不足之处,还请大家多多交流,互相学习。 …

【MongoDB】数据库、集合、文档常用CRUD命令

目录 一、数据库操作 1、创建数据库操作 2、查看当前有哪些数据库 3、查看当前在使用哪个数据库 4、删除数据库 二、集合操作 1、查看有哪些集合 2、删除集合 3、创建集合 三、文档基本操作 1、插入数据 2、查询数据 3、删除数据 4、修改数据 四、文档分页查询 …

论文阅读---《Unsupervised ECG Analysis: A Review》

题目 无监督心电图分析一综述 摘要 电心图(ECG)是检测异常心脏状况的黄金标准技术。自动检测心电图异常有助于临床医生分析心脏监护仪每天产生的大量数据。由于用于训练监督式机器学习模型的带有心脏病专家标签的异常心电图样本数量有限,对…

无涯教程-Perl - msgctl函数

描述 该函数使用参数ID,CMD和ARG调用系统函数msgctrl()。您可能需要包括IPC::SysV包以获得正确的常量。 语法 以下是此函数的简单语法- msgctl ID, CMD, ARG返回值 该函数返回0,但如果系统函数成功返回0和1,则返回true。 Perl 中的 msgctl函数 - 无涯教程网无涯教程网提供…

Go语言工程实践之测试与Gin项目实践

Go 语言并发编程 及 进阶与依赖管理_软工菜鸡的博客-CSDN博客 03 测试 回归测试一般是QA(质量保证)同学手动通过终端回归一些固定的主流程场景 集成测试是对系统功能维度做测试验证,通过服务暴露的某个接口,进行自动化测试 而单元测试开发阶段,开发者对单独的函数…

【Java 回忆录】Java全栈开发笔记文档

这里能学到什么? 实战代码文档一比一记录实战问题和解决方案涉及前端、后端、服务器、运维、测试各方面通过各方面的文档与代码,封装一套低代码开发平台直接开腾讯会议,实实在线一起分享技术问题核心以 Spring Boot 作为基础框架进行整合后期…

C++ 计算 拟合优度R^2

解决的问题: 拟合优度(Goodness of Fit)是指回归直线对观测值的拟合程度,度量拟合优度的统计量是可决系数(亦称确定系数) R?。R最大值为 1。R%的值越接近1,说明回归直线对观测值的拟合程度越好,反之,R%值越小&#x…

MySQL_多表关系

多表关系 一对一关系 用户和用户详情 关系:一对一的关系 用途:用于单表拆分,将一张表的基础字段放在一张表中,其它字段放在另一张表中,可以提升查询效率 实现:在任意一张表里面添加外键,关联…

竞赛项目 深度学习的水果识别 opencv python

文章目录 0 前言2 开发简介3 识别原理3.1 传统图像识别原理3.2 深度学习水果识别 4 数据集5 部分关键代码5.1 处理训练集的数据结构5.2 模型网络结构5.3 训练模型 6 识别效果7 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 深度学习…

Sharding-JDBC概述

前言 ​ 随着业务数据量的增加,原来所有的数据都是在一个数据库上的,网络IO及文件IO都集中在一个数据库上的,因此CPU、内存、文件IO、网络IO都可能会成为系统瓶颈。当业务系统的数据容量接近或超过单台服务器的容量、QPS/TPS接近或超过单个数…

生态系统服务(InVEST模型)

第一天: 1. 生态系统服务理论联系实践案例讲解 2. InVEST模型的开发历程、不同版本的差异及对数据需求的讲解 3. InVEST所需数据的要求(分辨率、格式、投影系统等)、获取及标准化预处理讲解 4. InVEST运行常见问题及处理解决方法讲解 5.…

个人用C#编写的壁纸管理器 - 开源研究系列文章

今天介绍一下笔者自己用C#开发的一个小工具软件:壁纸管理器。 开发这个小工具的初衷是因为Windows操作系统提供的功能个人不满意,而且现在闲着,所以就随意写了个代码。如果对读者有借鉴参考作用就更好了,能够直接代码段复用即可。…