政安晨:【Keras机器学习实践要点】(十四)—— 使用 Keras3 进行分布式训练

目录

简介

工作原理

设置

设备网格和张量布局

分发

数据并行

ModelParallel 和 LayoutMap


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

多后端 Keras 的分发 API 完全指南。

简介


Keras 分布 API 是一个新接口,旨在促进 JAX、TensorFlow 和 PyTorch 等各种后端之间的分布式深度学习。这个功能强大的 API 引入了一套支持数据和模型并行的工具,可在多个加速器和主机上高效扩展深度学习模型。

无论是利用 GPU 还是 TPU 的强大功能,API 都提供了一种简化的方法来初始化分布式环境、定义设备网格,以及协调跨计算资源的张量布局。通过 DataParallel 和 ModelParallel 等类,它抽象了并行计算的复杂性,使开发人员更容易加速机器学习工作流程。

工作原理


Keras 分布 API 提供了一个全局编程模型,允许开发人员在全局上下文(就像使用单个设备一样)中组成对张量进行操作的应用程序,同时自动管理跨多个设备的分布。API 利用底层框架(如 JAX),通过称为单程序多数据 (SPMD) 扩展的程序,根据分片指令分发程序和张量。

通过将应用程序与分片指令分离,应用程序接口可以在单个设备、多个设备甚至多个客户端上运行相同的应用程序,同时保留其全局语义。

设置

import os# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data  # For dataset input.

设备网格和张量布局

Keras distribution API 中的 keras.distribution.DeviceMesh 类表示为分布式计算配置的计算设备集群。它与 jax.sharding.Mesh 和 tf.dtensor.Mesh 中的类似概念一致,后者用于将物理设备映射到逻辑网格结构。

然后,TensorLayout 类指定了张量在 DeviceMesh 中的分布方式,详细说明了张量沿指定轴线的分片情况,这些轴线与 DeviceMesh 中的轴线名称相对应。

# Retrieve the local available gpu devices.
devices = jax.devices("gpu")  # Assume it has 8 local GPUs.# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(shape=(2, 4), axis_names=["data", "model"], devices=devices
)# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(axes=("data", None, None, None), device_mesh=mesh
)

分发


Keras 中的 Distribution 类是一个基础抽象类,用于开发自定义分布策略。它封装了在设备网格中分发模型变量、输入数据和中间计算所需的核心逻辑。

作为最终用户,您不必直接与该类交互,但可以与它的子类(如 DataParallel 或 ModelParallel)交互。

数据并行


Keras 分布 API 中的 DataParallel 类是为分布式训练中的数据并行策略而设计的,其中模型权重被复制到 DeviceMesh 中的所有设备上,每个设备处理一部分输入数据。

以下是该类的使用示例。

# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)# Set the global distribution.
keras.distribution.set_distribution(data_parallel)# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)

执行结果如下: 

Epoch 1/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349

0.842325747013092

ModelParallel 和 LayoutMap

当模型权重过大而无法在单个加速器上安装时,ModelParallel(并行模型)就非常有用了。通过此设置,您可以将模型权重或激活张量分发给 DeviceMesh 上的所有设备,并为大型模型启用水平缩放。

在 DataParallel 模型中,所有权重都是完全复制的,与此不同,ModelParallel 模型中的权重布局通常需要定制才能获得最佳性能。我们引入了 LayoutMap,让您从全局角度为任何权重和中间张量指定张量布局。

LayoutMap 是一个类似于 dict 的对象,用于将字符串映射到 TensorLayout 实例。它的行为不同于普通的 Python dict,因为在检索值时,字符串键被视为一个 regex。该类允许你定义 TensorLayout 的命名模式,然后检索相应的 TensorLayout 实例。通常,用于查询的关键字是 variable.path 属性,即变量的标识符。作为一种快捷方式,在插入值时也允许使用轴名的元组或列表,并将其转换为 TensorLayout。

如果没有设置 TensorLayout.device_mesh,布局图还可以选择包含一个设备网格(DeviceMesh)来填充 TensorLayout.device_mesh。

在使用键检索布局时,如果没有完全匹配的键,布局图中的所有现有键都将被视为 regex,并再次与输入键匹配。

如果有多个匹配项,就会引发 ValueError。如果没有找到匹配项,则返回 None。

