[megatron代码阅读] 2. TP和PP实现

news/2025/2/6 20:35:36/文章来源:https://www.cnblogs.com/sunstrikes/p/18701639

训练并行实现

TensorParallel

张量并行代码路径, 代码路径: megatron/core/tensor_parallel

主要包含Linear / VocabEmbedding / cross_entropy 三部分.

Linear

参数初始化

如果是从checkpoint热启, perform_initialization需要打开这个配置

1.set_tensor_model_parallel_attributes: 设置weight的三个属性: is_parallel / partition_dim / stride

2.调用传入的init_method, 初始化weight. 这里注意要使用同一个随机种子. 如果是expert网络, 每个expert要用自己独立的rng_tracker

3.如果启用expert_parallel, 设置allreduce属性为false. 否则为true

列并行

sequence_parallel回忆, 为了节省显存,拆分了layernorm后的激活存储. 在进入TP前通过allGather获取到完整的激活, 经过TP后再通过reduceScatter分离到各张卡.

grad_accumulation_fusion

代码: https://github.com/NVIDIA/apex/blob/master/csrc/megatron/fused_weight_gradient_dense_cuda.cu

主要作用是在显存受限,无法一次性更新大batch_size的时候, 通过mini-batch来累加多个小批量的梯度到weight.main_grad, 这里fused意思就是在main_grad上原地更新, 最后在用main_grad来更新大batch里的weight.

LinearWithGradAccumulationAndAsyncCommunication

forward:

  • 输入: 如果开sp, torch.distributed.all_gather_into_tensor input
  • Matmul(input, weight) + bias

backward:

  • wgrad_deferral_limit: TODO: 没弄懂用于降低pipeline flush延迟的含义
  • 如果开sp, grad_input bp对应的集合通信操作是dist_reduce_scatter_func
  • 没开sp, grad_input bp对应的是all_reduce_func
  • 如果开了grad_accumulation_fusion & sp, 需要先all_gather input, 也就是X, 因为求grad_weight的时候需要matmul input的转置.
  • 这里的异步优化指的是先进行 grad_input的异步集合通信, 在此同时计算grad_weight, 算完grad_weight后再等grad_input通信完成, 这样就能overlap一部分通信耗时.
行并行

forward

  • 开sp, input不做处理,因为直接输入切分好的input, 经过线性层后ouput reduceScatter到对应的节点
  • 关sp, input需要先进行ReduceScatter对输入x做切分, 经过线性层后output allReduce结果.

Backward: 没有集合通信

PipelineParallel

核心配置参数有两个:

  1. pipeline_model_parallel_size: pp切分数, transformer_layer实际被切分为多少个group

  2. virtual_pipeline_model_parallel_size: 举例 tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2, 一共有16个transform_layer的情况下, 模型被切分为:

    GPU 0: [1, 2] [9, 10]
    GPU 1: [3, 4] [11, 12]
    GPU 2: [5, 6] [13, 14]
    GPU 3: [7, 8] [15, 16]
    

    一共8个stage, 每个stage有2个layer. PP原理回忆

PP代码逻辑位置 megatron/core/pipeline_parallel

train_step->get_forward_backward_func->forward_backward_pipelining_with_interleaving

image-20250126160014500

P2P通信

有两种方式 batch_isend_irecv_p2p_ops, 后者即send和recv独立作为一个通信操作

batch_isend_irecv: 将send_prev/recv_prev/send_next/recv_next可以异步并发执行

p2p通信步骤:

  1. 传输tensor_shape, int64类型 (类似sequence压缩通信方式, 先传长度)
  2. 对所有的pp group进行遍历, 如果需要recv_prev / recv_next, 先创建空tensor用于结果存储 (这里是否能优化)
  3. 根据是否batch传输, 分别进行并行/串行的方式通信.
  4. 等待通信完成, 进行cuda流同步

1F1B(非交错式)

image-20250126165021086

缺点: 无法支持 p2p通信耗时的overlap

Warmup

num_warmup_microbatchs = min(microbatch, pp_world_size - pp_rank - 1), 比如device1的warmup就是 4 - 0 - 1 = 3, 前3个microbatch warmup的时候, 整体pipeline处于串行的执行状态.

步骤: recv_forward->forward_step->send_forward 再到下一层PP, 直到warmup步骤全部走完.

Steady

在稳态状态下就是1F1B描述的情况. 交替进行fp和bp

