大模型训练中优化策略(数据并行、模型并行、ZeRO等)

在微调时,模型显存占用主要包括模型参数参数梯度优化器中间结果四个部分。

对于一个 6B 参数量的模型,它的模型参数占用为:

在这里插入图片描述
将模型参数视为基准,模型梯度占用量与模型参数相同。

优化器主采用 Adam Optimizer ,它核心计算公式如下:

在这里插入图片描述由于需要保存 m 和 v,而 m 和 v 规模与参数梯度相同,因此优化器需要两倍显存容量。

同时,在计算中得到的中间结果需要保存在显存中,以便反向传播时计算梯度。 对于每一个中间结果,其数据形状为 [Batch, SeqLen, Dim]。

技术交流&资料

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远.

成立了大模型技术交流群,本文完整代码、相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:来自CSDN + 技术交流

通俗易懂讲解大模型系列

  • 做大模型也有1年多了,聊聊这段时间的感悟!

  • 用通俗易懂的方式讲解:大模型算法工程师最全面试题汇总

  • 用通俗易懂的方式讲解:不要再苦苦寻觅了!AI 大模型面试指南(含答案)的最全总结来了!

  • 用通俗易懂的方式讲解:我的大模型岗位面试总结:共24家,9个offer

  • 用通俗易懂的方式讲解:大模型 RAG 在 LangChain 中的应用实战

  • 用通俗易懂的方式讲解:一文讲清大模型 RAG 技术全流程

  • 用通俗易懂的方式讲解:如何提升大模型 Agent 的能力?

  • 用通俗易懂的方式讲解:ChatGPT 开放的多模态的DALL-E 3功能,好玩到停不下来!

  • 用通俗易懂的方式讲解:基于扩散模型(Diffusion),文生图 AnyText 的效果太棒了

  • 用通俗易懂的方式讲解:在 CPU 服务器上部署 ChatGLM3-6B 模型

  • 用通俗易懂的方式讲解:使用 LangChain 和大模型生成海报文案

  • 用通俗易懂的方式讲解:ChatGLM3-6B 部署指南

  • 用通俗易懂的方式讲解:使用 LangChain 封装自定义的 LLM,太棒了

  • 用通俗易懂的方式讲解:基于 Langchain 和 ChatChat 部署本地知识库问答系统

  • 用通俗易懂的方式讲解:在 Ubuntu 22 上安装 CUDA、Nvidia 显卡驱动、PyTorch等大模型基础环境

  • 用通俗易懂的方式讲解:Llama2 部署讲解及试用方式

  • 用通俗易懂的方式讲解:基于 LangChain 和 ChatGLM2 打造自有知识库问答系统

  • 用通俗易懂的方式讲解:一份保姆级的 Stable Diffusion 部署教程,开启你的炼丹之路

  • 用通俗易懂的方式讲解:对 embedding 模型进行微调,我的大模型召回效果提升了太多了

  • 用通俗易懂的方式讲解:LlamaIndex 官方发布高清大图,纵览高级 RAG技术

  • 用通俗易懂的方式讲解:为什么大模型 Advanced RAG 方法对于AI的未来至关重要?

  • 用通俗易懂的方式讲解:使用 LlamaIndex 和 Eleasticsearch 进行大模型 RAG 检索增强生成

  • 用通俗易懂的方式讲解:基于 Langchain 框架,利用 MongoDB 矢量搜索实现大模型 RAG 高级检索方法

  • 用通俗易懂的方式讲解:使用Llama-2、PgVector和LlamaIndex,构建大模型 RAG 全流程

GPU 显存分析

GPU显存分布.png

Collective Operations

为了节省显存,可以将模型或者数据分配到不同的显卡上,显卡之间有如下几种 Collective Operations

Broadcast

广播.png

The Broadcast operation copies an N-element buffer on the root rank to all ranks.

广播操作将一张显卡上数据广播到所有显卡。

AllReduce、Reduce、ReduceScatter

AllReduce.png

reduce.png

ReduceScatter.png

The AllReduce operation is performing reductions on data (for example, sum, min, max) across devices and writing the result in the receive buffers of every rank.

The Reduce operation is performing the same operation as AllReduce, but writes the result only in the receive buffers of a specified root rank.

The ReduceScatter operation performs the same operation as the Reduce operation, except the result is scattered in equal blocks between ranks, each rank getting a chunk of data based on its rank index.

AllReduce 操作将所有显卡上数据进行聚合如求和取最大值取最小值,并将结果写入所有显卡。

