[megatron代码阅读] 1. 初始化和组网

news/2025/1/13 16:39:48/文章来源:https://www.cnblogs.com/sunstrikes/p/18668809

pretrain_gpt.py为例, 看megatron的整体逻辑. 本章主要包括megatron初始化相关逻辑, 核心函数为initialize_megatron, setup_model_and_optimizer两个

initialize_megatron

parse_args

从argparse中直接读取超参数配置. 如学习率, 正则化等. 从环境变量中获取rank等

load_args_from_checkpoint

  • 优先从未被持久化的ckpt加载, 并且只加载rank0的args

  • _load_non_persistent_base_checkpoint

    • find_checkpoint_rank_0

      在不知道是否使用pp/ep策略的情况下, 尝试拼装出rank0 ckpt的名称, 如果存在就能定位到实际的存放目录

    • verify_checkpoint_and_load_strategy

      根据是zarr还是 torch_dist选择不同的加载策略

    • TorchCommonLoadStrategy->torch.load()

  • 如果没有非持久化的, 加载远端ckpt

  • 从ckpt里的args替换掉之前解析的部分args, 比如tp/pp/vp等超参数

校验yaml/args, 全局变量设置

_initialize_distributed

pytorch里的get_world_size 返回的是gpu总卡数

初始化torch.distributed

mpu.initialize_model_parallel (并行设置,核心函数)

RankGenerator:

  1. 在每块GPU上启动一个进程(process),每个进程独立执行自己所维护的那部分模型的计算,实现并行训练
  2. 存储tp/pp/dp/ep/cp 各种并行度配置大小. 并且能够从 tp-dp str格式的并行配置里获取 tp/dp对应的mask和并行度大小设置.
  3. get_ranks: 根据parallel_size和mask, 计算各种并行策略拆分后的rank group.

[!NOTE]

举例: 假定有2个8卡机器,node1: rank 0-7,node2: rank 8-15 tp-pp-dp: [2,4,2]

  • _TENSOR_MODEL_PARALLEL_GROUP :[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]。
  • _PIPELINE_MODEL_PARALLEL_GROUP : [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]。
  • _MODEL_PARALLEL_GROUP :tp-pp = 2 * 4 = 8 [0, 1, 4, 5, 8, 9, 12, 13],[2, 3, 6, 7, 10, 11, 14, 15]
  • _DATA_PARALLEL_GROUP :[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]。
分隔样例

注意在PP内输入层和输出层共享一个word_embedding,PP组中的第一个和最后一个rank需要通讯,保证word_embedding完全一致

group全局变量赋值: 每个并行模式有一个分组全局变量.通过 generator_wrapper生成, 自己的进程rank如果在group内, 初始化对应的nccl/gloo torch.distributed.new_group

GlobalMemoryBuffer: 保存每个已经分配出的tensor, 避免显存重分配.

setup_model_and_optimizer

主要逻辑是配置模型组网和优化器.

model_provider: torch gpt组网

megatron/core/transformer, transformer组网核心逻辑, 基于torch.nn.Module, 将涉及到的子模型结构进行了抽象. 通过subModule的方式嵌入自定义module, 便于代码复用

例如

self_attention=ModuleSpec(module=SelfAttention,params={"attn_mask_type": attn_mask_type},submodules=SelfAttentionSubmodules(linear_qkv=ColumnParallelLinear,core_attention=DotProductAttention,linear_proj=RowParallelLinear,q_layernorm=IdentityOp,k_layernorm=IdentityOp,),
)

attention.py里读到之前moduleSpec中的对应linear_qkv的实现, 即TP列并行的Linear实现. 加上TransformerConfig, 就能定义出最终的网络逻辑. TP相关逻辑在后续专门看的时候再细写.

self.linear_qkv = build_module(submodules.linear_qkv,self.config.hidden_size,self.query_projection_size + 2 * self.kv_projection_size,config=self.config,init_method=self.config.init_method,gather_output=False,bias=self.config.add_bias_linear or self.config.add_qkv_bias,skip_bias_add=False,is_expert=False,tp_comm_buffer_name='qkv',
)

torch里实现module时, 主要关注__init__()forward(), bp通过自动微分生成.

配置

配置类 ModelParallelConfig, TransformerConfig

ModelParallelConfig: 主要包括 模型并行/PP/通信overlap相关优化开关/cpuOffload 等相关配置

TransformerConfig: 主要包括 模型结构/MOE/算子fusion加速/激活重计算/Context并行 等配置

models/gpt/gpt_model.py

preprocess

分为word_emb和pos_emb两部分. 输出为 word_emb(b,s,h) + pos_emb(s,h) + tokentype_emb(b,s,h)(需要转置适配)

注意在embedding最后要进行dropout处理, 应该是为了减少模型过拟合的风险

