目录
简介
工作原理
设置
设备网格和张量布局
分发
数据并行
ModelParallel 和 LayoutMap
政安晨的个人主页:政安晨
欢迎 👍点赞✍评论⭐收藏
收录专栏: TensorFlow与Keras机器学习实战
希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!
多后端 Keras 的分发 API 完全指南。
简介
Keras 分布 API 是一个新接口,旨在促进 JAX、TensorFlow 和 PyTorch 等各种后端之间的分布式深度学习。这个功能强大的 API 引入了一套支持数据和模型并行的工具,可在多个加速器和主机上高效扩展深度学习模型。
无论是利用 GPU 还是 TPU 的强大功能,API 都提供了一种简化的方法来初始化分布式环境、定义设备网格,以及协调跨计算资源的张量布局。通过 DataParallel 和 ModelParallel 等类,它抽象了并行计算的复杂性,使开发人员更容易加速机器学习工作流程。
工作原理
Keras 分布 API 提供了一个全局编程模型,允许开发人员在全局上下文(就像使用单个设备一样)中组成对张量进行操作的应用程序,同时自动管理跨多个设备的分布。API 利用底层框架(如 JAX),通过称为单程序多数据 (SPMD) 扩展的程序,根据分片指令分发程序和张量。
通过将应用程序与分片指令分离,应用程序接口可以在单个设备、多个设备甚至多个客户端上运行相同的应用程序,同时保留其全局语义。
设置
import os# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data # For dataset input.
设备网格和张量布局
Keras distribution API 中的 keras.distribution.DeviceMesh 类表示为分布式计算配置的计算设备集群。它与 jax.sharding.Mesh 和 tf.dtensor.Mesh 中的类似概念一致,后者用于将物理设备映射到逻辑网格结构。
然后,TensorLayout 类指定了张量在 DeviceMesh 中的分布方式,详细说明了张量沿指定轴线的分片情况,这些轴线与 DeviceMesh 中的轴线名称相对应。
# Retrieve the local available gpu devices.
devices = jax.devices("gpu") # Assume it has 8 local GPUs.# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(shape=(2, 4), axis_names=["data", "model"], devices=devices
)# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(axes=("data", None, None, None), device_mesh=mesh
)
分发
Keras 中的 Distribution 类是一个基础抽象类,用于开发自定义分布策略。它封装了在设备网格中分发模型变量、输入数据和中间计算所需的核心逻辑。作为最终用户,您不必直接与该类交互,但可以与它的子类(如 DataParallel 或 ModelParallel)交互。
数据并行
Keras 分布 API 中的 DataParallel 类是为分布式训练中的数据并行策略而设计的,其中模型权重被复制到 DeviceMesh 中的所有设备上,每个设备处理一部分输入数据。
以下是该类的使用示例。
# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)# Set the global distribution.
keras.distribution.set_distribution(data_parallel)# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
执行结果如下:
Epoch 1/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.83490.842325747013092
ModelParallel 和 LayoutMap
当模型权重过大而无法在单个加速器上安装时,ModelParallel(并行模型)就非常有用了。通过此设置,您可以将模型权重或激活张量分发给 DeviceMesh 上的所有设备,并为大型模型启用水平缩放。
在 DataParallel 模型中,所有权重都是完全复制的,与此不同,ModelParallel 模型中的权重布局通常需要定制才能获得最佳性能。我们引入了 LayoutMap,让您从全局角度为任何权重和中间张量指定张量布局。
LayoutMap 是一个类似于 dict 的对象,用于将字符串映射到 TensorLayout 实例。它的行为不同于普通的 Python dict,因为在检索值时,字符串键被视为一个 regex。该类允许你定义 TensorLayout 的命名模式,然后检索相应的 TensorLayout 实例。通常,用于查询的关键字是 variable.path 属性,即变量的标识符。作为一种快捷方式,在插入值时也允许使用轴名的元组或列表,并将其转换为 TensorLayout。
如果没有设置 TensorLayout.device_mesh,布局图还可以选择包含一个设备网格(DeviceMesh)来填充 TensorLayout.device_mesh。
在使用键检索布局时,如果没有完全匹配的键,布局图中的所有现有键都将被视为 regex,并再次与输入键匹配。
如果有多个匹配项,就会引发 ValueError。如果没有找到匹配项,则返回 None。
mesh_2d = keras.distribution.DeviceMesh(shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)model_parallel = keras.distribution.ModelParallel(mesh_2d, layout_map, batch_dim_name="data"
)keras.distribution.set_distribution(model_parallel)inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
结果如下:
Epoch 1/3/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.warnings.warn("Some donated buffers were not usable:"8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/38/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/38/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.87258/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381 0.8502610325813293
改变网格结构也很容易,可以在更多数据并行或模型并行之间调整计算。您可以通过调整网格的形状来实现这一点。其他任何代码都无需更改。
full_data_parallel_mesh = keras.distribution.DeviceMesh(shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(shape=(1, 8), axis_names=["data", "model"], devices=devices
)