LLM并行训练3-数据并行

news/2024/10/4 18:18:06/文章来源:https://www.cnblogs.com/sunstrikes/p/18274445

前置知识

混合精度训练

image-20240627193640147

在参数存储时采取fp32, 开始进行fp/bp时转成fp16运算, 拿到fp16梯度后再转回fp32更新参数.

ZeRO对显存占用的估算:

  • 模型状态: Weights(fp16)、grad(fp16) 和 MasterWeights(fp32 模型参数备份),momentum(fp32)和variance(fp32)。假设模型参数量 \(\phi\) ,则共需要\(2\Phi + 2\Phi + (4\Phi + 4\Phi + 4\Phi) = 4\Phi + 12\Phi = 16\Phi\) 字节存储,
  • 剩余状态: 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)

Adam

image-20240627210940124

在adam optimizer的计算状态除了参数, 还有一个\(m_t\)(momentum 梯度均值)和\(v_t\)(variance 梯度未中心化方差)需要存储, 一般被称为optimizer state.

AllToAll通信原语

image-20240628204846386

allToall类似于矩阵转置. 相当于我们需要先把每个节点里的数据按照他们要传递给哪个节点排好序, 然后根据切分好的顺序推给对应的节点. 可以看到如果每个节点的数据量是M, 节点数是N, 最终通信总量就是M * N

ZeRO

在传统的训练方法里, 每张卡里存储一份完整的模型状态, 完成bp后allReduce grad,再更新每张卡里的副本. 这样子有N张卡就会多出(N-1)份冗余的参数存储. 当参数规模急剧增大时这种方法就完全不适合训练. ZeRO1 主要是将这些冗余的模型状态干掉, 通过增加通信来解决冗余参数的问题. ZeRO原理动态图
image

  • ZeRO1: 只保留一份MasterWeights+momentum+variance.
  • ZeRO2: 在ZeRO1的基础上去除了grad的冗余
  • ZeRO3: 在ZeRO2的基础上去掉了weights的冗余
image-20240627214641908

训练流程

以ZeRO3为例. 主要分为5步, 假设使用了4张卡进行训练:

  1. 每张卡上存1/4的W, OS和grad. 每张卡训练自己分配到的batch.
  2. fp时, AllGather所有卡上的W,取到全量的W(fp16)进行fp, 完成后只保留自己需要维护的1/4 W, 其他显存释放回池
  3. bp时, AllGather所有卡上的W进行bp, 完成后再抛弃其他卡维护的W
  4. 完成bp后, ReduceScatter所有卡的G, 从其他卡上取到需要需要更新的梯度增量, 然后释放不是自己维护的G.
  5. 使用自己维护的OS和G来更新W, 不需要通信.
image-20240628163731199 image-20240628194209187

通信量分析

定义单卡数据量为\(\phi\)

传统DP: bp完成后需要对梯度进行一次AllReduce, 一共\(2\phi\)

ZeRO1: 只舍弃了OS, bp时需要AllReduce G(Scatter+Gather 共\(2\phi\)). 另外在使用每张卡各自更新W时, 因为W每张卡都存储的全量, 需要从存储OS的卡上把对应更新后的W再拉回来, 所以需要一次Gather(\(\phi\)), 一共需要\(3\phi\)

ZeRO2: 舍弃了OS和G, bp时AllGather G(\(\phi\)), 更新W时从其他卡拉W, 再Gather一次(\(\phi\)), 一共需要\(2\phi\)

ZeRO3: 上面训练过程分析过, 共需要2次Gather和1次Scatter, 一共需要\(3\phi\)

可以看到ZeRO在通信量只增加了1.5倍的情况下, 显存降了60倍. 效果非常显著

ZeRO++

ZeRO存在的问题是会在GPU之间产生大量数据传输开销,降低了训练效率. 主要有两种情况:

  1. 全局batch size较小,而 GPU数量多,这导致每个 GPU 上batch size较小,需要频繁通信

  2. 在低端集群上进行训练,其中跨节点网络带宽有限,导致高通信延迟。

