大模型训练显存优化推理加速方案

当前的深度学习框架大都采用的都是fp32来进行权重参数的存储,比如Python float的类型为双精度浮点数fp64,pytorch Tensor的默认类型为单精度浮点数fp32。随着模型越来越大,加速训练模型的需求就产生了。在深度学习模型中使用fp32主要存在几个问题,第一模型尺寸大,训练的时候对显卡的显存要求高;第二模型训练速度慢;第三模型推理速度慢。其解决方案就是使用低精度计算对模型进行优化。本文主要讲解几种优化显存存储的方法。

1. fp32、fp16、bf16混合精度训练

  • FP32 是单精度浮点数,1位符号位,8位指数,23位表示小数,总共32位
  • BF16 是对FP32单精度浮点数截断数据,即用8bit 表示指数,7bit 表示小数
  • FP16 半精度浮点数,用5bit 表示指数,10bit 表示小数;
    请添加图片描述
    与32位相比,采用BF16/FP16吞吐量可以翻倍,内存需求可以减半。但是这两者精度上差异不一样,BF16 可表示的整数范围更广泛,但是尾数精度较小;FP16 表示整数范围较小,但是尾数精度较高。

1.1 混合精度训练

直接使用半精度进行计算会导致的两个问题的处理:舍入误差(Rounding Error)和溢出错误(Grad Overflow / Underflow)

  • 舍入误差
    float16 的最大舍入误差约为 2 − 10 ~2 ^{-10}  210,比 float32 的最大舍入误差 2 − 23 ~2 ^{-23}  223 要大不少。 对足够小的浮点数执行的任何操作都会将该值四舍五入到零。在反向传播中很多梯度更新值都非常小,但不为零,在反向传播中舍入误差累积可以把这些数字变成0或者nan, 这会导致不准确的梯度更新,影响网络的收敛

  • 溢出错误
    由于 float16 的有效的动态范围(正数部分,负数部分与正数对应)约为 5.96 × 1 0 − 8 ∼ 6.55 × 10 4 5.96\times10^{-8} \sim 6.55\times10{^4} 5.96×1086.55×104,比单精度的 float32 的动态范围 1.4 × 1 0 − 45 ∼ 1.7 × 1 0 38 1.4\times10^{-45} \sim 1.7 \times10^{38} 1.4×10451.7×1038要狭窄很多,精度下降会导致得到的值大于或者小于fp16的有效动态范围,也就是上溢出或者下溢出。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况

针对以上两种情况的解决方法是混合精度训练(Mixed Precision)和损失缩放(Loss Scaling)

  • 混合精度训练
    混合精度训练是一种通过在FP16上执行尽可能多的操作来大幅度减少神经网络训练时间的技术,在像线性层或是卷积操作上,FP16运算较快,但像Reduction运算又需要 FP32的动态范围。通过混合精度训练的方式,便可以在部分运算操作使用FP16,另一部分则使用 FP32,混合精度功能会尝试为每个运算使用相匹配的数据类型,在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。这样在权重更新的时候就不会出现舍入误差导致更新失败,混合精度训练的策略有效地缓解了舍入误差的问题

  • 损失缩放
    尽管使用了混合精度训练,还是会存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出。损失缩放是指在执行反向传播之前,将损失函数的输出乘以某个标量数(论文建议从8开始)。 乘性增加的损失值产生乘性增加的梯度更新值,提升许多梯度更新值到超过FP16的安全阈值2^-24。 只要确保在应用梯度更新之前撤消缩放,并且不要选择一个太大的缩放以至于产生inf权重更新(上溢出) ,从而导致网络向相反的方向发散

bf16/fp32 混合训练因为两种格式在范围上对齐了,并且 bf16 比 fp16 的范围更大,所以要比 fp16/fp32 混合训练稳定性更高

2. gradient checkpointing

