【PyTorch教程】如何使用PyTorch分布式并行模块DistributedDataParallel(DDP)进行多卡训练


本期目录

  • 1. 导入核心库
  • 2. 初始化分布式进程组
  • 3. 包装模型
  • 4. 分发输入数据
  • 5. 保存模型参数
  • 6. 运行分布式训练
  • 7. DDP完整训练代码


  • 本章的重点是学习如何使用 PyTorch 中的 Distributed Data Parallel (DDP) 库进行高效的分布式并行训练。以提高模型的训练速度。

1. 导入核心库

  • DDP 多卡训练需要导入的库有:

    作用
    torch.multiprocessing as mp原生Python多进程库的封装器
    from torch.utils.data.distributed import DistributedSampler上节所说的DistributedSampler,划分不同的输入数据到GPU
    from torch.nn.parallel import DistributedDataParallel as DDP主角,核心,DDP 模块
    from torch.distributed import init_process_group, destroy_process_group两个函数,前一个初始化分布式进程组,后一个销毁分布式进程组

2. 初始化分布式进程组

  • Distributed Process Group 分布式进程组。它包含在所有 GPUs 上的所有的进程。因为 DDP 是基于多进程 (multi-process) 进行并行计算,每个 GPU 对应一个进程,所以必须先创建并定义进程组,以便进程之间可以互相发现并相互通信。

  • 首先来写一个函数 ddp_setup()

    import torch
    import os
    from torch.utils.data import Dataset, DataLoader# 以下是分布式DDP需要导入的核心库
    import torch.multiprocessing as mp
    from torch.utils.data.distributed import DistributedSampler
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.distributed import init_process_group, destroy_process_group# 初始化DDP的进程组
    def ddp_setup(rank, world_size):os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "12355"init_process_group(backend="nccl", rank=rank, world_size=world_size)
    
  • 其包含两个入参:

    入参含义
    rank进程组中每个进程的唯一 ID,范围是[0, world_size-1]
    world_size一个进程组中的进程总数
  • 在函数中,我们首先来设置环境变量:

    环境变量含义
    MASTER_ADDR在rank 0进程上运行的主机的IP地址。单机训练直接写 “localhost” 即可
    MASTER_PORT主机的空闲端口,不与系统端口冲突即可

    之所以称其为主机,是因为它负责协调所有进程之间的通信。

  • 最后,我们调用 init_process_group() 函数来初始化默认分布式进程组。其包含的入参如下:

    入参含义
    backend后端,通常是 nccl ,NCCL 是Nvidia Collective Communications Library,即英伟达集体通信库,用于 CUDA GPUs 之间的分布式通信
    rank进程组中每个进程的唯一ID,范围是[0, world_size-1]
    world_size一个进程组中的进程总数
  • 这样,进程组的初始化函数就准备好了。

【注意】

  • 如果你的神经网络模型中包含 BatchNorm 层,则需要将其修改为 SyncBatchNorm 层,以便在多个模型副本中同步 BatchNorm 层的运行状态。(你可以调用 torch.nn.SyncBatchNorm.convert_sync_batchnorm(model: torch.nn.Module) 函数来一键把神经网络中的所有 BatchNorm 层转换成 SyncBatchNorm 层。)

3. 包装模型

  • 训练器的写法有一处需要注意,在开始使用模型之前,我们需要使用 DDP 去包装我们的模型:

    self.model = DDP(self.model, device_ids=[gpu_id])
    
  • 入参除了 model 以外,还需要传入 device_ids: List[int] or torch.device ,它通常是由 model 所在的主机的 GPU ID 所组成的列表,


4. 分发输入数据

  • DistributedSampler 在所有分布式进程中对输入数据进行分块,确保输入数据不会出现重叠样本。

  • 每个进程将接收到指定 batch_size 大小的输入数据。例如,当你指定了 batch_size 为 32 时,且你有 4 张 GPU ,那么有效的 batch size 为:
    32 × 4 = 128 32 \times 4 = 128 32×4=128

    train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=32,shuffle=False,	# 必须关闭洗牌sampler=DistributedSampler(train_set)	# 指定分布式采样器
    )
    
  • 然后,在每轮 epoch 的一开始就调用 DistributedSamplerset_epoch(epoch: int) 方法,这样可以在多个 epochs 中正常启用 shuffle 机制,从而避免每个 epoch 中都使用相同的样本顺序。

    def _run_epoch(self, epoch: int):b_sz = len(next(iter(self.train_loader))[0])self.train_loader.sampler.set_epoch(epoch)	# 调用for x, y in self.train_loader:...self._run_batch(x, y)
    