ZeRO++主要采用了3部分优化: 权重量化 (qwZ), 分层分割存储 (hpZ), 梯度量化 (qgZ). 对比ZeRO通信量减少了4倍, 主要的难点都在减小量化带来的训练误差

权重量化

    def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:q_range = 2**self.config['num_bits'] - 1min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True)max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True)scale = q_range / (max_value - min_value)tensor = tensor.sub_(min_value).mul_(scale)tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8)  #对称式量化return tensor, scale, min_value

量化代码在deepspeedcsrc/quantization/quantize.cu cached_quantization 这个kernel里.

如果采用全局fp16->int8的量化会导致极大误差. deepspeed采用了分区量化的方法, 把参数分为固定大小的block后, 先根据这个block的max/min计算出scale(量化系数), 在把这个参数传入量化函数里. 另外在通信的时候应该也需要每个block对应的系数传给接收节点用于反量化.

\[量化公式: clip(round(scale * x), -2^{b-1}+1, 2^{b-1}-1) \]

通过这种方式在通信量减半的同时还能保证精度, 很nice的思路.

image-20240628215923339

分层分割存储

image-20240628194529492

之前ZeRO的W切分方法是根据卡数均分. 在fp/bp之前进行AllGather拉取, 后来发现在机器间进行Gather通信是比较严重的瓶颈. 所以最后W的切分变成了每个节点内存储全量的W, 节点内根据卡数进行切片. 避免跨节点经过网卡的通信, 通过增加显存使用的方式解决通信瓶颈.

显存消耗: ZeRO3的单卡显存消耗为 $\frac{(2+2+K)*\phi}{N} \(, 这里每个节点多存了一份W, 如果有\)\alpha$个物理节点, 那么每张卡使用的显存就多了 \(\frac{\alpha * \phi}{N}\)

梯度量化

如果直接在之前zero RingAllReduce的通信方式上加量化和反量化, 如下图左, 可以看到需要节点个数次量化/反量化. 而每次量化都是有损的, 这样会导致无法接受的训练误差. 为了解决这个问题zero++使用了一次量化->AllToAll通信->一次反量化的操作. 而因为直接进行AllToAll通信量从M(参数量)变成了M*N/Z(N: 节点数, Z:量化压缩率), 这个通信量的增长过大. deepspeed设计了2-hpop all-to-all方法来解决通信问题.

image-20240628200906350

具体图示流程可以参考Deepspeed的blog动态图, 文字版步骤:
image

  1. 节点内的卡间张量切片重排. 主要是因为alltoall切分成了两步, 如果不重排如下图左. 最后顺序会变错位, 然后进行参数量化

    image-20240628210835122
  2. 节点内alltoall通信后反量化.先把卡内能合并的梯度加起来. 这里反量化主要是为了减小梯度累加的精度损失

  3. 再次量化后, 节点间进行allToAll

  4. 拿到通信结果, 反量化后再次reduce. 得到最终的梯度.

这里要进行两次alltoall的原因主要是, 第一次卡间alltoall之后梯度累加可以减少卡数倍的通信规模. 实际deepspeed在实现的时候还把重分片和量化kernel进行了fuse, 进一步优化性能

还有下图的方法, 在通信当前层的时候, 通过多流异步量化下一层要通信的数据. 避免同步等待的浪费

image-20240628211824538

参考

zero: https://arxiv.org/pdf/1910.02054

混合精度训练: https://arxiv.org/pdf/1710.03740

zero++: https://arxiv.org/abs/2306.10209

Deepspeed blog: https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/733070.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

货车运输的五种解法

上完课整的活(这里的“五种解法”之间有实现方式之外的不同) 方法1:最大生成树 + 树上倍增 本题的标准解法,先用 kruskal 建出最大生成树,再在最大生成树跑树上倍增求路径 \(min\) ,时间复杂度为 \(\Theta(n \log n + q \log n)\)。 树上倍增也可以用树剖替换,但是需要两…

一张图让你看懂10种软件架构风格

