深度学习基础理论————训练加速(单/半/混合精度训练)/显存优化(gradient-checkpoint)

news/2025/1/8 15:48:50/文章来源:https://www.cnblogs.com/Big-Yellow/p/18650816

主要介绍单精度/半精度/混合精度训练,以及部分框架(DeepSpeed/Apex)

不同精度训练

单精度训练single-precision)指的是用32位浮点数(FP32)表示所有的参数、激活值和梯度
半精度训练half-precision)指的是用16位浮点数(FP16 或 BF16)表示数据。(FP16 是 IEEE 标准,BF16 是一种更适合 AI 计算的变种)
混合精度训练mixed-precision)指的是同时使用 FP16/BF16 和 FP32,利用二者的优点。通常,模型权重和梯度使用 FP32,而激活值和中间计算使用 FP16/BF16

FP16/BF16/FP32

Image From: https://www.exxactcorp.com/blog/hpc/what-is-fp64-fp32-fp16

不同精度之间对比:

指标 单精度(FP32) 半精度(FP16/BF16) 混合精度
精度 较低(FP16),中(BF16) 中高
显存占用 较低
训练速度 较慢
稳定性 最佳 稳定性低(FP16) 稳定
适用场景 小规模任务 性能优先,大规模模型 性能与稳定的平衡

混合精度训练(https://arxiv.org/pdf/1710.03740):

为什么不只用单精度训练(速度快/显存占用少)
1、直接使用半精度(FP16)容易引发数值问题,如溢出(overflow)下溢(underflow):这里是因为单精度有效尾数(约10位尾数)较单精度要小得多,那么就会有一个问题因此在训练过程中,如果激活函数的梯度非常小,可能会因精度不足而被舍弃为零,导致梯度下溢。此外,当数值超过半精度的表示范围时,也会发生溢出问题。这些限制会使训练难以正常进行,导致模型无法收敛或性能下降;
2、舍入误差(Rounding Error) 舍入误差指的是当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败,用一张图清晰地表示:

FP16/BF16/FP32

Image: https://zhuanlan.zhihu.com/p/79887894
总的来说就是:如果只用半精度会导致精度损失严重,因此就会提出用混合精度进行训练

解决上面用单精度造成的问题,在混合精度训练中论文提到的解决办法:

  • 1、FP32 MASTER COPY OF WEIGHTS
    模型权重会同时维护两个版本:1、FP32权重(Master Copy):以32位浮点数表示,用于存储和更新权重的精确值。2、FP16权重(Working Copy):以16位浮点数表示,用于前向传播和反向传播的计算,减少显存占用并加速运算

这里就会有一个问题,反向传播过程中要计算梯度,如果(梯度用FP16)梯度很小,不也还是会出现溢出问题,作者后续提到LOSS SCALING可以解决这种问题。如果梯度很大也会导致溢出问题,梯度计算使用FP16,但在权重更新之前,梯度会转换为 FP32 精度进行累积和存储,从而避免因溢出导致的权重更新错误。
另外之所以要用FP32对权重进行保存这是因为,作者研究发现更新 FP16 权重会导致 80% 的相对准确度损失。
we match FP32 training results when updating an
FP32 master copy of weights after FP16 forward and backward passes, while updating FP16 weights
results in 80% relative accuracy loss
另外一方面,如果拷贝权重,不也等同于把显存的占用拉大了?参考知乎上描述显存占用上主要是中间过程值

FP16/BF16/FP32
  • 2、LOSS SCALING

下图展示了 SSD 模型在训练过程中,激活函数梯度的分布情况,容易发现部分梯度值如果用FP16容易导致最后的梯度值变为0,这样就会导致上面提到的溢出问题,那么论文里面的做法就是:在反向传播前将loss增打\(2^k\)倍,这样就会保证不发生下溢出(乘一个常数,后面再去除这个常数不影响结果),如何反向传播再去除这个常数即可。

FP16/BF16/FP32
  • 3、Apex实现混合精度训练
git clone https://github.com/NVIDIA/apex
cd apex
python3 setup.py install

分别用Apex和torch原生的ampMNIST数据集上进行测试(模型:1层卷积+池化+2层全连接层)

# Apex
from apex import amp
...
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale="dynamic")
...
with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()# Amp
from torch.cuda.amp import autocast, GradScaler
...
scaler = GradScaler()
...
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