5. 保存模型参数

  • 由于我们前面已经使用 DDP(model) 包装了模型,所以现在 self.model 指向的是 DDP 包装的对象而不是 model 模型对象本身。如果此时我们想读取模型底层的参数,则需要调用 model.module

  • 由于所有 GPU 进程中的神经网络模型参数都是相同的,所以我们只需从其中一个 GPU 进程那儿保存模型参数即可。

    ckp = self.model.module.state_dict()	# 注意需要添加.module
    ...
    ...
    if self.gpu_id == 0 and epoch % self.save_step == 0:	# 从gpu:0进程处保存1份模型参数self._save_checkpoint(epoch)
    

6. 运行分布式训练

  • 包含 2 个新的入参 rank (代替 device) 和 world_size

  • 当调用 mp.spawn 时,rank 参数会被自动分配。

  • world_size 是整个训练过程中的进程数量。对 GPU 训练来说,指的是可使用的 GPU 数量,且每张 GPU 都只运行 1 个进程。

    def main(rank: int, world_size: int, total_epochs: int, save_step: int):ddp_setup(rank, world_size)	# 初始化分布式进程组train_set, model, optimizer = load_train_objs()train_loader = prepare_dataloader(train_set, batch_size=32)trainer = Trainer(model=model,train_loader=train_loader,optimizer=optimizer,gpu_id=rank,	# 这里变了save_step=save_step)trainer.train(total_epochs)destroy_process_group()	# 最后销毁进程组if __name__ == "__main__":import systotal_epochs = int(sys.argv[1])save_step = int(sys.argv[2])world_size = torch.cuda.device_count()mp.spawn(main, args=(world_size, total_epochs, save_step), nprocs=world_size)
    
  • 这里调用了 torch.multiprocessingspawn() 函数。该函数的主要作用是在多个进程中执行指定的函数,每个进程都在一个独立的 Python 解释器中运行。这样可以避免由于 Python 全局解释器锁 (GIL) 的存在而限制多线程并发性能的问题。在分布式训练中,通常每个 GPU 或计算节点都会运行一个独立的进程,通过进程之间的通信实现模型参数的同步梯度聚合

  • 可以看到调用 spawn() 函数时,传递 args 参数时并没有传递 rank ,这是因为会自动分配,详见下方表格 fn 入参介绍。

    入参含义
    fn: function每个进程中要执行的函数。该函数会以 fn(i, *args) 的形式被调用,其中 i 是由系统自动分配的唯一进程 ID ,args 是传递给该函数的参数元组
    args: tuple要传递给函数 fn 的参数
    nprocs: int要启动的进程数量
    join: bool是否等待所有进程完成后再继续执行主进程 (默认值为 True)
    daemon: bool是否将所有生成的子进程设置为守护进程 (默认为 False)

7. DDP完整训练代码

首先,创建了一个训练器 Trainer 类。

