分布式执行引擎ray入门--(3)Ray Train

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

import os
import tempfileimport torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Composeimport ray.train.torchdef train_func(config):# Model, Loss, Optimizermodel = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# model.to("cuda")  # This is done by `prepare_model`# [1] Prepare model.model = ray.train.torch.prepare_model(model)criterion = CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=0.001)# Datatransform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])data_dir = os.path.join(tempfile.gettempdir(), "data")train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=128, shuffle=True)# [2] Prepare dataloader.train_loader = ray.train.torch.prepare_data_loader(train_loader)# Trainingfor epoch in range(10):for images, labels in train_loader:# This is done by `prepare_data_loader`!# images, labels = images.to("cuda"), labels.to("cuda")outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()# [3] Report metrics and checkpoint.metrics = {"loss": loss.item(), "epoch": epoch}with tempfile.TemporaryDirectory() as temp_checkpoint_dir:torch.save(model.module.state_dict(),os.path.join(temp_checkpoint_dir, "model.pt"))ray.train.report(metrics,checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),)if ray.train.get_context().get_world_rank() == 0:print(metrics)# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(train_func,scaling_config=scaling_config,# [5a] If running in a multi-node cluster, this is where you# should configure the run's persistent storage that is accessible# across all worker nodes.# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))model = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

+import ray.train
+from ray.train import Checkpointdef train_func(config):...torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

from ray.train import RunConfig# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

from ray.train.torch import TorchTrainertrainer = TorchTrainer(train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/527426.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CentOS 8启动流程

一、BIOS与UEFI BIOS Basic Input Output System的缩写,翻译过来就是“基本输入输出系统”,是一种业界标准的固件接口,第一次出现在1975年,是计算机启动时加载的第一个程序,主要功能是检测和设置计算机硬件&#xff…

验证码安全

目录 验证码识别&复用&调用&找回密码重定向&状态值 res 修改-找回密码修改返回状态值判定验证通过 验证码爆破-知道验证码规矩进行无次数限制爆破 短信轰炸原理 验证码识别&复用&调用&找回密码重定向&状态值 res 修改-找回密码修改返回状态…

【UE5】创建蓝图

创建GamePlay需要的相关蓝图 项目资源文末百度网盘自取 在 内容游览器 文件夹中创建文件夹,命名为 Blueprints ,用来放这个项目的所有蓝图(Blueprint) 在 Blueprints 文件夹下新建文件夹 GamePlay ,用存放GamePlay相关蓝图 在 Blueprints 文件夹下创建文…

算法详解——leetcode150(逆波兰表达式)

欢迎来看博主的算法讲解 博主ID:代码小豪 文章目录 逆波兰表达式逆波兰表达式的作用代码将中缀表达式转换成后缀表达式文末代码 逆波兰表达式 先来看看leetcode当中的原题 大多数人初见逆波兰表达式的时候大都一脸懵逼,因为与平时常见的表达式不同&am…

CentOS网络故障排查秘笈:实战指南

前言 作为一名热爱折腾 Linux 的技术达人,我深知网络故障会让人抓狂!在这篇文章里,我和你分享了我的心得体会,从如何分析问题、识别瓶颈,到利用各种神器解决网络难题。不管你是新手小白还是老鸟大神,这里都…

计算机网络-第4章 网络层(2)

主要内容:网络层提供的两种服务:虚电路和数据报(前者不用)、ip协议、网际控制报文协议ICMP、路由选择协议(内部网关和外部网关)、IPv6,IP多播,虚拟专用网、网络地址转换NAT,多协议标…

Oracle Essbase 多维库导入文件数据步骤操作

第一步: 先确定导入数据的维度数量(清楚自己需要导入什么数据和范围) 第二步: 设置加载的规则 1.创建规则 2.编辑规则-》打开数据文件 通过数据文件来确定加载规则的加载格式 先查看数据文件格式: 将数据文件导入&…

jquery基础

1、jQuery的下载 官网地址:jQuery 版本: 1x:兼容IE678等低版本浏览器,官网不再更新 2x:不兼容IE678等低版本浏览器,官网不再更新 3x:不兼容IE678等低版本浏览器,是官方主要更新…

nmcli绑定bond双网卡(active-backup模式)

当前网卡mac地址IP都不一样 创建名为“jbl”的新连接,并将其模式设置为“active-backup” nmcli connection add type bond ifname jbl mode active-backup添加物理网卡到bond(JBL),两个物理网卡添加到新创建的bond连接中 nmcli connection add type bond-slave…

力扣--76. 最小覆盖子串

给你一个字符串 s 、一个字符串 t 。返回 s 中涵盖 t 所有字符的最小子串。如果 s 中不存在涵盖 t 所有字符的子串,则返回空字符串 "" 。 注意: 对于 t 中重复字符,我们寻找的子字符串中该字符数量必须不少于 t 中该字符数量。如…

js【详解】Promise

为什么需要使用 Promise ? 传统回调函数的代码层层嵌套,形成回调地狱,难以阅读和维护,为了解决回调地狱的问题,诞生了 Promise 什么是 Promise ? Promise 是一种异步编程的解决方案,本身是一个构…

[虚拟机保护逆向] [HGAME 2023 week4]vm

[虚拟机保护逆向] [HGAME 2023 week4]vm 虚拟机逆向的注意点:具体每个函数的功能,和其对应的硬件编码的*长度* 和 *含义*,都分析出来后就可以编写脚本将题目的opcode转化位vm实际执行的指令 :分析完成函数功能后就可以编写脚本输出…