神经网络中的归一化

我们今天介绍一下神经网络中的归一化方法~

之前学到的机器学习中的归一化是将数据缩放到特定范围内,以消除不同特征之间的量纲和取值范围差异。通过将原始数据缩放到一个特定的范围内,比如[0,1]或者[-1,1],来消除不同特征之间的量纲和取值范围的差异。这样做的好处包括降低数据的量纲差异,避免某些特征由于数值过大而对模型产生不成比例的影响,以及防止梯度爆炸或过拟合等问题。

神经网络中的归一化用于加速和稳定学习过程,避免梯度问题。 

神经网络的学习其实在学习数据的分布,随着网络的深度增加、网络复杂度增加,一般流经网络的数据都是一个 mini batch,每个 mini batch 之间的数据分布变化非常剧烈,这就使得网络参数频繁的进行大的调整以适应流经网络的不同分布的数据,给模型训练带来非常大的不稳定性,使得模型难以收敛。

如果我们对每一个 mini batch 的数据进行标准化之后,强制使输入分布保持稳定,从而可以加快网络的学习速度并提高模型的泛化能力。参数的梯度变化也变得稳定,有助于加快模型的收敛。

机器学习中的正则化分为L1和L2正则化,sklearn库中的Lasso类和Ridge类来实现L1正则化和L2正则化的线性回归模型。通过调整alpha参数,可以控制正则化的强度。

import numpy as np
from sklearn.linear_model import Lasso
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 生成模拟数据集
X, y = make_regression(n_samples=100, n_features=2, noise=0.1)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建Lasso回归模型,并设置alpha参数为0.1(正则化强度)
lasso = Lasso(alpha=0.1)# 拟合模型
lasso.fit(X_train, y_train)# 预测测试集数据
y_pred = lasso.predict(X_test)# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error:", mse)
  1. Ridge回归模型,fit方法的作用是使用提供的输入特征矩阵X_train和对应的目标值y_train来训练模型,即确定模型的权重参数。这个过程涉及到最小化一个包含L2正则化项的损失函数,以找到最佳的参数值,使得模型在训练集上的表现最优,同时通过正则化避免过拟合。 
  2. 在模型拟合完成后,可以使用predict方法来进行预测。这个方法将使用fit方法中学到的参数来对新的输入数据X_test进行预测,输出预测结果y_pred。因此,fit方法本身并不直接产生预测结果,而是为后续的预测准备了必要的模型参数。

批量归一化公式 

  • λ 和 β 是可学习的参数,它相当于对标准化后的值做了一个线性变换,λ 为系数,β 为偏置;
  • eps 通常指为 1e-5,避免分母为 0;
  • E(x) 表示变量的均值;
  • Var(x) 表示变量的方差;

通过批量归一化(Batch Normalization, 简称 BN)层之后,数据的分布会被调整为均值为β,标准差为γ的分布

批量归一化通过对每个mini-batch数据进行标准化处理,强制使输入分布保持稳定: 

  1. 计算该批次数据的均值和方差:这两个统计量是针对当前批次数据进行计算的。
  2. 利用这些统计数据对批次数据进行归一化处理:这一步将数据转换为一个近似以0为中心,标准差为1的正态分布。
  3. 尺度变换和偏移:为了保持网络的表达能力,通过可学习的参数γ(尺度因子)和β(平移因子)对归一化后的数据进行缩放和位移。

BN 层的接口 

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)

  • num_features: 这是输入张量的特征数量,即通道数。它指定了要进行归一化的特征维度。

  • eps: 这是一个小的常数,用于防止除以零的情况。默认值为1e-05。

  • momentum: 这是动量值,用于计算移动平均值。默认值为0.1。

  • affine: 这是一个布尔值,表示是否启用可学习的缩放和位移参数。如果设置为True,则在训练过程中会学习这些参数;如果设置为False,则使用固定的缩放和位移参数。默认值为True。

我们通过一个代码案例来理解一下工作原理 :

