pytorch06:权重初始化

在这里插入图片描述

目录

  • 一、梯度消失和梯度爆炸
    • 1.1相关概念
    • 1.2 代码实现
    • 1.3 实验结果
    • 1.4 方差计算
    • 1.5 标准差计算
    • 1.6 控制网络层输出标准差为1
    • 1.7 带有激活函数的权重初始化
  • 二、Xavier方法与Kaiming方法
    • 2.1 Xavier初始化
    • 2.2 Kaiming初始化
    • 2.3 常见的初始化方法
  • 三、nn.init.calculate_gain

一、梯度消失和梯度爆炸

1.1相关概念

一个简易三层全连接神经网络图和神经元计算如下:
在这里插入图片描述
观察第二个隐藏层的权值的梯度是如何求取的,根据链式法则,可以得到如下计算公式,会发现w2的梯度依赖上一层的输出值H1;
在这里插入图片描述
当H1趋近于0的时候,W2的梯度也趋近于0;—>梯度消失
当H1趋近于无穷的时候,W2的梯度也趋近于无穷;—>梯度爆炸
在这里插入图片描述
一旦出现梯度消失或者梯度爆炸就会导致模型无法训练;

1.2 代码实现

import os
import torch
import random
import numpy as np
import torch.nn as nn
from common_tools import set_seedset_seed(1)  # 设置随机种子class MLP(nn.Module):def __init__(self, neural_num, layers):super(MLP, self).__init__()self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])self.neural_num = neural_numdef forward(self, x):for (i, linear) in enumerate(self.linears):x = linear(x)# x = torch.relu(x)# x = torch.tanh(x)print("layer:{}, std:{}".format(i, x.std()))  # 打印当前值的标准差if torch.isnan(x.std()):  # 判断是什么时候标准差为nanprint("output is nan in {} layers".format(i))breakreturn x# 权值初始化函数def initialize(self):for m in self.modules():if isinstance(m, nn.Linear):  # 判断当前网络层是否是线性层,如果是就进行权值初始化nn.init.normal_(m.weight.data)  # normal: mean=0, 控制标准差std在1左右# nn.init.normal_(m.weight.data, std=np.sqrt(1 / self.neural_num))# =======这段代码的目的是通过均匀分布初始化并结合tanh激活函数的特性,为神经网络的某一层(线性层)初始化合适的权重# a = np.sqrt(6 / (self.neural_num + self.neural_num))# tanh_gain = nn.init.calculate_gain('tanh')# a *= tanh_gain# nn.init.uniform_(m.weight.data, -a, a)# 将权重矩阵的值初始化为在 [-a, a] 范围内均匀分布的随机数。这个范围是通过之前的计算和调整得到的,目的是使得权重初始化在一个合适的范围内# nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)# ================凯明初始化方法================# nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法# nn.init.kaiming_normal_(m.weight.data)# flag = 0
flag = 1if flag:layer_nums = 100  # 100层线性层neural_nums = 256  # 每增加一层网络 标准差扩大根号n倍batch_size = 16net = MLP(neural_nums, layer_nums)print(net)net.initialize()inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1output = net(inputs)print(output)

1.3 实验结果

这里的初始化使用的是标准正态分布normal: mean=0, 控制标准差std在1左右的方法;
在这里插入图片描述
当输出层达到33层后就会出现梯度爆炸,超出了数据精度可以表示的范围。

1.4 方差计算

在这里插入图片描述
1.期望的计算公式
2,3.是方差的计算公式
根据1,2,3,可以得出,x,y的方差计算公式,当x,y的期望值都为0的时候,x,y的方差等于x的方差乘以y的方差。

1.5 标准差计算

在这里插入图片描述
通过计算可以得出每增加一层网络,标准差增加 n \sqrt{n} n ,n也就是神经元的个数;
代码展示:

if flag:layer_nums = 100  # 100层线性层neural_nums = 256  # 神经元个数 每增加一层网络 标准差扩大根号n倍batch_size = 16

执行结果:
可以看出第一层标准差是15.95,第二次标准差在上一层的基础上再乘以 256 \sqrt{256} 256
在这里插入图片描述

1.6 控制网络层输出标准差为1

从1.5可以看出D(H)的大小有三个因素决定,分别是n、D(X)、D(w),所以只要保证这三者乘积为1,就可以保证D(H)的值为1;
在这里插入图片描述
当我们权值的标准差为 1 / n \sqrt{1/n} 1/n ,那么就能保证网络层每一层的输出标准差都为1;

代码实现:
在这里插入图片描述

输出结果:
在这里插入图片描述
通过输出结果可以发现,几乎每一层网络输出的标准差都为1.

1.7 带有激活函数的权重初始化

在forward函数里面添加tanh激活函数
在这里插入图片描述
执行结果:
增加tanh激活函数之后,随着网络层的增加,标准差越来越小,从而会导致梯度消失的现象,下面将说明Xavier方法与Kaiming方法是如何解决该问题。
在这里插入图片描述

二、Xavier方法与Kaiming方法

2.1 Xavier初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:饱和函数,如Sigmoid,Tanh
Xavier初始化公式如下:
在这里插入图片描述

代码实现:
手动代码实现
在这里插入图片描述

直接使用pytorch提供的xavier_uniform_函数方法

nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

执行结果:
在这里插入图片描述
可以看到,每一层的网络输出标准差都在0.6左右

2.2 Kaiming初始化

