PyTorch/XLA SPMD 用户指南¶
在本用户指南中,我们讨论了如何将 GSPMD 集成到 PyTorch/XLA 中,并提供了一个设计概述来说明 SPMD 分片注释 API 及其构造的工作原理。
什么是 PyTorch/XLA SPMD?¶
GSPMD 是适用于常见 ML 工作负载的自动并行化系统。XLA 编译器将根据用户提供的分片提示,将单个设备程序转换为具有适当集合的分区程序。此功能允许开发人员编写 PyTorch 程序,就像在单个大型设备上一样,无需任何自定义分片计算运算和/或集体通信即可扩展。

*图 1.两种不同执行策略的比较,(a) 用于非 SPMD 和 (b) 用于 SPMD。*
如何使用 PyTorch/XLA SPMD?¶
下面是一个使用 SPMD 的简单示例
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)
让我们一一解释这些概念
SPMD 模式¶
要使用 SPMD,您需要通过 启用它。在 SPMD 模式下,只有一个逻辑设备。分布式计算和 collective 由 .请注意,用户不能将 SPMD 与其他分布式库混合使用。xr.use_spmd()
mark_sharding
网孔¶
对于给定的设备集群,物理网格是互连拓扑的表示形式。
mesh_shape
是一个元组,它将乘以物理设备的总数。device_ids
几乎总是 。np.array(range(num_devices))
还鼓励用户为每个网格维度命名。在上面的示例中,第一个网格维度是维度,第二个网格维度是维度。
data
model
您还可以通过以下方式查看更多网格信息
>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])
分区规范¶
partition_spec 与 Importing 张量具有相同的秩。每个维度都描述了如何在设备网格中对相应的输入张量维度进行分片。在上面的示例中,张量的第一个维度在 dimension 处被分片,第二个维度在 dimension 处被分片。t
data
model
用户还可以对与网格形状具有不同维度的张量进行分片。
t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)
# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))
# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))
# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))
通过 SPMD 实现全分片数据并行 (FSDP)¶
通过 SPMD 或 FSDPv2 的完全分片数据并行是一种实用程序,它在 SPMD 中重新表达了著名的 FSDP 算法。这是 一项实验性功能,旨在为用户提供一个熟悉的界面,以享受 SPMD 带来的所有好处 表。设计文档在这里。
在继续之前,请查看 SPMD 用户指南。您还可以在此处找到最小可运行示例。
用法示例:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))
# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
还可以单独对各个层进行分片,并让外部包装器处理任何剩余的参数。下面是一个自动包装每个 .DecoderLayer
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
decoder_only_model.DecoderLayer
},
)
model = FSDPv2(
model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)
分片输出¶
为了确保 XLA 编译器正确实现 FSDP 算法,我们需要对权重和激活进行分片。这意味着对 forward 方法的输出进行分片。由于前向函数输出可能会有所不同,因此如果您的模块输出不属于以下类别之一,我们提供分片激活shard_output:
单个张量
一个张量元组,其中第 0 个元素是激活值。
用法示例:
def shard_output(output, mesh):
xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))
model = FSDPv2(my_module, mesh, shard_output)
梯度检查点¶
目前,需要在 FSDP 包装器之前将梯度检查点应用于模块。否则,递归循环到子模块将以无限循环结束。我们将在未来的版本中修复此问题。
用法示例:
from torch_xla.distributed.fsdp import checkpoint_module
model = FSDPv2(checkpoint_module(my_module), mesh)
PyTorch/XLA SPMD 高级主题¶
在本文档中,我们将介绍一些有关 GSPMD 的高级主题。在阅读本文档之前,请阅读 SPMD 用户指南。
PyTorch/XLA SPMD 采用单个设备程序,对程序进行分片并并行执行。SPMD 执行需要使用本机 PyTorch DataLoader,它将数据从主机同步传输到 XLA 设备。这会阻止 input data transfer 期间每一步的训练。为了提高原生数据加载性能,我们让 PyTorch/XLA ParallelLoader 支持直接输入分片 (src),当传递可选的 kwarg _input_sharding_时:
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
如果 batch 的每个元素是不同的形状,也可以为 different 指定一个:input_sharding
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
定义张量后,PyTorch/XLA 通常会将张量数据从主机异步传输到设备。这是为了将数据传输与图形跟踪时间重叠。但是,由于 GSPMD 允许用户修改张量分片_after _the张量已定义,因此我们需要进行优化,以防止在主机和设备之间来回传输张量数据。我们介绍了虚拟设备优化,这是一种先将张量数据放置在虚拟设备 SPMD:0 上的技术,然后在所有分片决策最终确定后上传到物理设备。SPMD 模式下的每个张量数据都放置在虚拟设备 SPMD:0 上。虚拟设备作为 XLA 设备 XLA:0 向用户公开,并在物理设备上具有实际分片,如 TPU:0、TPU:1 等。
混合网格¶
Mesh 很好地抽象了物理设备网格的构造方式。用户可以使用逻辑网格以任何形状和顺序排列设备。但是,可以根据物理拓扑定义性能更高的网格,尤其是当它涉及数据中心网络 (DCN) 交叉切片连接时。HybridMesh 创建的网格为此类多切片环境提供开箱即用的良好性能。它接受 ici_mesh_shape 和 dcn_mesh_shape,它们表示内部和外部网络的逻辑网格形状。
from torch_xla.distributed.spmd import HybridMesh
# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)
mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
在 TPU Pod 上运行 SPMD¶
如果您根据设备数量而不是某个硬编码常量来构建网格和分区规范,则无需更改从单个 TPU 主机到 TPU Pod 的代码。要在 TPU Pod 上运行 PyTorch/XLA 工作负载,请参阅我们的 PJRT 指南的 Pod 部分。
XLAShardedTensor¶
xs.mark_sharding
是一个就地作,它将分片注释附加到输入张量,但它也会返回一个 Python 对象。XLAShardedTensor
[RFC] 的主要用例是使用分片规范对本机(在单个设备上)进行注释。注释会立即发生,但由于计算是延迟执行的,因此张量的实际分片会延迟,但输入张量除外,这些张量是毫不延迟地分片的。一旦张量被注释并包装在 中,它就可以作为 传递给现有的 PyTorch 运算和层。这对于确保相同的 PyTorch 层和张量运算可以与 .这意味着用户不需要重写现有的 operations 和 model 代码进行分片计算。即,将满足以下要求:XLAShardedTensor
torch.tensor
XLAShardedTensor
nn.Module
torch.Tensor
XLAShardedTensor
XLAShardedTensor
XLAShardedTensor
是一个子类,直接与原生 Torch作和 .我们过去常常发送到 XLA 后端。PyTorch/XLA 检索附加的分片注释以跟踪图形并调用 XLA SPMDPartitioner。torch.Tensor
module.layers
__torch_dispatch__
XLAShardedTensor
在内部,(及其 global_tensor input)由一个特殊的数据结构提供支持,其中包含对分片设备数据的引用。
XLAShardedTensor
XLATensor
延迟执行后的分片张量可以被收集并作为global_tensor物化回主机,当主机上请求时(例如,打印全局张量的值)。
本地分片的句柄在延迟执行后严格具体化。 公开local_shards以将可寻址设备上的本地分片作为
XLAShardedTensor
List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]
.
此外,我们还在努力整合XLAShardedTensor
到DistributedTensor
支持 XLA 后端 [RFC] 的 API。
DTensor 集成¶
PyTorch 在 2.1 中发布了原型版 DTensor。
我们正在将 PyTorch/XLA SPMD 集成到 DTensor API RFC 中。我们有一个概念验证集成,它调用注释 API 来使用 XLA 对张量及其计算进行分片:distribute_tensor
mark_sharding
import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor
# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
此功能是实验性的,请继续关注即将发布的版本中的更多更新、示例和教程。
torch.compile 的激活分片¶
在 2.3 版本中,PyTorch/XLA 添加了自定义运算,可用于在区域中执行激活分片。这是我们不断努力的一部分,旨在使 + 成为使用 PyTorch/XLA 进行模型推理的推荐方式。使用此自定义作的示例:dynamo_mark_sharding
torch.compile
torch.compile
GSPMD
# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)
SPMD 调试工具¶
我们为 TPU/GPU/CPU 上的 PyTorch/XLA SPMD 用户提供了一个单主机/多主机:您可以使用它来可视化分片张量,也可以用于可视化共享字符串。以下是 TPU 单主机 (v4-8) 上的两个代码示例,其中有 或:shard placement visualization debug tool
visualize_tensor_sharding
visualize_sharding
visualize_tensor_sharding
visualize_sharding
使用的代码片段和可视化结果:
visualize_tensor_sharding
import rich
# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))
# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)