mesh_2d = keras.distribution.DeviceMesh(shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)model_parallel = keras.distribution.ModelParallel(mesh_2d, layout_map, batch_dim_name="data"
)keras.distribution.set_distribution(model_parallel)inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)

结果如下:

Epoch 1/3/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.warnings.warn("Some donated buffers were not usable:"8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/38/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/38/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.87258/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381  0.8502610325813293

改变网格结构也很容易,可以在更多数据并行或模型并行之间调整计算。您可以通过调整网格的形状来实现这一点。其他任何代码都无需更改。

full_data_parallel_mesh = keras.distribution.DeviceMesh(shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(shape=(1, 8), axis_names=["data", "model"], devices=devices
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/585920.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实现一个Google身份验证代替短信验证

最近才知道公司还在做国外的业务,要实现一个登陆辅助验证系统。咱们国内是用手机短信做验证,当然 这个google身份验证只是一个辅助验证登陆方式。看一下演示 看到了嘛。 手机下载一个谷歌身份验证器就可以 。 谷歌身份验证器,我本身是一个基…

sap 实施商

领导者象限中的三巨头   德勤、埃森哲、IBM,这三家公司行业从业者基本都知道,三巨头在国内的实力也是挺强的。 领导者象限中的四朵花 四朵花: 凯捷,塔塔咨询,安永,普华永道(Pwc) 国内本土公司&…

【Linux】动态库与静态库

👀樊梓慕:个人主页 🎥个人专栏:《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 🌝每一个不曾起舞的日子,都是对生命的辜负 目录 前言 1.为什么要有库&…

电脑分辨率怎么调,电脑分辨率怎么调整

随着电脑的普及以及网络的发展,我们现在在工作中都离不开对电脑的使用,今天小编教大家设置电脑分辨率,现在我们先了解这个分辨率是什么?通常电脑的显示分辨率就是屏幕分辨率,显示屏大小固定时,显示分辨率越高图像越清…

EVM Layer2 主流解决方案

深度解析主流 EVM Layer 2 解决方案:zk Rollups 和 Optimistic Rollups 随着以太坊网络的不断演进和 DeFi 生态系统的迅速增长,以太坊 Layer 2 解决方案日益受到关注。 其中,zk Rollups 和 Optimistic Rollups 作为两种备受瞩目的主流 EVM&…

ubuntu18.04 pycharm

一、下载pycharm (1)进入官网下载Download PyCharm: The Python IDE for data science and web development by JetBrains 选择专业版(professional)直接点击下载(download),我下载的是2023.3…

设计模式7--建造者模式

定义 案例一 案例二 优缺点

nodejs基础学习(一)

nodejs逆向python爬虫学习笔记 第一章 nodejs基础 nodejs基础 nodejs逆向python爬虫学习笔记开发环境vscodeF5运行注释js逆向作用变量**1、var全局**2、let块级作用域: {} if while for。。。等等循环中使用3、const 块级作用域,常量,不可以修改/重新定…

智慧公厕四大核心能力,赋能城市公共厕所智能化升级

公共厕所是城市基础设施中不可或缺的一部分,但由于传统的公共厕所在建设与规划上,存在一定的局限性,导致环境卫生差、管理难度大、使用体验不佳等问题,给市民带来了很多不便。而智慧公厕作为城市智能化建设的重要组成部分&#xf…

ElMessage自定义样式

ElMessage自定义样式 默认样式 从顶部出现,3 秒后自动消失。 常用于主动操作后的反馈提示。 import { ElMessage } from element-plusElMessage.success({message: res.data.msg,duration: 300,style: {marginTop: 200px,// 设置提示框的宽度width: 500px, // 设置…

STM32学习和实践笔记(4): 分析和理解GPIO_InitTypeDef GPIO_InitStructure (a)

深入分析及学习一下上面这一段代码的构成与含义。 首先,这个GPIO_InitTypeDef GPIO_InitStructure;其实与int a 是完全类似的语法格式以及含义。 GPIO_InitStructure就相当于a这样一个变量。不过从这个变量的名字可以知道,这是一个用于GPIO初始化的结构…

DeepMind联合创始人Demis Hassabis因对人工智能的贡献被授予英国爵士勋章

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…