软件架构风格是构建各种软件系统的基本蓝图,确保它们满足特定的要求和质量属性。通过坚持合适的架构风格,组织可以确保其软件系统的构建与其战略目标保持一致,适应未来的变化,并在面对不断发展的技术环境和用户需求时具有弹性。 以下是最常见的样式:(单体架构):将整个…

linux三剑客-grep、sed、awk

Linux三剑客是Linux系统中最重要的三个命令,它们以其强大的功能和广泛的应用场景而闻名。这三个工具的组合使用几乎可以完美应对Shell中的数据分析场景,因此被统称为Linux三剑客。 1、grep grep是一个强大的文本搜索工具,用于在文件内容中查找指定的字符串,并将匹配到的行输…

联通网络无法使用FTP,无法使用21端口连接的解决方法

最近家里换了联通的网络,结果发现连接不上FTP了,本来以为是软件的问题。最后发现只有21端口的FTP连接不上,其它的端口没问题。 首先想到的是肯定是联通的光猫把21端口给关闭了。然后就想着通过192.168.1.1来设置一下光猫。专业网站制作、系统开发订制、微信公众号开发、接外…

LINUX命令-sed

sed命令是用于对文本文件做内容操作的神器,常见的增删改都可以,熟练运用可提高shell脚本编写能力和在terminal下的工作效率。本文编辑小绝技-sed sed命令是用于对文本文件做内容操作的神器,常见的增删改都可以,查没必要用它,用grep或者gvim打开用vim的搜索匹配就行。 sed …

毕业好几年了还要考研吗?

其实,毕业多少年都不影响我们考研,因为考研本身并没有年限或者年龄上的限制。所以,在是否考研这个问题上,我们真正应该思考的是,我们是否已经对未来做了一个比较合理的规划,考研这件事是否在未来的规划中有着重要的影响,如果是,而且现实条件也允许我们去考,那么,就应…

ubuntu 下使用netplan配置网络

一个yaml走遍天下。 netplan 是Ubuntu底层网络配置的封装,它允许使用yaml的方式“声明式”的配置底层网络,不管底层网络是NetworkManager还是networkd. netplan 官网,使用静态配置的示例: https://netplan.readthedocs.io/en/stable/netplan-tutorial/#using-static-ip-add…

(并查集+双向映射)

题意: 思路: 题目就是让我们实现把一个代表数x的集合加到另一个代表数y的集合中多次操作,这个很容易想到用并查集维护,将相同数字的下标放到一个集合中,集合所代表的数字,用“集合的首领”和代表的数字做一个双射,这样既能表示出集合所带表的数,还能辅助之后输出集合,…

2024年Java学习路线

java 最新学习路线2024Java学习路线(快速版) 核心基础:Java基础→MySQL→JDBC→JavaWeb 微服务核心:Maven→Gradle→Spring6→SpringMVC→MyBatis→MyBatisPlus→SSM →Redis7→SpringBoot2→SpringCloud 微服务生态:Git→Docker→Elasticsearch→ZooKeeper→Nginx→Sprin…

DApp设计与开发 课程笔记(二)remix | hardhat | 测试驱动开发

笔记对应课程内容为成都信息工程大学区块链产业学院老师梁培利的DApp 设计与开发 04-06 课 笔记中提到的名词不做过多解释 不懂就搜!Remix IDE的基本使用 官网:https://remix.ethereum.org/建议使用其网页版而不是桌面版,侧重于比较实用的特性而不是全部的介绍。 支持编写合…

DApp设计与开发 课程笔记(二)

笔记对应课程内容为成都信息工程大学区块链产业学院老师梁培利的DApp 设计与开发 04-06 课 笔记中提到的名词不做过多解释 不懂就搜!Remix IDE的基本使用 官网:https://remix.ethereum.org/建议使用其网页版而不是桌面版,侧重于比较实用的特性而不是全部的介绍。 支持编写合…

OOP第三次博客

write_by_23201707_gongjunjie oop第三次博客 一:前言 这次博客不出意外是oop课程的最后一次博客了,不过这次博客pta只有两题,但是我想说的是,最后一次pta也是够难的, 但是好像我自己的设计也有很大的问题,第七次pta遗留下了一点问题,导致第八次出现了很多问题 二:关于…