使用的代码片段和可视化结果:
visualize_sharding
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)

您可以在 TPU/GPU/CPU 单主机上使用这些示例,并将其修改为在多主机上运行。您可以将其修改为 sharding-style 和 。tiled
partial_replication
replicated
自动分片¶
我们引入了一项名为 RFC 的新 PyTorch/XLA SPMD 功能。这是 和 中的一个实验性功能,它支持单个 TPUVM 主机。auto-sharding
r2.3
nightly
XLA:TPU
PyTorch/XLA 自动分片可以通过以下方法之一启用:
设置 envvar
XLA_AUTO_SPMD=1
在代码的开头调用 SPMD API:
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
使用 和 进行呼叫 :
pytorch.distributed._tensor.distribute_module
auto-policy
xla
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
或者,可以设置以下 options/env-vars 来控制 基于 XLA 的自动分片通行证:
XLA_AUTO_USE_GROUP_SHARDING
:参数的组重新分片。默认设置。XLA_AUTO_SPMD_MESH
:用于自动分片的逻辑网格形状。例如,对应于具有 4 个全局设备的 2×2 网格。如果未设置,则 将使用默认的 Device Mesh 形状。XLA_AUTO_SPMD_MESH=2,2
num_devices,1
分布式检查点¶
PyTorch/XLA SPMD 通过专用实例与 torch.distributed.checkpoint 库兼容。用户可以通过这个通用接口同步保存和加载 checkpoint。Planner
SPMDSavePlanner 和 SPMDLoadPlanner (src) 类使 and 函数能够直接在 的分片上运行,从而在 SPMD 训练中实现分布式检查点的所有优势。save
load
XLAShardedTensor
以下是同步分布式检查点 API 的演示:
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc
# Saving a state_dict
state_dict = {
"model": model.state_dict(),
"optim": optim.state_dict(),
}
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
)
...
# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
"model": model.state_dict(),
}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])
实验性的 CheckpointManager 接口为函数提供了更高级别的 API,以启用一些关键功能:torch.distributed.checkpoint
托管检查点:每个检查点由 通过执行步骤来识别。跟踪的所有步骤都是可访问的 通过该方法,并且任何跟踪的步骤都可以 使用 恢复。
CheckpointManager
CheckpointManager.all_steps
CheckpointManager.restore
异步检查点:通过 API 获取的检查点将写入持久性存储 异步解锁,以在 checkpoint 的持续时间内解锁训练。这 input 分片 state_dict 首先移动到 CPU 中,然后再执行检查点 dispatch 到后台线程。
CheckpointManager.save_async
抢占时自动检查点:在 Cloud TPU 上,可以检测到抢占 以及在进程终止之前获取的检查点。要使用,请确保您的 TPU 是通过启用了自动检查点的 QueuedResource 配置的, 并确保在构造 CheckpointManager (默认情况下启用此选项)。
chkpt_on_preemption
FSSpec 支持:使用 fsspec 存储后端启用 直接对任何与 fsspec 兼容的文件系统(包括 GCS)执行检查点作。
CheckpointManager
CheckpointManager 的示例用法如下:
from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
# Choose the highest step
best_step = max(tracked_steps)
# Before restoring the checkpoint, the optimizer state must be primed
# to allow state to be loaded into it.
prime_optimizer(optim)
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
chkpt_mgr.restore(best_step, state_dict)
model.load_state_dict(state_dict['model'])
optim.load_state_dict(state_dict['optim'])
# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
...
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
if chkpt_mgr.save_async(step, state_dict):
print(f'Checkpoint taken at step {step}')
在分布式检查点中,state_dicts 是就地加载的,只有
加载 checkpoint 的必需分片。由于优化器状态是 lazyly
created,则状态在第一次调用之前不存在,并且
尝试加载未启动的优化器将失败。optimizer.step
为此提供了 Utility 方法:它运行一列假火车
步骤将所有梯度设置为零并调用 .这是一个
destructive 方法,并将同时触及模型参数和优化器状态,
因此,它应该只在恢复之前调用。prime_optimizer
optimizer.step
要使用分布式检查点等 API,需要一个进程
group 是必需的。在 SPMD 模式下,不支持后端,因为
compiler 负责所有 collectives。torch.distributed
xla
相反,必须使用 CPU 进程组(如 )。在 TPU 上,仍然支持init_method来发现主 IP、全局世界大小、
和主机等级。初始化示例如下:gloo
xla://
import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
xr.use_spmd()
# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')