【pytorch】多GPU同时训练模型

文章目录

  • 1. 基本原理
    • 单机多卡训练教程——DP模式
  • 2. Pytorch进行单机多卡训练步骤
    • 1. 指定GPU
    • 2. 更改模型训练方式
    • 3. 更改权重保存方式


摘要:多GPU同时训练,能够解决单张GPU显存不足问题,同时加快模型训练。

1. 基本原理

单机多卡训练教程——DP模式

(1)将模型复制到各个GPU中,并将一个batch的数据划分成mini_batch(平均分配) 并分发给每个GPU;
注意:这里的batch_size要大于device数。
(2)各个GPU独自完成mini_batch的前向传播,并把获得的output传递给GPU_0(主GPU) ;
(3) GPU_0整合各个GPU传递过来的output,并计算loss。此时GPU_0可以对这些loss进行一些聚合操作;
(4) GPU_0归并loss之后,并进行后向传播以及梯度下降从而完成模型参数的更新(此时只有GPU_0上的模型参数得到了更新),GPU_0将更新好的模型参数又传递给其余GPU;

以上就是DP模式下多卡GPU进行训练的方式。其实可以看到GPU_0不仅承担了前向传播的任务,还承担了收集loss,并进行梯度下降。因此在使用DP模式进行单机多卡GPU训练的时候会有一张卡的显存利用会比其他卡更多,那就是你设置的GPU_0。

2. Pytorch进行单机多卡训练步骤

只需要在你的代码中改三个地方就可实现

1. 指定GPU

在这里插入图片描述
如上所示,在导入各种库下面使用os.environ["CUDA_VISIBLE_DEVICES"]来指定可识别的GPU,该语句在程序开始前使用。
代码如下:

import torch.nn as nn
import os
os.environ["CUDA_VISIBLE_DEVICES"]= 2,3,1'#指定该程序可以识别的物理GPU编号,这里的你主机上的2号GPU就是训练程序中的主GPUO,这里最好—定要自己指定你自己可以用的gpu号。

2. 更改模型训练方式

在这里插入图片描述
平常的模型训练方式只需要model.cuda()语句即可,在单机多卡训练中,只需要在该语句下面添加一行nn.DataParallel语句即可。
代码如下

model.cuda()
model = nn.DataParallel(model,devise =[0,1,2])#在执行该语句之前最好加上model.cuda(),保证你的模型存在GPU上即可

3. 更改权重保存方式

对于数据,我们只需要按照平常的方式使用.cuda()放置在GPU上即可,内部batch的拆分已经被封装在了DataPanallel模块中。要注意的是,由于我们的model被nn.DataPanallel()包裹住了,所以如果想要储存模型的参数,需要使用:model.module.state_dict()的方式才能取出(不能直接是model.state_dict()
代码如下:

'''
使用单机多卡训练的模型权重保存方式
'''
torch.save(model.module.state_dict(),f'best.pth')  

作为参考,将平常的权重保存方式也写上:

'''
平常的权重保存方式
'''
torch.save(model.state_dict(),f'best.pth')  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/128995.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis-01基本数据结构

1、String 1.1、介绍 String 是最基本的 key-value 结构,key 是唯一标识,value 是具体的值,value其实不仅是字符串, 也可以是数字(整数或浮点数),value 最多可以容纳的数据长度是 512M 1.2、…

京东数据报告:2023年儿童面膜行业数据分析

如今,儿童面膜在不少家长群体中受到追捧。有的家长称自家孩子3岁开始敷面膜,而在某电商平台上,一位母婴博主称自己的女儿才2岁,“已经深知护肤的重要性了,每天洗完澡就嚷嚷着要敷面膜”。 而从市场角度看,…

大数据概述(林子雨慕课课程)

文章目录 1. 大数据概述1.1 大数据概念和影响1.2 大数据的应用1.3 大数据的关键技术1.4 大数据与云计算和物联网的关系云计算物联网 1. 大数据概述 大数据的四大特点:大量化、快速化、多样化、价值密度低 1.1 大数据概念和影响 大数据摩尔定律 大数据由结构化和非…

TCP/IP(二)导论

一 知识铺垫 以下内容参照 <<电子科技大学TCPIP协议原理>>全 ① 协议和标准 一组规则&#xff1a; 交通规则、学生上学的学生守则等;数据通信的规则,有一个专门的名称叫作协议 protocol语义&#xff1a;具体描述在通信当中,每一个信息的具体含义. 二进制bit流…

MySQL字符集大小写不敏感导致的主键冲突问题记录

文章目录 前言问题复原&#xff08;一&#xff09;数据库&#xff08;二&#xff09;表&#xff08;三&#xff09;插入语句&#xff08;四&#xff09; 解决 参考资料 前言 数据入库的时候报了个主键冲突的error&#xff0c;很是纳闷于是乎开始排查摸索起来&#xff0c;发现是…

归并排序及其非递归实现

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 目录 归并排序递归实现 归并排序非递归实现 归并排序递归实现 图示&#xff1a; 代码&#xff1a; 先分再归并&#xff0c;像是后序一般。 //归并排序 void MergeSort(int* arr, int left, int right) {int* temp (int…

css的gap设置元素之间的间隔

在felx布局中可以使用gap来设置元素之间的间隔&#xff1b; .box{width: 800px;height: auto;border: 1px solid green;display: flex;flex-wrap: wrap;gap: 100px; } .inner{width: 200px;height: 200px;background-color: skyblue; } <div class"box"><…

VMware和别的服务器 ,组建局域网那些事 。

利用VMware &#xff0c;实现组件局域网、有可能会受限于WiFi&#xff08;路由器&#xff09; 。 通常不会&#xff0c;除非做了网关设置 相关知识&#xff1a; 禁用局域网隔离&#xff08;LAN Isolation&#xff09;&#xff1a; 某些路由器提供了一个选项&#xff0c;允许您禁…

从0到1基于ChatGLM-6B使用LoRA进行参数高效微调

从0到1基于ChatGLM-6B使用LoRA进行参数高效微调 吃果冻不吐果冻皮 ​ 关注他 cliniNLPer 等 189 人赞同了该文章 ​ 目录 收起 ChatGLM-6B简介 具备的一些能力 局限性 LoRA 技术原理 环境搭建 数据集准备 数据预处理 参数高效微调 单卡模式模型训练 数据并行模式模型训练 模型推…

怎么将自己拍摄的视频静音?详细步骤教会你~

大部分人都会遇到的一个问题&#xff0c;我们在拍摄视频时容易将嘈杂的背景音或环境音录进去&#xff0c;怎样解决这个问题呢&#xff1f;今天就来教大家具体操作步骤&#xff0c;只需用到这个软件即可&#xff01; 第一步&#xff1a;打开我们的【音分轨】APP&#xff0c;进入…

3561-24-8|荧光染料6-fam(Br4)|可作为成像剂

产品简介&#xff1a;6-fam(Br4)是一种荧光染料&#xff0c;广泛应用于生物医学领域中的荧光探针、标记物和成像剂等方面。其分子结构独特&#xff0c;具有良好的荧光量子产率和稳定性&#xff0c;能够在生物体内快速、准确地标记和追踪生物分子和细胞。其优异的荧光性能和化学…

环信web、uniapp、微信小程序SDK报错详解---登录篇

项目场景&#xff1a; 记录对接环信sdk时遇到的一系列问题&#xff0c;总结一下避免大家再次踩坑。这里主要针对于web、uniapp、微信小程序在对接环信sdk时遇到的问题。主要针对报错400、404、401、40 (一) 登录用户报400 原因分析&#xff1a; 从console控制台输出及networ…