import torch
import os
from torch.utils.data import Dataset, DataLoader# 以下是分布式DDP需要导入的核心库
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group# 初始化DDP的进程组
def ddp_setup(rank: int, world_size: int):"""Args:rank: Unique identifier of each process.world_size: Total number of processes."""os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "12355"init_process_group(backend="nccl", rank=rank, world_size=world_size)class Trainer:def __init__(self,model: torch.nn.Module,train_loader: DataLoader,optimizer: torch.optim.Optimizer,gpu_id: int,save_step: int	# 保存点(以epoch计)) -> None:self.gpu_id = gpu_id,self.model = DDP(model, device_ids=[self.gpu_id])	# DDP包装模型self.train_loader = train_loader,self.optimizer = optimizer,self.save_step = save_stepdef _run_batch(self, x: torch.Tensor, y: torch.Tensor):self.optimizer.zero_grad()output = self.model(x)loss = torch.nn.CrossEntropyLoss()(output, y)loss.backward()self.optimizer.step()def _run_epoch(self, epoch: int):b_sz = len(next(iter(self.train_loader))[0])self.train_loader.sampler.set_epoch(epoch)	# 调用set_epoch(epoch)洗牌print(f'[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_loader)}')for x, y in self.train_loader:x = x.to(self.gpu_id)y = y.to(self.gpu_id)self._run_batch(x, y)def _save_checkpoint(self, epoch: int):ckp = self.model.module.state_dict()torch.save(ckp, './checkpoint.pth')print(f'Epoch {epoch} | Training checkpoint saved at ./checkpoint.pth')def train(self, max_epochs: int):for epoch in range(max_epochs):self._run_epoch(epoch)if self.gpu_id == 0 and epoch % self.save_step == 0:self._save_checkpoint(epoch)

然后,构建自己的数据集、数据加载器、神经网络模型和优化器。

def load_train_objs():train_set = MyTrainDataset(2048)model = torch.nn.Linear(20, 1)	# load your modeloptimizer = torch.optim.SGD(model.parameters(), lr=1e-3)return train_set, model, optimizerdef prepare_dataloader(dataset: Dataset, batch_size: int):return DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,	# 必须关闭pin_memory=True,sampler=DistributedSampler(dataset=train_set)	# 指定DistributedSampler采样器)

最后,定义主函数。

def main(rank: int, world_size: int, total_epochs: int, save_step: int):ddp_setup(rank, world_size)	# 初始化分布式进程组train_set, model, optimizer = load_train_objs()train_loader = prepare_dataloader(train_set, batch_size=32)trainer = Trainer(model=model,train_loader=train_loader,optimizer=optimizer,gpu_id=rank,	# 这里变了save_step=save_step)trainer.train(total_epochs)destroy_process_group()	# 最后销毁进程组if __name__ == "__main__":import systotal_epochs = int(sys.argv[1])save_step = int(sys.argv[2])world_size = torch.cuda.device_count()mp.spawn(main, args=(world_size, total_epochs, save_step), nprocs=world_size)

至此,你就已经成功掌握了 DDP 分布式训练的核心用法了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/176967.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SSM的OA办公管理系统的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…

Oracle数据库、实例、用户、表空间和表之间的关系

一、Oracle数据库中数据库、实例、用户、表空间和表(索引、视图、存储过程、函数、对象等对象)之间的关系。 1、Oracle的数据库是由一些物理文件组成:数据文件控制文件重做日志文件归档日志文件参数文件报警和跟踪日志文件备份文件。 2、实…

MHA:故障切换

MHA: masterhight availabulity:基于主库的高可用环境下:主从复制 故障切换 主从的架构。 MHA:最少要一主两从 mysql的单点故障问题,一旦主库崩溃,MHA可以在0-30秒内自动完成故障切换。 工作原理&#…

上机实验四 图的最小生成树算法设计 西安石油大学数据结构

实验名称:图的最小生成树算法设计 (1)实验目的: 掌握最小生成树算法,利用kruskal算法求解最小生成树。 (2)主要内容: 利用kruskal算法求一个图的最小生成树,设计Krus…

社区论坛小程序系统源码+自定义设置+活动奖励 自带流量主 带完整的搭建教程

大家好啊,又到了罗峰来给大家分享好用的源码的时间了。今天罗峰要给大家分享的是一款社区论坛小程序系统。社区论坛已经成为人们交流、学习、分享的重要平台。然而,传统的社区论坛往往功能单一、缺乏个性化设置,无法满足用户多样化的需求。而…

「Verilog学习笔记」优先编码器Ⅰ

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 分析 分析编码器的功能表: 当使能El1时,编码器工作:而当E10时,禁止编码器工作,此时不论8个输入端为何种状态&…

【rl-agents代码学习】02——DQN算法

文章目录 Highway-env Intersectionrl-agents之DQN*Implemented variants*:*References*:Query agent for actions sequence探索策略神经网络实现小结1 Record the experienceReplaybuffercompute_bellman_residualstep_optimizerupdate_target_network小结2 exploration_polic…

【吐血总结】前端开发:一文带你精通Vue.js前端框架(七)

文章目录 前言1️⃣事件处理器2️⃣表单3️⃣总结 前言 上一篇中我们学习了vue.js 的条件语句、循环语句等知识点.,现在让我们接着Vue系列的学习。 Vue中事件处理器、表单等在开发中的作用不可或缺,本文将基于实例进行以上知识点的讲解。 1️⃣事件处理器…

交换机堆叠 配置(H3C)

堆叠用来干什么? 一台交换机网口有限,无法满足网络需求; 无法达到网络要求,为了扩展核心设备的转发要求,不改变原来网络, 可以使用新交换机和原来交换机组成IRF。 配合聚合可以达到备用作用,防…

爆款元服务!教你如何设计高使用率卡片

元服务的概念相信大家已经在 HDC 2023 上有了很详细的了解,更轻便的开发方式,让开发者跃跃欲试。目前也已经有很多开发者开发出了一些爆款元服务,那么如何让你的元服务拥有更高的传播范围、更高的用户使用率和更多的用户触点呢?设…

Java实现音频转码,WAV、MP3、AMR互转

1.背景 最近在集成一款产品支持语音双向对讲,首先是采集小程序的音频下发给设备端,然后可以控制设备录音生成音频链路让小程序播放。在这个过程中发现,设备除了AMR格式的音频外,其他的音频都不支持,而微信小程序有不支…

2023年好用的远程协同运维工具当属行云管家!

对于IT小伙伴而言,一款好用的远程协同运维工具是非常重要的,不仅可以提高工作效率,还能第一时间解决运维难题,所以好用的远程协同工具是非常必要的。这里就给大家推荐一款哦! 2023年好用的远程协同运维工具当属行云管…