基于Pytorch的DDP训练Mnist数据集

        在前几期的博文中我们讲了pytorch的DDP,但是当时的demo是自制的虚拟数据集(Pytorch分布式训练:DDP),这期文章我们使用Mnist数据集做测试,测试并完善代码。

快速开始

        1.  我们修改一下main函数,在main函数中导入Mnist数据。我这里把测试集关闭了,需要的可以打开。

def main(rank, world_size, max_epochs, batch_size):ddp_setup(rank, world_size)train_dataset = datasets.MNIST(root="./MNIST", train=True, transform=data_tf, download=True)train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,sampler=DistributedSampler(train_dataset))model = Net()# optimzer = torch.optim.Adam(model.parameters(), lr=1e-3)optimzer = torch.optim.SGD(model.parameters(), lr=1e-2)trainer = Trainer(model=model, gpu_id=rank, optimizer=optimzer, train_dataloader=train_dataloader)trainer.train(max_epochs)destroy_process_group()# test_dataset = datasets.MNIST(root="./MNIST", train=False, transform=data_tf, download=True)# test_dataloader = DataLoader(test_dataset,#                              batch_size=32,#                              shuffle=False)# evaluation(model=model, test_dataloader=test_dataloader)

        2.  修改模型结构,非常简单的一个网络


class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear( 64*5*5, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))# print(x.shape)x = x.view(-1, 64*5*5)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return x

        3. 完整代码如下,增加了计算准确率的功能,这些代码可以自己写个函数进行封装的,我太懒了。。。

"""
pytorch分布式训练结构
"""
from time import time
import os
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 多gpu训练所需的包
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_groupdef ddp_setup(rank, world_size):"""每个显卡都进行初始化"""os.environ["MASTER_ADDR"] = "localhost"os.environ["MASTER_PORT"] = "12355"# init_process_group(backend="nccl", rank=rank, world_size=world_size)init_process_group(backend="gloo", rank=rank, world_size=world_size)torch.cuda.set_device(rank)data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]
)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear( 64*5*5, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))# print(x.shape)x = x.view(-1, 64*5*5)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return xclass Trainer:def __init__(self, model, train_dataloader, optimizer, gpu_id):self.gpu_id = gpu_idself.model = model.to(gpu_id)self.train_dataloader = train_dataloaderself.optimizer = optimizerself.model = DDP(model, device_ids=[gpu_id])self.criterion = torch.nn.CrossEntropyLoss()def _run_batch(self, xs, ys):self.optimizer.zero_grad()output = self.model(xs)loss = self.criterion(output, ys)loss.backward()self.optimizer.step()_, predicted = torch.max(output, 1)return ys.size(0), (predicted == ys).sum()def _run_epoch(self, epoch):batch_size = len(next(iter(self.train_dataloader))[0])# print(f"|GPU:{self.gpu_id}| Epoch:{epoch} | batchsize:{batch_size} | steps:{len(self.train_dataloader)}")# 打乱数据,随机打乱self.train_dataloader.sampler.set_epoch(epoch)sample_nums = 0train_correct = 0for xs, ys in self.train_dataloader:xs = xs.to(self.gpu_id)ys = ys.to(self.gpu_id)sample_num, correct = self._run_batch(xs, ys)sample_nums += sample_numtrain_correct += correct# print(train_correct.item(), sample_nums)print(f"train_acc: {train_correct.item() / sample_nums * 100 :.3f}")def _save_checkpoint(self, epoch):ckp = self.model.module.state_dict()PATH = f"./params/checkpoint_{epoch}.pt"torch.save(ckp, PATH)def train(self, max_epoch: int):for epoch in range(max_epoch):self._run_epoch(epoch)# if self.gpu_id == 0:#     self._save_checkpoint(epoch)def evaluation(model, test_dataloader):model.eval()model.to("cuda:0")sample_nums = 0train_correct = 0for xs, ys in test_dataloader:xs = xs.to("cuda:0")ys = ys.to("cuda:0")output = model(xs)_, predicted = torch.max(output, 1)sample_nums += ys.size(0)train_correct += (predicted == ys).sum()print(f"test_acc: {train_correct.item() / sample_nums * 100 :.3f}")def main(rank, world_size, max_epochs, batch_size):ddp_setup(rank, world_size)train_dataset = datasets.MNIST(root="./MNIST", train=True, transform=data_tf, download=True)train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,sampler=DistributedSampler(train_dataset))model = Net()# optimzer = torch.optim.Adam(model.parameters(), lr=1e-3)optimzer = torch.optim.SGD(model.parameters(), lr=1e-2)trainer = Trainer(model=model, gpu_id=rank, optimizer=optimzer, train_dataloader=train_dataloader)trainer.train(max_epochs)destroy_process_group()# test_dataset = datasets.MNIST(root="./MNIST", train=False, transform=data_tf, download=True)# test_dataloader = DataLoader(test_dataset,#                              batch_size=32,#                              shuffle=False)# evaluation(model=model, test_dataloader=test_dataloader)if __name__ == "__main__":start_time = time()max_epochs = 50batch_size = 128world_size = torch.cuda.device_count()mp.spawn(main, args=(world_size, max_epochs, batch_size), nprocs=world_size)print(time() - start_time)

训练测试

        我简单的测试了一下单卡和多卡的GPU性能(一张3090、一张3090ti),表格如下:

        在数据量较小的前提下双卡对单卡优势不明显,加大epoch才能看出明显差距。