以device3刚进入steady状态为例:

  1. forward_step: warmup执行了microbatch1, steady执行的第一个forward是 batch2
  2. send_forward_recv_backward: 向device4发batch2的fp结果, 同时等device4返回batch1的bp结果. 这里是同步通信, 需要等bp执行完成, 这时候并没有跑到batch3的fp上.
  3. backward_step: 执行batch1的bp
  4. send_backward_recv_forward: 把batch1的bp结果发给device2, 同时接受device2的batch3 fp结果, 用来执行下一轮的batch3 fp.

5,6,7,8 图上描述的状态和代码是完全一致的, 但1,2,3,4不完全一致.

Cooldown

和warmup刚好是相反的逻辑.根据warmup microbatchs的个数, 等待bp执行完成.

1F1B with interleaving

虚拟流水线的主要目的是让microbatch_size更小更多, 从而减少气泡。方法是让一个device虚拟成 \(v\) 个device,从计算1个连续的layer段(有 \(x\) 个 layer)变成计算 \(v\) 个不连续的layer段(每段 layer 数量为 \(x\)/\(v\)). 比如之前1F1B时device1负责layer 1~4,device2负责 5~8,在 Interleaved 1F1B下device1负责layer 1~2 和 9~10,device2负责 3~4 和 11~12,这样可以让流水线中每个stage更小,因而下个stage的等待时间更短,气泡更小。需要注意的是,micro_batch_size需要是 pipeline_parallel_size的整数倍。

初始化
  1. warmup_batch数计算方法, 如下代码:
    total_num_microbatches = num_microbatches * num_model_chunks  #模型分块数(virtual pipeline size) * microbatchall_warmup_microbatches = Falseif forward_only:num_warmup_microbatches = total_num_microbatcheselse:# 这里*2的原因是 为了充分利用设备资源,会使用双倍缓冲技术。这意味着每个设备会同时处理两个microbatches,一个在前向传播,另一个在后向传播。因此,热身阶段的microbatches数量需要乘以2,以覆盖前向和后向传播。num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2  # microbatch_group_size_per_vp_stage 默认值 = pipeline_parallel_size, 用于num_warmup_microbatches += (num_model_chunks - 1) * config.microbatch_group_size_per_vp_stageif num_warmup_microbatches >= total_num_microbatches:num_warmup_microbatches = total_num_microbatchesall_warmup_microbatches = Truenum_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
  1. 设置schedule_table, 为了方便计算, 将microbatch+chunk重映射成了virtual_microbatch_id
# PP2 N3M5 with VP2 is constructed as below:
# virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
# microbatch_id         | 0 1 2 0 1 2 3 4 3 4
# model_chunk_id        | 0 0 0 1 1 1 0 0 1 1

根据chunk_id, 还能判断出这个virtual_id是这个device上的第一个chunk还是最后一个chunk

  1. recv_tensor_from_previous_stage: 先判断当前stage是否为leading stage(forward第一个, backward最后一个), 如果virtual_microbatch_id < (pipeline_parallel_size - 1), 说明当前stage没有任何前置需要接受的tensor, 否则说明他和之前的最后一个stage连在一起. 以PP=4举个例子:
#       0 1 2 3 ...  这里的microbatch 0的下一个stage是 device0的microbatch3
#     0 1 2 3 ...
#   0 1 2 3 ...
# 0 1 2 3 ...
warmup

注意配置项: overlap_p2p_comm_warmup_flush: 在打开这个开关后支持overlap warmup和flush阶段前向计算和通信, 后面看代码默认这个开关打开, warmup步骤:

  1. 根据microbatch id判断是不是leading_stage, 如果不是的话需要等上一个循环发出的异步接受前向结果的handle.
  2. 异步通信结果保存在fwd_recv_buffer, 异步发出预取下个循环的recv_forward请求
  3. 进行该stage的forward_step
  4. 把output_tensor 异步发出 send_forward, 等上一个循环的send_next_wait_handle完成.
  5. 把通信完的fwd_recv_buffer 赋值给input_tensor用于下个循环的forward
  6. 在warmup的最后, 触发异步等待recv_backward的请求. 方便衔接steady阶段
steady

循环num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches次, 步骤:

  1. 等warmup的recv_prev 异步执行完, 收到forward结果到buffer里
  2. Forward_step
  3. send_forward_recv_forward: 同时接受previous stage的forward结果, 同时把next stage的forward输入发出.
  4. Wait recv_next传回来的grad
  5. Backward_step
  6. send_backward_recv_backward: 反向往之前的stage发grad
  7. 等上一个batch的backward send_prev发完, 相当于一个buffer切换过程.