WordEmbeddings

tensor_parallel.VocabParallelEmbedding

vocab_size表示词表维度, 例如分词预处理后保留能查到的几千个常用单词. 将vocab_size个embed均分存储到global_world_size张卡上, embedding lookup时从对应的存储卡上拉取. 这里把非自身rank的emb通过[start_idx, end_idx)的mask操作置0, 然后通过reduce就能获取完整的词表.

如果配置开了序列并行, reduce操作会变为reduceScatter操作, lookup之后直接分配好sp的输入.

RoPE(旋转位置编码)

位置编码需要满足几个性质: 1. 不能满足交换律, 第m个token与第n个token的位置关系,和第n个token与第m个token的位置关系一定要有区分度。 2.需要有远程衰减性

image-20250108114351771

为了便于加速计算, 可以等价优化为下面这种向量乘法的形式:

image-20250108114806801 image-20250108114336830
tokentype_embedding

类型嵌入层,用于区分输入中不同类型的token, 例如,在BERT中用于区分两个句子,而在某些GPT变种或特定任务中可能用于区分不同类型的输入数据,如对话中的提问和回答.

transformer

self.decoder就是上面通过ModuleSpec获得的module, 可以根据配置选择普通的selfAttention, 还是MLA.

  1. MLA原理: 在模型能力不变基础上,通过KV低秩压缩, 使得推理的KVcache显存占用和计算效率上对比MHA性能有明显提升.
image-20250107171551928
postprocess
1.output_layer & loss

训练时output可以并行, 这里是个TP列并行的方式, 训练方式如下例子:

<s>
<s> i
<s> i love 
<s> i love maching
<s> i love maching learning <eos/>

训练阶段将这个矩阵直接输入到decoder,分别得到 5个输出 \(O_i, i\in [1,2,3,4,5]\), 理想的输出应该是[i, love, maching, learning, ] ,然后 比较\(O_i\)和理想输出的交叉熵,得到loss. 而且这五个序列可以放在一个batch内并行计算.

optimizer

_get_param_groups_and_buffers

从多个model_chunks中遍历所有的param向量, 对其中某些param进行特殊的处理

  • decoupled_lr是为input/output layer单独设置的lr
  • no_weight_decay_cond: 配置参数是否应该执行权重衰减。
  • scale_lr_cond: 对某些指定层的参数进行学习率缩放, 匹配到对应的param_map后执行.
_get_megatron_optimizer_based_on_param_groups

主要逻辑是混合精度optimizer的设置(MixedPrecisionOptimizer), TODO: 细看Apex.FusedAdam, 和torch.adamW的区别在哪里

梯度缩放: DynamicGradScaler

混合精度训练的时候, 用于动态调整梯度缩放比例,以处理梯度爆炸或消失问题.

主要逻辑是有一个初始化scale值, 当连续hysteresis次迭代中出现NaN,torch.max(scale * backoff_factor, min_scale) 用来减小scale\(backoff\_factor \in (0, 1)\).

当连续growth_interval次没出现NaN, 按照_scale * growth_factor_, 放大scale, \(growth\_factor > 1\)

DistributedOptimizer

接口继承自torch.optimizer, 核心逻辑在step(self), 有3个类: FP32Optimizer, ChainedOptimizer, MixedPrecisionOptimizer

FP32Optimizer: fp32训练使用到的, 主要功能是配置了clip_grad后进行normalization, norm分两种, 一种是取max_grad, 一种是l2范数, 通过all_reduce拿到total_norm, 最后用这个值分别对每个param tensor进行scale. 在scale之后就调用的是torch.optimizer.step进行正常的Adam更新.

MixedPrecisionOptimizer: 混合精度训练使用

  • prepare_grads: 先从param.grad copy到 param.main_grad, 这一步同时做了fp16->fp32的转换, 然后检查所有的grad, 先unscale, 再看是否存在NaN. 注意只有fp16需要, bf16不需要.
  • clip_grad_norm: 与FP32Optimizer一样的方法scale grad.
  • step_with_ready_grads: optimizer.step后, 再把fp32的main_param copy回用于下一轮bp的fp16 param里面.

ChainedOptimizer: 用于moe场景, 每个分块子模型配置不同的optimizer时使用. 多个optimizer之间串行执行.

下一节看megatron的模型保存&加载, 并行训练相关代码.

参考链接

ROPE位置编码博客, 论文

MLA原理博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/868694.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

巧用VTable打造炫酷金字塔图表

在数据分析和可视化领域,表格是展示数据直观、有效的方式之一。今天,就让我们来探索如何利用VTable这个强大的表格组件,制作出既美观又富有信息量的金字塔图表,以及深入了解VTable中各种单元格类型的使用方法,让你的表格也能“绘”出精彩图表!在数据分析和可视化领域,表…

