Pytorch distributed 概述
本节我们介绍一下 torch.distributed
Pytorch 分布式库主要包含一套并行的模块,一个通信层,以及对于运行和debug大规模训练的infra
主要有以下四个并行的apis:
- DDP(分布式数据并行)
- FSDP (fully sharded data-parallel training)
- Tensor parallel(tp)
- pipeline parallel(pp)
分片原语:
DTensor
和 DeviceMesh
是可以根据在N维的进程分组进行构建来开启并行。
DTensor
: 表示一个 sharded and/or replicated 的tensor,可以根据操作自动地reshard tensorDeviceMesh
: 将 device communicator 抽象为 一个多维数组,可以管理底层的ProcessGroup
实例 来在一个多维的并行上进行集合通信。
通信api:
pytorch分布式通信层(c10d
)提供了集合通信api(例如 all_reduce, all_gather) 以及 P2P 的api (例如send和isend)
launcher
torchrun
是一个通常使用的launch脚背,可以在本地和远程机器上spawns 进程来运行分布式的pytorch程序
应用并行来scale你的模型
数据并行:模型被复制到每个进程上
模型并行:模型被放进一个GPU内
- 如果你的模型能放入一个GPU,想使用多GPU进行scale,那就使用
DDP
.- 如果使用了多个节点,用
torchrun
来launch多个pytorch进程
- 如果使用了多个节点,用
- 如果不能放进GPU,那就使用
FSDP
- 如果到达了
FSDP
的scale极限,使用tp 及 pp