ApexAmp参数(https://nvidia.github.io/apex/amp.html):
1、opt_level欧1而不是零1):
O0:纯FP32训练,可以作为accuracy的baseline;
O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;
2、loss_scale="dynamic"
损失值处理(LOSS SCALING)默认是动态(初始一个较大的值,检查到溢出就减小)

测试效果:
准确率变化上

在公开数据集(CIFAR10)上进行测试(模型为resnet50)测试使用的设备为4090

训练集上变化

Run Smoothed Value Step Time 显存占用
scalar-CIFAR10/scalar-256-amp 0.8026 0.9364 11 16.99 min 15508
scalar-CIFAR10/scalar-256-apex 0.8093 0.9366 11 16.51 min 13166
scalar-CIFAR10/scalar-256-fp32 0.7946 0.9456 11 22.27 min 22818

测试集上变化

Run Smoothed Value Step Time 显存占用
scalar-CIFAR10/scalar-256-amp 0.7302 0.8031 11 16.99 min 15508
scalar-CIFAR10/scalar-256-apex 0.7323 0.7956 11 16.51 min 13166
scalar-CIFAR10/scalar-256-fp32 0.7250 0.8092 11 22.27 min 22818

根据知乎:Nicolas和Dreaming.O实验建议:

  • 1、判断你的GPU是否支持FP16:支持的有拥有Tensor Core的GPU(2080Ti、Titan、Tesla等),不支持的(Pascal系列)
import torchif torch.cuda.is_available():device = torch.device("cuda")compute_capability = torch.cuda.get_device_capability(device)print(f"Compute Capability: {compute_capability[0]}.{compute_capability[1]}")
else:print("CUDA is not available.")

结果\(≥7\)说明支持

  • 2、开启混合精度加速后,Training 对 CPU 的利用率会变得很敏感

如果训练时候 CPU 大量被占用的话,会导致严重的减速。具体表现在:CPU被大量占用后,GPU-kernel的利用率下降明显。估计是因为混合精度加速有大量的cast操作需要CPU参与,如果CPU拖了后腿,则会导致GPU的利用率也下降。

  • 3、使用Apex框架会出现溢出情况

因为在Apexamp默认使用的是dynamic可以改为1024或者2048

显存优化

gradient-checkpoint参考: https://www.cnblogs.com/Big-Yellow/p/18646083

参考

1、https://arxiv.org/pdf/1710.03740
2、https://www.exxactcorp.com/blog/hpc/what-is-fp64-fp32-fp16
3、https://zhuanlan.zhihu.com/p/79887894
4、https://zhuanlan.zhihu.com/p/84219777
5、https://nvidia.github.io/apex/amp.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/863533.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用学生优惠创建 Azure Database for MySQL 数据库

文章首先强调了需要一个已通过学生认证的 Azure 账户,然后详细讲解了从登录 Azure 门户页面、选择免费服务、配置服务器和网络等步骤,最终成功创建并部署 Azure Database for MySQL。前言 在此之前,你需要拥有一个已通过学生认证的 Azure 账户。关于通过 Azure 学生认证,网…

【Windows】修改虚拟内存位置

这篇文章详细介绍了如何在 Windows 系统中将虚拟内存文件(pagefile.sys)从 C 盘移动到其他盘。步骤包括查看文件位置、检查和关闭 C 盘加密、修改注册表、设置新的虚拟内存位置并重启电脑,最终实现文件转移。问题:系统优化中,希望将pagefile.sys文件(即虚拟内存)移动到其…

