为什么要梯度累积

文章目录

    • 梯度累积
      • 什么是梯度累积
      • 如何理解理解梯度累积
        • 梯度累积的工作原理
      • 梯度累积的数学原理
        • 梯度累积过程
        • 如何实现梯度累积
      • 梯度累积的可视化

梯度累积

什么是梯度累积

随着深度学习模型变得越来越复杂,模型的训练通常需要更多的计算资源,特别是在训练期间需要更多的内存。在训练深度学习模型时,在硬件资源有限的情况下,很难使用大批量数据进行有效学习。大批量数据通常可以带来更好的梯度估计,但同时也需要大量的内存。

梯度累积是一种巧妙的技术,它允许在不增加内存需求的情况下,有效地使用更大的批量数据来训练深度学习模型。

如何理解理解梯度累积

梯度累积本质上涉及将大批量划分为较小的子批量,并在这些子批量上累积计算出的梯度。这一过程模拟了使用较大批量训练的情况。

梯度累积的工作原理

以下是梯度累积过程的逐步分解:

  1. 分而治之:将你的硬件无法处理的大批量划分为更小的、可管理的子批量。
  2. 累积梯度:不是在处理每个子批量后更新模型参数,而是在几个子批量上累积梯度。
  3. 参数更新:在处理了预定义数量的子批量后,使用累积的梯度来更新模型参数。

这种方法使得模型能够利用大批量的稳定性和收敛性,而不必提高内存成本。

梯度累积的数学原理

在这里插入图片描述

梯度累积过程

在深度学习模型中,一个完整的前向和反向传播过程如下:

  • 前向传播:数据通过神经网络,层层处理后得到预测结果。

  • 损失计算:使用损失函数计算预测结果与实际值之间的差异。以平方误差损失函数为例:

    L ( θ ) = 1 2 ( h ( x k ) − y k ) 2 L(\theta) = \frac{1}{2} (h(x_k) - y_k)^2 L(θ)=21(h(xk)yk)2

    这里 L ( θ ) L(\theta) L(θ) 表示损失函数, θ \theta θ 代表模型参数, h ( x k ) h(x_k) h(xk) 是对输入 x k x_k xk 的预测输出, y k y_k yk 是对应的真实输出。

  • 反向传播:计算损失函数相对于模型参数的梯度(对上式求导):

    ∇ θ L ( θ ) = ( h ( x k ) − y k ) ⋅ ∇ θ h ( x k ) \nabla_\theta L(\theta) = (h(x_k) - y_k) \cdot \nabla_\theta h(x_k) θL(θ)=(h(xk)yk)θh(xk)

  • 梯度累积:在传统的训练过程中,每完成一个批次的数据处理后就会更新模型参数。而在梯度累积中,梯度不是立即用来更新参数,而是累加多个小批次的梯度:

    G = ∑ i = 1 n ∇ θ L i ( θ ) G = \sum_{i=1}^{n} \nabla_{\theta} L_i(\theta) G=i=1nθLi(θ)

    这里 G G G 是累积梯度, L i ( θ ) L_i(\theta) Li(θ) 是第 i i i 个batch的损失函数。

  • 参数更新:累积足够的梯度后,使用以下公式更新参数:

    θ = θ − η ⋅ G \theta = \theta - \eta \cdot G θ=θηG
    其中 l r lr lr 是学习率,用于控制更新的步长。

如何实现梯度累积

以下是在 PyTorch 中实现梯度累积的示例:

# 模型定义
model = ...
optimizer = ...# 累积步骤数
accumulation_steps = 4for epoch in range(num_epochs):optimizer.zero_grad()for i, (inputs, labels) in enumerate(dataloader):outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()# 只有在处理足够数量的子批量后才更新参数if (i + 1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad()# 如果批量大小不是累积步数的倍数,确保在每个epoch结束时更新if (i + 1) % accumulation_steps != 0:optimizer.step()optimizer.zero_grad()

这个例子中,accumulation_steps 定义了在参数更新前需要累积的batch数量。

梯度累积的可视化

为了更好地理解梯度累积的影响,可视化可以非常有帮助。以下是一个例子,说明如何在神经网络中可视化梯度流,以监控梯度是如何被累积和应用的:

import matplotlib.pyplot as plt# 绘制梯度流动的函数
def plot_grad_flow(named_parameters):ave_grads = []layers = []for n, p in named_parameters:if (p.requires_grad) and ("bias" not in n):layers.append(n)ave_grads.append(p.grad.abs().mean())plt.plot(ave_grads, alpha=0.3, color="b")plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k")plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")plt.xlim(xmin=0, xmax=len(ave_grads))plt.xlabel("层")plt.ylabel("平均梯度")plt.title("网络中的梯度流")plt.grid(True)plt.show()# 在训练过程中或训练后调用此函数以可视化梯度流
plot_grad_flow(model.named_parameters())

