抑制过拟合——Dropout原理

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timeclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20),nn.Linear(20, 20),nn.Dropout(0.1),nn.Linear(20, 20),nn.Linear(20, 20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.01
iteration = 1000x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()start_time = time.time()
writer = SummaryWriter(comment='_随机失活')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad, iter)writer.add_histogram(name + '_data', layer, iter)writer.add_scalar('loss', loss, iter)optimizer.step()optimizer.zero_grad()if iter % 50 == 0:print("iter: ", iter)print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/234890.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MacOS + Android Studio 通过 USB 数据线真机调试

环境:Apple M1 MacOS Sonoma 14.1.1 软件:Android Studio Giraffe | 2022.3.1 Patch 3 设备:小米10 Android 13 一、创建测试项目 安卓 HelloWorld 项目: 安卓 HelloWorld 项目 二、数据线连接手机 1. 手机开启开发者模式 参考&#xff1…

Jmeter+ant+jenkins实现持续集成看这一篇就搞定了!

jmeterantjenkins持续集成 一、下载并配置jmeter 首先下载jmeter工具,并配置好环境变量;参考:https://www.cnblogs.com/YouJeffrey/p/16029894.html jmeter默认保存的是.jtl格式的文件,要设置一下bin/jmeter.properties,文件内容…

【前端】多线程 worker

VUE3 引用 npm install worker-loader 在vue.config.js文件的defineConfig里加上配置参数 chainWebpack: config > {config.module.rule(worker-loader).test(/\.worker\.js$/).use({loader: worker-loader,options: {inline: true}}).loader(worker-loader).end()}先在…

传教士与野人过河问题

代码模块参考文章:传教士与野人过河问题(numpy、pandas)_python过河问题_醉蕤的博客-CSDN博客 问题描述 一般的传教士和野人问题(Missionaries and Cannibals):有N个传教士和C个野人来到河边准 备渡河。…

EasyMicrobiome-易扩增子、易宏基因组等分析流程依赖常用软件、脚本文件和数据库注释文件

啥也不说了,这个好用,给大家推荐:YongxinLiu/EasyMicrobiome (github.com) 大家先看看引用文献吧,很有用:https://doi.org/10.1002/imt2.83 还有这个,后面马上介绍:YongxinLiu/EasyAmplicon: E…

mabatis基于xml方式和注解方式实现多表查询

前面步骤 http://t.csdnimg.cn/IPXMY 1、解释 在数据库中,单表的操作是最简单的,但是在实际业务中最少也有十几张表,并且表与表之间常常相互间联系; 一对一、一对多、多对多是表与表之间的常见的关系。 一对一:一张…

zookeeper 单机伪集群搭建简单记录(实操课程系列)

本系列是zookeeper相关的实操课程,课程测试环环相扣,请按照顺序阅读测试来学习zookeeper 1、官方下载加压后,根目录下新建data和log目录,然后分别拷贝两份,分别放到D盘,E盘,F盘 2、data目录下面…

基于SSM的影视创作论坛设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…

docker集群的详解以及超详细搭建

文章目录 一、问题引入1. 多容器位于同一主机2. 多容器位于不同主机 二、介绍三、特性四、概念1. 节点nodes2. 服务(service)和任务(task)3. 负载均衡 五、docker网络1. overlay网络 六、docker集群搭建1. 环境介绍2. 创建集群3. 集群网络4. 加入工作节点 七、部署可视化界面po…

[NOIP2016 普及组] 回文日期

枚举好题&#xff0c;直接枚举答案 看看在不在范围内就行了 注意二月份 92200229是合法的~ 82200228也是合法的&#xff01; #include<bits/stdc.h> using namespace std;map<int,int>mp;int main() {mp[1] mp[3] mp[5] mp[7] mp[8] mp[10] mp[12] 31;mp[…

【MATLAB】异常数据识别

基于分位数的异常点识别 首先&#xff0c;给定了一个原始数据序列x。然后&#xff0c;计算了序列x的上四分位数和下四分位数&#xff0c;并根据这两个值计算了异常点的阈值。上四分位数减去1.5倍的四分位数范围得到异常值下界&#xff0c;下四分位数加上1.5倍的四分位数范围得…

5年经验之谈 —— 接口测试主要测哪些方面?

当今互联网时代&#xff0c;接口测试已经成为软件测试的一个重要组成部分。接口测试是指对系统各个接口进行验证&#xff0c;确保接口的正确性、稳定性和安全性。接口测试是软件开发过程中不可缺少的环节&#xff0c;它旨在确保接口能够正常工作&#xff0c;并且满足所需要的规…