Reduce 只会将结果写入一张显卡。

ReduceScatter 则将结果分散在所有显卡中。

AllGather

AllGather.png

The AllGather operation gathers N values from k ranks into an output of size k*N, and distributes that result to all ranks.

AllGather 操作会收集所有显卡数据,并写入所有显卡中。

数据并行

数据并行是将数据分成若干份,装载到不同节点上进行计算。

数据并行.png

数据并行计算流程如下:

  1. 有个参数服务器保存模型参数。
  2. 参数被复制到不同的设备中,构成若干 replicas。每个 replica 处理一部分数据,进行前向传播和反向传播。
  3. 每个设备得到梯度进行 Reduce 操作,得到最终梯度,并按照这个梯度更新参数服务器中的模型参数。
  4. 在后向传播时,每计算完一层的梯度,就可以进行 Reduce 操作,提高并行性。

分布式数据并行

分布式数据并行.png

分布式数据并行中不存在参数服务器,其计算流程如下:

  1. 每个 replica 都保存模型参数,但是分别计算部分数据,进行前向传播和反向传播。
  2. 每个设备都得到梯度后进行 AllReduce 操作,将梯度写入所有设备,每个设备根据自己的优化器和梯度更新参数。

分布式数据并行中,每个设备显存占用情况如图:

分布式数据并行显存占用.png

其中每个设备仍需要保存模型参数、梯度和优化器参数。

模型并行

由于模型越来越大,单个设备保存模型参数、梯度和优化器越来越难。因为深度学习主要是矩阵计算,而矩阵计算可以分块计算,因此可以将模型参数拆成若干份,每份单独计算,以减少显存占用。

在这里插入图片描述

模型并行.png

其计算流程如下:

  1. 将参数矩阵分成若干子矩阵,分发到不同设备中。
  2. 每个设备计算不同矩阵,然后将结果收集起来。

模型并行后,显存占用如下:

模型并行显存占用.png

由于每个设备处理所有数据,因此中间结果都会保存在所有设备中。

ZeRO

在分布式数据并行中,最后梯度更新在不同设备进行的操作相同,多个设备中参数相同,梯度相同,优化器状态相同,存在大量冗余。

ZeRO-1 对优化器状态进行分片。

ZeRO-1.png

ZeRO-1 计算流程如下:

  1. 每个 replica 处理一部分数据输入。
  2. 独立进行前向传播。
  3. 独立进行反向传播。
  4. 得到完整梯度后进行 ReduceScatter,每个 replica 得到对应梯度。
  5. 每个 replica 更新梯度对应的部分参数。
  6. 使用 AllGather 同步更新所有参数。

ZeRO-2 计算流程与1基本相同,ZeRO-2在后向传播时,每计算一层梯度,就可以使用 ReduceScatter 进行同步,提高并行度。同时由于不需要完整计算梯度之后进行 ReduceScatter,每个 replica 只需要保存部分梯度即可。

ZeRO-3 在 2 的基础上,将模型参数进行分片。

ZeRO-3.png

ZeRO-3 计算流程如下:

  1. 每个 replica 处理一部分输入。
  2. 前向传播时,当需要别的层参数,使用 AllGather 获取。
  3. 反向传播时,当需要别的层参数时,使用 AllGather 获取,同时计算出每一层梯度时,使用 ReduceScatter 分发到对应 replica。
  4. 每个 replica 用于部分优化器参数和梯度,进行对应参数更新。

不同 ZeRO 对应的显存占用情况:

ZeRO显存占用.png

流水线并行

将模型一层一层分开,不同层放入不同 GPU 进行计算。个人理解与模型并行不同的是,模型并行保留从头到尾每一层的部分参数,输入可以计算出结果。流水线并行需要等前一层计算完毕才能进行计算。

流水线并行.png

流水线并行显存分析:

流水线并行显存分析.png

混合精度

FP16 相较于 FP32 计算更快,同时占用更少的显存。但同时 FP16 表示的范围小,可能产生溢出错误。

特别的,在权重更新时 gradient * lr 导致下溢出。

混合精度训练的思路在优化器中保留一份 FP32 格式的参数副本,而模型权重、梯度等数据在训练中都是用 FP16 来存储。

混合精度.png

优化器中参数更新在 FP32 格式下保证精度,之后转换为 FP16 格式。

Checkpointing

由于模型反向传播需要中间结果计算梯度,大量中间结果占用大量显存。