import torch
import torch.nn as nn# 定义输入数据的形状
batch_size = 32
num_channels = 3
height = 64
width = 64# 创建输入张量
input_data = torch.randn(batch_size, num_channels, height, width)# 创建批量归一化层
bn_layer = nn.BatchNorm2d(num_features=num_channels, eps=1e-05, momentum=0.1, affine=True)# 将输入数据传入批量归一化层
output_data = bn_layer(input_data)# 打印输出数据的形状
print("Output shape:", output_data.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/671849.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Excel利用数据透视表将二维数据转换为一维数据(便于后面的可视化分析)

一维数据:属性值都不可合并,属性值一般在第一列或第一行。 二维数据:行属性或列属性是可以继续合并的,如下数据中行属性可以合并为【月份】 下面利用数据透视表将二维数据转换为一维数据: 1、在原来的数据上插入数据透…

机器人系统ros2-开发实践05-将静态坐标系广播到 tf2(Python)-定义机器人底座与其传感器或非移动部件之间的关系

发布静态变换对于定义机器人底座与其传感器或非移动部件之间的关系非常有用。例如,最容易推断激光扫描仪中心框架中的激光扫描测量结果。 1. 创建包 首先,我们将创建一个用于本教程和后续教程的包。调用的包learning_tf2_py将依赖于geometry_msgs、pyth…

蛋白质/聚合物防污的机器学习(材料基因组计划)

前言:对于采用机器学习去研究聚合物的防污性能,以及或者其他性质。目前根据我的了解我认为最困难的点有三条: 其一:数据,对于将要训练的数据必须要有三点要求,1.数据要多,也就是大数据&#xff…

设置定位坐标+请按任意键继续

设置定位坐标 目的 在编程和游戏开发中,设置定位坐标的目的是为了确定对象在屏幕或游戏世界中的具体位置。坐标通常由一对数值表示,例如 (x, y),其中 x 表示水平位置,y 表示垂直位置。设置定位坐标的目的包括: 1. **精…

jquery项目 html使用export import方式调用模块

jquery的老项目,引入vue3, 需要方便使用export, import方式引用一些常用的方法与常量 导出模块 export js/numberUtil.js /*** Description:* Author Lani* date 2024/1/10*//* * 【金额】 保留2位小数,不四舍五入 * 5.992550 >5.99 , 2 > 2.…

Vue3 动态引入图片: require is not defined报错

问题&#xff1a;在 Vue3 项目中&#xff0c;使用 require 引入图片&#xff0c;报错 require is not defined 原因&#xff1a; Vue3 使用的是 vite&#xff0c;而 require 是 Webpack 的方法。 官网说明&#xff1a; 解决代码&#xff1a; <template><div v-fo…

ElasticSearch知识点汇总

1、ES中的​​​​​​​倒排索引是什么。 倒排索引&#xff0c;是通过分词策略&#xff0c;形成了词和文章的映射关系表&#xff0c;这种词典映射表即为倒排索引 2、ES是如何实现master选举的。 选举过程主要包括以下几个步骤&#xff1a; 心跳检测&#xff1a; 每个节点…

websevere服务器从零搭建到上线(四)|muduo网络库的基本原理和使用

文章目录 muduo源码编译安装muduo框架讲解muduo库编写服务器代码示例代码解析用户连接的创建和断开回调函数用户读写事件回调 使用vscode编译程序配置c_cpp_properties.json配置tasks.json配置launch.json编译 总结 muduo源码编译安装 muduo依赖Boost库&#xff0c;所以我们应…

webstorm 常用插件

安装插件步骤&#xff1a; 打开软件&#xff0c;文件 -- 设置-- 插件 -- 输入插件名称 -- 安装 代码截图: code screenShots 先选中代码&#xff0c;按 ctrl shift alt a&#xff0c;就可截取选中的代码颜色注释: comments highlighter 对注释的文字改变颜色高亮成对符号: h…

Redis开源社区持续壮大,华为云为Valkey项目注入新的活力

背景 今年3月21日&#xff0c;Redis Labs宣布从Redis 7.4版本开始&#xff0c;将原先比较宽松的BSD源码使用协议修改为RSAv2和SSPLv1协议&#xff0c;意味着 Redis在OSI&#xff08;开放源代码促进会&#xff09;定义下不再是严格的开源产品。Redis官方表示&#xff0c;开发者…

[Docker]容器的网络类型以及云计算

目录 知识梗概 1、常用命令2 2、容器的网络类型 3、云计算 4、云计算服务的几种主要模式 知识梗概 1、常用命令2 上一篇已经学了一些常用的命令&#xff0c;这里补充两个&#xff1a; 导出镜像文件&#xff1a;[rootdocker ~]# docker save -o nginx.tar nginx:laster 导…

Codeforces Round 943 (Div. 3) C-G

C. Assembly via Remainders 思路&#xff1a; 我们可以注意到&#xff0c;数组的长度只有 500 500 500 &#xff0c;并且每个数的大小都在 500 500 500 以内&#xff0c;再看向这题&#xff0c;容易知道&#xff0c;当第一个数确定之后&#xff0c;之后所有的数字都会确定下…