gradient checkpointing(梯度检查点)的工作原理是在反向传播时重新计算深度神经网络的中间值(通常情况是在前向传播时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

3. Xformers

Xformers 应该是目前社区知名度最高的优化加速方案了,所谓 Xformers 指的是该库将各种transformer 架构的模型囊括其中。注意,该库仅适用于N卡,特点是加速图片生成并降低显存占用,代价是输出图像不稳定,有可能比不开Xformers略差。各种transformer变体可以参考 A Survey of Transformers.

参考

  • 彻底搞懂float16与float32的计算方式
  • pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等
  • facebookresearch/xformers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/115412.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Prometheus+Grafana可视化监控【Redis状态】

文章目录 一、安装Docker二、安装Redis数据库(Docker容器方式)三、安装Prometheus四、安装Grafana五、Pronetheus和Grafana相关联六、安装redis_exporter七、Grafana添加Redis监控模板 一、安装Docker 注意:我这里使用之前写好脚本进行安装Docker,如果已…

Python实现查询一个文件中的pdf文件中的关键字

要求,查询一个文件中的pdf文件中的关键字,输出关键字所在PDF文件的文件名及对应的页数。 import os import PyPDF2def search_pdf_files(folder_path, keywords):# 初始化结果字典,以关键字为键,值为包含关键字的页面和文件名列表…

七天学会C语言-第七天(结构体)

1.定义结构体 例 1&#xff1a;把一个学生的信息(包括学号、姓名、性别、住址等 4 项信息) 放在一个结构体变量中&#xff0c;然后输出这个学生的信息。 #include <stdio.h>struct Student {int student_id;char name[30];char gender;char address[60]; };int main() …

合并两个升序链表,哨兵位的理解

开始时也要判断是否有一个链表本来就是空&#xff0c;如果是&#xff0c;直接返回另外一个链表 代码&#xff1a; struct ListNode* mergeTwoLists(struct ListNode* list1, struct ListNode* list2){if(list1NULL){return list2;}if(list2NULL){return list1;} struct ListN…

【收藏】如何最快取得NISP二级和CISP

【收藏】如何最快取得NISP二级和CISP &#x1f449;今天小编来给大家讲解一下如何最快的取得NISP二级和CISP证书 ✅我们从如下几个方面为大家讲解&#xff1a; &#x1f53a;报名条件 &#x1f53a;考试形式 &#x1f53a;考试题型 &#x1f53a;如何备考 &#x1f53a;证书用途…

[Linux入门]---Linux项目自动化构建工具-make/Makefile

目录 1.背景2.make指令输入make默认为Makefile文件第一条指令执行Makefile文件对gcc指令特殊处理及原理特殊符号 3.总结 1.背景 会不会写makefile&#xff0c;从一个侧面说明了一个人是否具备完成大型工程的能力一个工程中的源文件不计数&#xff0c;其按类型、功能、模块分别放…

maven多模块依赖包程序包xxx不存在

背景 rpc-common 被 rpc-server、rpc-client依赖 项目地址 https://github.com/pjmike/springboot-rpc-demo mvn clean install 打包时报错 报错信息 程序包xxxx不存在 找不到符号 原因分析 原因还不清楚&#xff0c;网友们帮解答一下 解决 主pom.xml 添加 <packaging…

微调大型语言模型(一):为什么要微调(Why finetune)?

今天我们来学习Deeplearning.ai的在线课程 微调大型语言模型(一)的第一课&#xff1a;为什么要微调(Why finetune)。 我们知道像GPT-3.5这样的大型语言模型(LLM)它所学到的知识截止到2021年9月&#xff0c;那么如果我们向ChatGPT询问2022年以后发生的事情&#xff0c;它可能会…

LeetCode 刷题记录——从零开始记录自己一些不会的(二)

20. 替换后的最长重复字符 题意 给你一个字符串 s 和一个整数 k 。你可以选择字符串中的任一字符&#xff0c;并将其更改为任何其他大写英文字符。该操作最多可执行 k 次。 在执行上述操作后&#xff0c;返回包含相同字母的最长子字符串的长度。 思路 代码 class Solution…

API(十一) 获取openresty编译信息

一 ngx.config 说明&#xff1a; 不常用,了解即可 ngx.config.subsystem 说明&#xff1a; 用的四层还是七层代理 ngx.config.debug 说明&#xff1a; 返回的是boolean类型, openresty rpm安装一般没有 --with-debug编译选项对比&#xff1a; nginx rpm 安装一般携带 --wi…

【二叉树魔法:链式结构与递归的纠缠】

本章重点 二叉树的链式存储二叉树链式结构的实现二叉树的遍历二叉树的节点个数以及高度二叉树的创建和销毁二叉树的优先遍历和广度优先遍历二叉树基础oj练习 1.二叉树的链式存储 二叉树的链式存储结构是指&#xff0c;用链表来表示一棵二叉树&#xff0c;即用链来指示元素的逻辑…

Labelme分割标注软件

Labelme分割标注软件 1、环境配置与安装1.1 创建conda虚拟环境(建议)1.2 安装Labelme 2、简单使用2.1 创建label标签文件2.2 启动labelme2.3 打开文件/文件夹2.4 设置保存结果路径2.5 标注目标2.6 保存json文件格式 3 格式转换3.1 转换语义分割标签3.2 转换实例分割标签 相关重…