Checkpointing 思路是保存部分隐藏层的结果(作为检查点),其余的中间结果直接释放。当反向传播需要计算梯度时,从检查点开始重新前向传播计算中间结果,得到梯度后再次释放。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/475221.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenGL学习——14.投光物_点光源

前情提要:本文代码源自Github上的学习文档“LearnOpenGL”,我仅在源码的基础上加上中文注释。本文章不以该学习文档做任何商业盈利活动,一切著作权归原作者所有,本文仅供学习交流,如有侵权,请联系我删除。L…

SaaS(软件即服务)是什么,在中国有发展?

SaaS(Software as a Service,软件即服务)是一种基于互联网的软件交付模式,用户通过互联网访问和使用软件,而无需在本地安装和维护软件。这种模式通常以订阅的形式提供,用户按照一定的周期支付费用&#xff…

K8s进阶之路-Pod的生命周期

Pod创建过程: 首先创建一个pod,然后创建一个API Server 和 Etcd【把创建出来的信息存储在etcd中】 然后创建 Scheduler,监控API Server是否有新的Pod,如果有的话,会通过调度算法,把pod调度某个node上 在nod…

C++之Easyx——图形库的基本准备工作

什么是Easyx? EasyX Graphics Library 是针对 Visual C 的免费绘图库,支持 VC6.0 ~ VC2022,简单易用,学习成本极低,应用领域广泛。目前已有许多大学将 EasyX 应用在教学当中。 它比Red PandaDev C上的图形库功能要强…

浅谈iPaaS对企业转型的重要性

面对数字化转型的大浪潮,众多企业都期望着能快速实现全面的数字化转型,让企业在日益激烈的竞争中拥有更稳的市场地位,提升自身的实力及能力,奠定更坚实的基底。但在数字化转型过程中,部分企业数字化基础水平较薄弱&…

云手机在引流方面有什么优势?

对于电商商家而言,无论是在亚马逊还是其他平台,有效的流量来源主要集中在短视频引流和社交电商营销。要在新兴社交平台为企业电商带来更多流量,不可忽视云手机的关键作用和独特优势。 云手机的定义与作用 在经营TikTok、Facebook和INS账号时&…

网络原理 - HTTP/HTTPS(1)

HTTP HTTP是什么 HTTP("全程超文本协议")是一种应用非常广泛的应用层协议. 文本:字符串(能在utf8/gbk)码表上找到合法字符. 超文本:不仅是字符串,还能携带图片啥的(HTML). 富文本:类似于word文档这种. HTTP诞生于1991年.目前已经发展为最主流使用的一种应用层协议.…

中国社科院与英国斯特灵大学创新与领导力博士—应该怎样选专业

现如今其实有很多人感觉只是平台成就自己,离开平台自己并无一技之长或过人之处。但是又不想如此安稳过日,一直终老。所以现在大多数人都会去想在职读个博士。 基本上都是在职博士专业为那些希望边工作边获得博士学位的在在职人员开设的,那么&…

Vue3+Vite+TS+Pinia+ElementPlus+Router+Axios创建项目

目录 初始项目组成1. 创建项目1.1 下载项目依赖1.2 项目自动启动1.3 src 别名设置vite.config.ts配置文件tsconfig.json配置若新创项目ts提示 1.4 运行测试 2. 清除默认样式2.1 样式清除代码下载2.2 src下创建公共样式文件夹style2.3 main.js中引入样式2.4 安装sass解析插件 2.…

Mysql开启bin-log日志

目录 一、安装配置 二、mysqlbinlog命令 一、安装配置 yum -y install mariadb mariadb-server#安装mysql数据库#默认配置文件/etc/my.cnfvim /etc/my.cnflog-binmariadb-bin #开启二进制日志 systemctl restart mariadb#会在/car/lib/mysql/产生二进制日志文件&#xff0…

HBuilderX 插件开发指南(一):从插件开发到发布的完整流程

前端目前主流使用的IDE工具有VS Code、Sublime Text3、HBuilder X等等 本期我们主要了解HBuilder X,作为前端通用型开发工具,拥有可视化的操作方式,内置相关环境,开箱即用,无需配置nodejs等优点外,对uni-a…

matlab代码--基于注水法的MIMO信道容量实现

今天接触一个简单的注水法程序,搞懂数学原理即可看懂代码。 1 注水法简介 详细原理可以参考: MIMO的信道容量以及实现 大致理论就是利用拉格朗日乘子法,求解信道容量的最大化问题,得到的解形如往水池中注水的形式,最…