基于 Performace 分析事件循环

我们是袋鼠云数栈 UED 团队,致力于打造优秀的一站式数据中台产品。我们始终保持工匠精神,探索前端道路,为社区积累并传播经验价值。本文作者:千寻什么是事件循环? 我们为什么需要事件循环?对于 JavaScript 是一门单线程语言我们是肯定的,JavaScript 单线程的特性保证了渲…

万字图文:SaaS业务架构、价值流、业务能力、业务流程、业务对象、组织架构

大家好,我是汤师爷~ 本文为读者提供一个SaaS业务架构的系统性框架,探讨业务架构分析的核心要素,帮助SaaS企业深入剖析目标客户的业务模式,全面理解他们的业务架构。 无论你是SaaS创业者、产品经理还是架构师,本文内容都将为你的系统设计和决策提供帮助。 1 目标与步骤 Saa…

老奶奶看了都会的WSL2连接USB设备教程!

老奶奶看了都会的WSL2-Ubuntu连接USB设备教程!作者:SkyXZ CSDN:SkyXZ~-CSDN博客 博客园:SkyXZ - 博客园参考资料:微软官方文档连接 USB 设备 | Microsoft Learn在Win11上用WSL2安装Ubuntu来开发简直不要太爽!!!但是很多小伙伴会发现,欸~为什么我在宿主机上插入的USB设…

HighReport报表工具V4.0带来十大核心优势变化

1.概述经过一年时间产品升级研发,HighReport报表工具正式推出V4.0版本,报表算法和报表功能获得全面提升。HighReportV4.0带来全面质的飞跃,具有明显的产品优势。 2.亮点一:双父格扩展模型报表引擎核心算法是父子格扩展模型,下面是常见模型一般报表厂商下面的扩展模型是不支…

一个超经典 WinForm,WPF 卡死问题的终极反思

一:背景 1. 讲故事 写这篇文章起源于训练营里一位朋友最近在微信聊到他对这个问题使用了一种非常切实可行,简单粗暴的方式,并且也成功解决了公司里几个这样的卡死dump,如今在公司已是灵魂级人物,让我也尝到了什么叫反哺!对,这个东西叫 Harmony, github网址: https://gi…

nginx 简单实践:静态资源部署、URL 重写【nginx 实践系列之一】

本文为 nginx 简单实践系列文章之一,主要简单实践了两个内容:静态资源部署、重写,仅供参考。〇、前言 本文为 nginx 简单实践系列文章之一,主要简单实践了两个内容:静态资源部署、重写,仅供参考。 关于 Nginx 基础,以及安装和配置详解,可以参考博主过往文章: https://…

题解:AT_abc353_f [ABC353F] Tile Distance

[ABC353F] Tile Distance 题解 cnblogs 题目传送门:洛谷,Atcoder Solution 很恶心人的分类讨论题。 很显然走大格子大概率比走小格子快。 对终点和起点向上下左右枚举大格子,我们就把问题转化为给两个大格子 \((a,b)\)、\((c,d)\),求怎样走最快。 对角的大格子可以通过 \(2…

数字化转型中的项目管理优化:协作工具的优势与应用

一、企业数字化转型的背景与挑战 1.1 数字化转型的驱动力数字化转型是指企业通过采用数字技术、创新流程和业务模式,提升运营效率、创造新价值并优化客户体验。随着云计算、大数据、人工智能和物联网等技术的不断发展,数字化转型已成为企业实现长期竞争力和持续增长的重要战略…

rk3568屏幕抖动问题

问题描述:有时候操作屏幕界面,发现屏幕有抖动的情况。经跟RK原厂沟通,此问题跟给ddr供电的vdd_logic有关系。vdd_logic默认定义:vdd_logic: DCDC_REG1 {regulator-always-on;regulator-boot-on;regulator-min-microvolt = <500000>;regulator-max-microvolt = <13…

B@se-还原错误字母表转码的base64编码

题目: 密文:MyLkTaP3FaA7KOWjTmKkVjWjVzKjdeNvTnAjoH9iZOIvTeHbvD== JASGBWcQPRXEFLbCDIlmnHUVKTYZdMovwipatNOefghq56rs****kxyz012789+/oh holy shit, something is missing... 第一行是密文,有明显的Base64编码特征(等号结尾) 第二行是大小写字母、数字、+、/,有明显的…

打开浏览器Chrome跳转指定页面并全屏打开

办法来源于https://blog.csdn.net/shaofengzong/article/details/119928096 主要用于大屏数据可视化的项目,设置电脑自启动后,打开浏览器的同时默认跳转指定页面并全屏打开。、 办法通过增加谷歌浏览器的启动参数进行实现。 两种方式实现,需要根据需求进行选择默认全屏打开指…