Fabric实现多GPU运行

官方的将pytorch转换为fabric简单分为五个步骤:

步骤 1:

在训练代码的开头创建 Fabric 对象

from lightning.fabric import Fabricfabric = Fabric()

步骤 2:

如果打算使用多个设备(例如多 GPU),就调用 launch()

fabric.launch()

 步骤 3:

在每个模型和优化器对上调用 setup() ,在所有数据加载器上调用 setup_dataloaders()

model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)

 步骤 4:

删除所有 .to 和 .cuda 调用,因为 Fabric 将自动处理

- model.to(device)    # 删除
- batch.to(device)    # 删除

步骤 5:

将 loss.backward() 替换为 fabric.backward(loss) 

- loss.backward()
+ fabric.backward(loss)

结合起来:

将所有步骤结合起来,这就是代码将如何更改:

  import torchfrom lightning.pytorch.demos import WikiText2, Transformer
+ import lightning as L    # 新增- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    # 删除
+ fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")    # 新增
+ fabric.launch()    # 新增dataset = WikiText2()dataloader = torch.utils.data.DataLoader(dataset)model = Transformer(vocab_size=dataset.vocab_size)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)- model = model.to(device)    # 删除
+ model, optimizer = fabric.setup(model, optimizer)    # 新增
+ dataloader = fabric.setup_dataloaders(dataloader)    # 新增model.train()for epoch in range(20):for batch in dataloader:input, target = batch
-         input, target = input.to(device), target.to(device)    # 删除optimizer.zero_grad()output = model(input, target)loss = torch.nn.functional.nll_loss(output, target.view(-1))
-         loss.backward()    # 删除
+         fabric.backward(loss)    # 新增optimizer.step()

=======================================================================

记录一下自己代码的修改过程 

训练的是DECA的修改版 (En生bs和lm(Wgan*0.01+dinov2))

main_train.py中

导入lighting和Fabric并实例化,实例化的适合也可以加上【precision='32'】,float32位精度 

from lightning import Fabric
import lightning as Lfabric = Fabric(accelerator="cuda",devices=None, strategy="ddp",precision='32')
fabric.launch()# 这里的devices=None:这样就取决于命令行中CUDA_VISIBLE_DEVICES=的gpu名称
# precision='32'是使用32位精度

其他参数可用内容:

fabric = Fabric()fabric = Fabric(devices=2/4/8)fabric = Fabric(devices=1/2/4/8/"auto", strategy="ddp"/"fspd"/"deepspeed"/"auto")

deca.py中

1.导包+初始化fabric

这个import fabric就是上面实例化出来的fabric,我在trainer中又实例化了一下

2.去除原本的DP或者DDP,因为会冲突

注释了上面的DataParallel,使用fabric.setup对模型进行fabric操作

 trainer.py中

主要修改训练函数

1.我在这里又实例化了一遍

from lightning import Fabric
fabric = Fabric(accelerator="cuda",devices=None, strategy="ddp",precision='32')
fabric.launch()

2.Trainer类中初始化时进行了添加

3.主要修改:training_step

 验证的话 也是这么修改

4.数据dataloader处理

 5.fit.py,关于tensorboard报错

每个卡都有损失,tensorboard好像全局损失什么的,会产生冲突

然后最后loss和backward修改

 然后就可以启动了CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 nohup python -u main_train.py --cfg configs/pretrain.yml > train.log 2>&1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/700208.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

搞什么副业可以月入过万?

月入过万的副业不是一件容易的事情,它需要你付出很多努力和时间。以下是一些可能能够实现月入过万的副业 1. 自媒体运营 通过开设自己的公众号、博客或YouTube频道,积累一定的粉丝和流量,然后通过广告、赞助、商品销售等方式赚取收入。 2. …

Python入门系列-02 pip的安装

目录 一、pip介绍二、pip安装检查三、pip安装 一、pip介绍 pip 是 Python 包管理工具,该工具提供了对Python 包的查找、下载、安装、卸载的功能。 二、pip安装检查 你可以通过以下命令来判断是否已安装。 pip --version # Python2.x 版本命令 pip3 --versio…

Redis分布式缓存

分布式缓存 引入: 一:持久化: 1.1.RDB持久化: 1.2.AOF文件: 记得关闭RDB,开启AOF。 注意,AOF默认是详细的记录每一条命令,即使是对同一个key的多次修改,RDB只会记录最…

云端的艺术革命:云渲染如何重塑动画与视觉特效产业

在 2019 年,乔恩费儒(Jon Favreau)决定重拍迪士尼的经典电影《狮子王》。他的创新构想是以真实动物为模型,在非洲草原上拍摄,由真实动物“出演”的辛巴和其他角色,随后通过配音赋予它们生命。 为了实现这一…

emp.dll文件丢失荒野大镖客,怎么快速修复emp.dll

缺失或损坏的 DLL 文件是会导致系统或软件故障的,DLL(动态链接库)文件是 Windows 操作系统中至关重要的一部分,它们允许多个程序共享代码和资源,从而减少内存占用和增强系统性能。然而,当EMP.dll文件丢失或…

基于springboot+vue+Mysql的在线答疑系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

The Quantcast File System——论文泛读

VLDB 2013 Paper 分布式元数据论文阅读笔记整理 问题 在2013年之前,由于网络链路带宽有限,数据在集群中移动速度慢,因此Hadoop尽量将数据留在原来的位置,并将处理代码发送给它。随着网络链路的发展,可以之前更高的数…

使用Nginx对网站资源进行加密访问并限制访问IP

你好呀,我是赵兴晨,文科程序员。 大家在工作中有没有遇到过这样的需求,新上的网站部署到生产服务器上,但是还没公开,只允许个别高层领导看。 思来想去,我想到了一个简单的方法,通过Nginx对网站…

MATLAB科技绘图与数据分析

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。…

惠普发布全新AI战略,重塑办公空间 引领企业智能化新浪潮

近日、全球知名科技公司惠普在北京隆重举办了以“用智能,开启无限可能”为主题的2024惠普商用AI战略暨AI PC新品发布会,此次盛会标志着惠普在人工智能领域迈出了重要一步,惠普紧跟时代步伐,推出了更高效、更安全、更灵活的AI PC产…

强化训练:day9(添加逗号、跳台阶、扑克牌顺子)

文章目录 前言1. 添加逗号1.1 题目描述2.2 解题思路2.3 代码实现 2. 跳台阶2.1 题目描述2.2 解题思路2.3 代码实现 3. 扑克牌顺子3.1 题目描述3.2 解题思路3.3 代码实现 总结 前言 1. 添加逗号   2. 跳台阶   3. 扑克牌顺子 1. 添加逗号 1.1 题目描述 2.2 解题思路 我的写…

Network Compression

听课(李宏毅老师的)笔记,方便梳理框架,以作复习之用。本节课主要讲了Network Compression,包括为什么要压缩,压缩的主要手段(pruning,knowledge distillation,parameter quantization,architect…