ex7.3

import numpy as np from scipy.interpolate import lagrange import matplotlib.pyplot as plt import matplotlib yx = lambda x: 1/(1+x**2) def fun(n): x = np.linspace(-5, 5, n+1) p = lagrange(x, yx(x)) # n次插值多项式 return p x0 = np.linspace(-5, 5, 100) plt…

【hashMap扩容】关于hashMap扩容以后,新下标的理解

首先我们知道hashMap在存取元素的时候的下标算法是这样子的 根据当前元素(e)的hash值((e.hashCode()) ^ (e.hashCode() >>> 16))去与上当前hashMap的容量减一(Cap-1) put和get都是如此 put get所以在扩容算法中,元素的坐标也应是用这种方式存的,看一下代码我们会发现…

【自动化测试基础】Pytest前后置处理

Pytest的前后置(固件、夹具)处理 有一些初始化配置和测试之后的收尾,只需要处理一次,这个时候我们就要用到夹具。 Pytest提供了以下几种setup和teardown方法:setup_function 和 teardown_function: 用于每个测试函数 setup_method 和 teardown_method: 用于每个测试方法(…

进阶大模型开发框架LangChain

本文来自博客园,作者:王竹笙,转载请注明原文链接:https://www.cnblogs.com/edeny/p/18650785

【unity】学习制作类银河恶魔城游戏-6-

碰撞检查 控制面板定义变量射线功能创建射线实体分配射线实体调整射线编辑碰撞代码 创建地面和墙面的层判断是否碰撞到了地面这行代码的作用是:从groundCheck的位置开始,向下(Vector2.down)投射一条射线,距离为groundCheckDistance,只检测whatIsGround指定的层上的物体。…

直播预告丨社区年度交流会 《RTE 和 AI 融合生态洞察报告 2024》发布

新的一年开始,是时候再深度交流一次了!欢迎关注 1 月 4 日周六晚 社区年度交流会的 线上直播 。这将是一群 实时多模态 AI 开发者 的聚会。我们将一起探讨 Voice Agent 在 AI 陪伴助手、AI 硬件和 AI 企业服务等应用场景中的技术突破与产品创新。同时,我们也会交流 RTE 开发…

从 LB Ingress 到 ZTM:集群服务暴露新思路

12 月 28 日, KubeSphere 社区联合 Higress 社区主办的云原生 AI Meetup 广州站成功召开,我们非常荣幸邀请到CNCF Ambassador、Flomesh 社区布道师——张晓辉老师,张老师为大家带来了一场主题为「从 LB Ingress 到 ZTM:集群服务暴露新思路」的主精彩分享。以下为演讲实录,…

CH32V203F6P6-TSSOP20测试之03---三种烧录方式

CH32V203F6P6-TSSOP20支持三种下载方式:USB下载、串口下载(用串口2即8脚PA2为TX2接下载的RX,9脚PA3为RX2接下载的TX)和SW二线下载。CH32V203F6P6-TSSOP20的BOOT1内置接GND,而BOOT0外露,用户可以选择两种启动模式,因而支持USB下载和串口下载。接法可以选择下面两种方法的…

第二章 BIOS -- MBR

本文是对《操作系统真相还原第二章》学习的笔记,欢迎大家一起交流。第二章 BIOS --> MBR 本文是对《操作系统真象还原》第二章学习的笔记,欢迎大家一起交流。 第一棒 BIOS 首先我们要先明白计算机的启动过程,在 x86 模式下,开机的一瞬间,cpu 的 cs:ip 寄存器被强制初始…

Office Tool Plus v10 - Microsoft office安装使用激活一条龙

下载Office Tool Plus Office Tool Plus是一款相当牛逼的office安装工具,并且安装完了顺带激活,也可以很快捷的卸载office清除激活信息等等。 👉👉点击下载 Office Tool Plus移除老的office安装office 点击页面菜单-部署在产品这栏添加需要的产品.进行下载安装,等待即可…