结尾

如果不出意外DDP的内容已经结束了,后续发现什么好玩的继续发出来
如果觉得文章对你有用请点赞、关注  ->> 你的点赞对我太有用了
群内交流更多技术
130856474  <--  在这里

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/439016.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL之索引使用原则详解(验证索引效率,SQL提示等)

索引使用 验证索引效率 在未建立索引之前&#xff0c;执行如下SQL语句&#xff0c;查看SQL的耗时 select * from tb_user where id 1; 针对字段创建索引 create index idx_sku_sn on tb_sku(sn) ; 针对于用户量大的表中&#xff0c;添加索引要比没有添加索引的字段查询…

【基础算法练习】单调队列与单调栈模板

文章目录 单调栈模板题代码模板算法思想 单调队列模板题代码模板算法思想 单调栈 模板题 题目链接&#xff1a;ACwing 830. 单调栈 代码模板 #include <iostream> #include <vector> #include <stack>using namespace std;const int N 100010;vector<…

STM32标准库+HAL库 | 输入捕获测量PWM的脉冲频率+占空比

提醒&#xff1a;本文的代码Demo中使用的是&#xff0c;单通道捕捉采集PWM输入信号的频率占空比。 在上一篇博客中已经讲解了过PWM输出配置&#xff0c;本文主要讲解TIM输入捕获配置。STM32标准库HAL库 | 高精度动态调节PWM输出频率占空比_hal库改变pwm频率-CSDN博客 目录 1…

Mac安装配置maven

Mac安装配置maven 官网下载地址&#xff1a;https://maven.apache.org/download.cgi 下载好以后解压配置 maven 环境变量 打开终端&#xff0c;输入命令打开配置文件./bash_profile open ~/.bash_profile输入i进入编辑模式,进行maven配置; MAVEN_HOME为maven的本地路径 ex…

机器学习3-简单线性回归

需求&#xff1a; 现在要根据学生的学习时间来预测学习成绩&#xff0c;给出现有数据&#xff0c;用来训练模型并预测新数据。 分析&#xff1a; 使用线性回归模型。 代码&#xff1a; import pandas as pd import matplotlib.pyplot as plt from sklearn.model_selection i…

蓝桥杯备战——8.DS1302时钟芯片

1.分析原理图 由上图可以看到&#xff0c;芯片的时钟引脚SCK接到了P17,数据输出输入引脚IO接到P23,复位引脚RST接到P13。 2.查阅DS1302芯片手册 具体细节还需自行翻阅手册&#xff0c;我只截出重点部分 总结&#xff1a;数据在上升沿写出&#xff0c;下降沿读入&#xff0c;…

PDF控件Spire.PDF for .NET【安全】演示:使用 C# 检测签名的 PDF 是否被修改

对 PDF 文档进行数字签名后&#xff0c;PDF 将被锁定以防止更改或允许检测更改。在本文中&#xff0c;我们将介绍如何使用 Spire.PDF 检测签名的 PDF 是否被修改。 Spire.PDF for .NET 是一款独立 PDF 控件&#xff0c;用于 .NET 程序中创建、编辑和操作 PDF 文档。使用 Spire…

鸿蒙首批原生应用!无感验证已完美适配鸿蒙系统

顶象无感验证已成功适配鸿蒙系统&#xff0c;成为首批鸿蒙原生应用&#xff0c;助力鸿蒙生态的快速发展。 作为全场景分布式操作系统&#xff0c;鸿蒙系统旨在打破不同设备之间的界限&#xff0c;实现极速发现、极速连接、硬件互助、资源共享。迄今生态设备数已突破8亿台&…

结构体与共用体基础

结构体基础用法与共用体简述 1.结构体的定义2.结构体声明及使用3.结构体成员初始化4.结构体占用空间探究4.1 结构体成员所在地址4.2 按地址值访问结构体内容4.3 内存对齐 5.共用体6.总结 1.结构体的定义 之前的课程中&#xff0c;我们介绍了很多数据类型&#xff0c;如整形、浮…

测试用例级别该如何定义 ? 在工作中该如何应用它 ? 把握好这5个场景即可。

1.级别的作用 在编写测试用例的过程中&#xff0c;用例的级别经常是一个不可缺少的字段 &#xff0c;本篇幅就来聊下这个字段 &#xff0c;首先从它的作用是什么呢 &#xff1f;我觉得主要有两点 &#xff0c;分别是 &#xff1a; 用于测试用例不同套件的选取 &#xff0c;即用…

MMCLMC公差计算.exe

一、概要 软件及完整代码请戳这里&#xff1a;MMC&LMC公差计算软件及代码 图1 软件操作界面 本软件功能主要是根据实际应用选择MMR或者LMR原则&#xff0c;输入基本尺寸、形位公差尺寸和实际测量尺寸&#xff0c;即可计算出对应的公差值。以孔的MMR为例见如图2、3&#xf…

Java - JDBC

Java - JDBC 文章目录 Java - JDBC引言JDBC1 什么是JDBC2 MySQL数据库驱动3 JDBC开发步骤4 具体介绍 引言 思考: 当下我们如何操作数据库&#xff1f; 使用客户端工具访问数据库&#xff0c;手工建立连接&#xff0c;输入用户名和密码登录。编写SQL语句&#xff0c;点击执行…