当我们使用带有权值初始化的relu激活函数时,输出结果如下,会发现标准差随着网络层的增加逐渐减小,Kaiming初始化解决了这一问题。
在这里插入图片描述
在这里插入图片描述

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:ReLU及其变种
公式如下:
在这里插入图片描述

代码实现:

# ================凯明初始化方法================
nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
# nn.init.kaiming_normal_(m.weight.data)  # 使用pytorch自带方法

输出结果:
在这里插入图片描述

2.3 常见的初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

三、nn.init.calculate_gain

主要功能:计算激活函数的方差变化尺度(也就是输入数据的方差/经过激活函数之后的方差)
主要参数
• nonlinearity: 激活函数名称
• param: 激活函数的参数,如Leaky ReLU的negative_slop

代码实现:

flag = 1if flag:x = torch.randn(10000)out = torch.tanh(x)gain = x.std() / out.std()  # 手动计算print('gain:{}'.format(gain))tanh_gain = nn.init.calculate_gain('tanh')  # pytorch自带函数print('tanh_gain in PyTorch:', tanh_gain)

输出结果:
在这里插入图片描述
总结:任何数据在经过tanh激活函数之后,方差缩小大约1.6倍。感兴趣的话也可以使用relu进行实验,最后我的到的结果方差尺度大约是1.4左右。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/318935.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony从入门到放弃(一)

OpenHarmony从入门到放弃(二) 一、OpenHarmony的基本概念和特性 OpenHarmony是由开放原子开源基金会孵化及运营的开源项目,其目标是构建一个面向全场景、全连接、全智能的时代的智能终端设备操作系统。 分布式架构 OpenHarmony采用分布式…

前端工程化回顾-vite 构建神器

1.构建vite 项目 pnpm create vite2.常用的配置: 1.公共资源路径配置: base: ./, 默认是/2.路径别名配置: resolve: {alias: {: path.resolve(__dirname, ./src),ass: path.resolve(__dirname, ./src/assets),comp: path.resolve(__dirnam…

基于SSM的校园快递管理系统

目录 前言 开发环境以及工具 项目功能介绍 学生: 管理员: 详细设计 获取源码 前言 本项目是一个基于IDEA和Java语言开发的基于SSM的校园快递管理系统应用。应用包含学生端和管理员端等多个功能模块。 欢迎使用我们的校园快递管理系统!我…

【Java】面向对象程序设计 期末复习总结

语法基础 数组自带长度属性 length&#xff0c;可以在遍历的时候使用&#xff1a; int []ages new int[10];for (int i 0; i < ages.length; i)System.out.println(ages[i]); 数组可以使用增强式for语句进行只读式遍历&#xff1a; int[] years new int[10];for (int ye…

基于决策树、随机森林和层次聚类对帕尔默企鹅数据分析

作者&#xff1a;i阿极 作者简介&#xff1a;数据分析领域优质创作者、多项比赛获奖者&#xff1a;博主个人首页 &#x1f60a;&#x1f60a;&#x1f60a;如果觉得文章不错或能帮助到你学习&#xff0c;可以点赞&#x1f44d;收藏&#x1f4c1;评论&#x1f4d2;关注哦&#x…

JMM到底如何理解?JMM与MESI到底有没有关系?

今天给大家分享一篇对于理解Java的多线程&#xff0c;特别重要的一个知识点&#xff1a;JMM。在JVM中增加线程机制&#xff0c;首当其冲就是要实现JMM&#xff0c;即Java内存模型。JMM也是大家真正理解Java多线程的基础。 但是大家对于JMM&#xff0c;可以说大多数小伙伴对其的…

3dmax灯光缓存参数应该怎么设置?

细分&#xff1a;用来决定灯光缓存的样本数量&#xff0c;样本数量以此数值的平方来计算。数值越高&#xff0c;效果越好&#xff0c;速度越慢。 一般出图建议1000到1800之间已经足够了 采样大小&#xff1a;用来控制灯光缓存的样本尺寸大小&#xff0c;较小的数值意味着较小的…

python统计分析——直方图(sns.histplot)

使用seanborn.histplot()函数绘制直方图 from matplotlib.pyplot as plt import seaborn as snsdata_setnp.array([2,3,3,4,4,4,4,5,5,6]) plt.hist(fish_data) &#xff08;1&#xff09;dataNone, 表示数据源。 &#xff08;2&#xff09;xNone, 表示直方图的分布垂直与x轴…

Pycharm恢复默认设置

window 系统 找到下方目录-->删除. 再重新打开Pycharm C:\Users\Administrator\.PyCharm2023.3 你的不一定和我名称一样 只要是.PyCharm*因为版本不同后缀可能不一样 mac 系统 请根据需要删除下方目录 # Configuration rm -rf ~/Library/Preferences/PyCharm* # Caches …

uni-app中实现元素拖动

uni-app中实现元素拖动 1、代码示例 <template><movable-area class"music-layout"><movable-view class"img-layout" :x"x" :y"y" direction"all"><img :src"musicDetail.bgUrl" :class&…

算法与数据结构之数组(Java)

目录 1、数组的定义 2、线性结构与非线性结构 3、数组的表现形式 3.1 一维数组 3.2 多维数组 4、重要特性&#xff1a;随机访问 5、ArrayList和数组 6、堆内存和栈内存 7、数组的增删查改 7.1 插入数据 7.2 删除一个数据 7.3 修改数组 7.4 查找数据 8、总结 什么…