整个流程像下面这个流水线示意图.

image-20250206201325086
cooldown

与warmup刚好完全相反, 只有backward的计算和通信操作.

注意在每个阶段完成的时候都回将通信用到的output_tensor重新释放回显存池, 用来缓解显存压力.

参考:

对VPP的进一步优化: https://zhuanlan.zhihu.com/p/681363624

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/879802.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024.2.6鲜花

初探牛顿迭代?推歌 《以恋结缘》 诚、意地の悪い神の所业か? 奇迹?縁?袂触合う不思议 花ひとひら揺れて 不意に宿ってた うなじ解いてく春风 戯れはそこそこに 恋手ほどきしてくだしゃんせ 汤気にほんのり頬染て 夜风に愿ふ …いざ!!蝶と舞ひ花となりて 衣を乱して祓いま…

megatron 2. TP和PP实现

megatron 源码阅读第二篇, 看了TP和PP的对应实现训练并行实现 TensorParallel 张量并行代码路径, 代码路径: megatron/core/tensor_parallel 主要包含Linear / VocabEmbedding / cross_entropy 三部分. Linear 参数初始化 如果是从checkpoint热启, perform_initialization需要打…

【Azure Policy】当Azure策略组中存在多个修正任务时候时的批量处理办法

问题描述 在分配一组策略中包含了很多修正任务时候,从门户上,只能选择一个修正任务执行。 如下图:是否有好的办法,执行全部的修正任务呢?问题解答 从Azure门户的设计来看,只能选择一个修正任务是设计使然。如果想批量执行全部的修正任务,需要使用PowerShell脚本来循环执…

Kotlin空安全

前言 访问空引用的成员变量就会导致空指针异常,在Java中被称作NullPointerException,简称NPE,Kotlin中NPE产生的原因只可能是以下几种:显式调用 throw NullPointerException()使用了!!操作符数据在初始化时不一致,例如:传递一个在构造函数中出现的未初始化的 this 并用于…

Kotlin控制流程

条件与循环 if表达式 Kotlin中的if与Java中的if大致上都差不多,但是Kotlin中没有三元运算符(A ? B : C),可以用if表达式作为代替,例如: Java int a = int a = System.currentTimeMillis() % 2 == 1L ? 1 : 0; Kotlin val a = if (System.currentTimeMillis() % 2 == 1L…

第一次用Markdown

标题 标题2 标题3 标题4 字体 字体 字体姓名 性别 年龄张三 男 20![das]() baidu

【测试基础】web3.0介绍

web3.0介绍 Web3.0也被称为下一代互联网,是对当前互联网(Web2.0)的演进和升级。其目标是实现一个更加去中心化、安全、用户拥有数据主权且具有更好互操作性的互联网环境。Web3.0的核心技术包括区块链、智能合约和加密货币等。 web2.0与web3.0区别 Web2.0和Web3.0的主要区别在…

区块链原理、技术与实践

区块链介绍 区块链是一种分布式账本技术,允许多个参与者共同维护一个不断增长的数据记录列表,每个区块包含一系列交易记录,并通过密码学方法与前一个区块链接起来,形成一个不可篡改和不可逆的链条。 这种基于共识的机制使得区块链具有高度的安全性和透明性。 区块链与传统W…

《高效能人士的七个习惯》

情感账户 勇气和体谅 大石头 自传式回应、同理心倾听:用你的话反映他们的感受和意思,而不是去评论、去判断是否正确

高效能人士的七个习惯

情感账户 勇气和体谅 大石头 自传式回应、同理心倾听:用你的话反映他们的感受和意思,而不是去评论、去判断是否正确

新春“码”启 | Cocos 3D 微信小游戏(第5天):分包构建和上传发布(完美收官)

新春开发 Cocos 3D 微信小游戏计划的第 5 天,详细介绍了如何利用Cocos Creator开发并发布一款3D微信小游戏,包括游戏状态机的设计理念,和微信小游戏主包大小限制时的解决方案——分包策略。从游戏设计、开发、调试到最后成功发布的全过程,为想要进入微信小游戏开发领域的开…

爬虫随笔(一)

爬虫随笔,某牛前几天一直在看js逆向,现在分享一下本人近期学习记录首先分享一个网站,这个网站可以获得request所需要的header和cookie https://curlconverter.com/ 爬取网站就不挂了简单观察发现,该网站是滑动加载,我们可以在滑动加载时获得我们所需要的接口,发现两个链…