参考资料:

  1. Gradient Accumulation Algorithm

  2. Performing gradient accumulation with 🤗 Accelerate

  3. 梯度累加(Gradient Accumulation)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/675409.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2009-2022年上市公司华证ESG评级评分数据(含细分项)

2009-2022年上市公司华证ESG评级评分数据(含细分项) 1、时间:2009-2022年 2、来源:华证ESG 3、指标:证券代码、证券简称、综合评级、年度、综合得分、E评级、E得分、S评级、S得分、G评级、G得分 4、范围&#xff1…

C语言 函数的定义与调用

上文 C语言 函数概述 我们对函数进行了概述 本文 我们来说函数的定义和调用 C语言规定 使用函数之前,首先要对函数进行定义。 根据模块化程序设计思想,C语言的函数定义是互相平行、独立的,即函数定义不能嵌套 C语言函数定义 分为三种 有参函…

Ansible自动运维工具之playbook

一.inventory主机清单 1.定义 Inventory支持对主机进行分组,每个组内可以定义多个主机,每个主机都可以定义在任何一个或多个主机组内。 2.变量 (1)主机变量 [webservers] 192.168.10.14 ansible_port22 ansible_userroot ans…

[微信小程序] 入门笔记2-自定义一个显示组件

[微信小程序] 入门笔记2-自定义一个显示组件 0. 准备工程 新建一个工程,删除清空app的内容和其余文件夹.然后自己新建pages和components创建1个空组件和1个空页面. 设定 view 组件的默认样式,使其自动居中靠上,符合习惯.在app.wxss内定义,作用做个工程. /**app.wxss**/ /* 所…

【第6节课笔记】LagentAgentLego

Lagent 最中间部分的是LLM,即为大语言模型模块,他可以思考planning和调用什么action,再将其转发给动作执行器action executer执行。 支持的工具如下: Arxiv 搜索 Bing 地图 Google 学术搜索 Google 搜索 交互式 IPython 解释器 IP…

【Java笔记】多线程:一些有关中断的理解

文章目录 线程中断的作用线程的等待状态WAITINGTIMED_WAITING 线程从等待中恢复 java.lang.Thread中断实现相关方法中断标识interrupted 一些小练习Thread.interrupt() 只唤醒线程并修改中断标识sleep() 清除中断状态标识 Reference 线程中断的作用 线程中断可以使一个线程从等…

Linux命名管道的创建及应用

目录 一、命名管道的定义即功能 1.1创建命名管道 1.2匿名管道和命名管道的区别 1.3命名管道的打开规则 1.4系统调用unlink 二、进程间命名管道的创建及使用 2.1Comm.hhp 2.2PipeServer.cc 2.3PipeClient.cc 一、命名管道的定义即功能 管道应用的一个限制就是只能在具有…

MySQL之聚合函数与应用

1. 前言 上文我们讲到了单行函数.实际上SQL还有一类叫做聚合函数, 它是对一组数组进行汇总的函数, 输入的是一组数据的集合, 输出的是单个值. 2. 聚合函数 用于处理一组数据, 并对一组数据返回一个值. 有如下几种聚合函数 : AVG(), SUM(), MAX(), MIN(), COUNT(). 3. AVG(…

上位机图像处理和嵌入式模块部署(树莓派4b镜像烧录经验总结)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 陆陆续续也烧录了好多次树莓派的镜像了,这里面有的时候很快,有的时候很慢。特别是烧录慢的时候,也不知道是自己…

基于51单片机的闭环反馈直流电机PWM控制电机转速测量( proteus仿真+程序+设计报告+原理图+讲解视频)

基于51单片机的闭环反馈直流电机PWM控制转速测量( proteus仿真程序设计报告原理图讲解视频) 仿真图proteus7.8及以上 程序编译器:keil 4/keil 5 编程语言:C语言 设计编号:S0086 1. 主要功能: 基于51单片机的闭环…

Figma 高效技巧:设计系统中的图标嵌套

Figma 高效技巧:设计系统中的图标嵌套 在设计中,图标起着不可或缺的作用。一套便捷易用的图标嵌套方法可以有效提高设计效率。 分享一下我在图标嵌套上走过的弯路和经验教训。我的图标嵌套可以分三个阶段: 第一阶段:建立图标库 一…

语音识别--光谱门控降噪

⚠申明: 未经许可,禁止以任何形式转载,若要引用,请标注链接地址。 全文共计7267字,阅读大概需要3分钟 🌈更多学习内容, 欢迎👏关注👀【文末】我的个人微信公众号&#xf…