目录

PyTorch/XLA SPMD用户指南

在本用户指南中,我们将讨论 GSPMD 如何集成到 PyTorch/XLA 中,并提供设计概述以说明 SPMD 分片注解 API 及其构造的工作原理。

PyTorch/XLA SPMD?

GSPMD 是一个用于常见机器学习工作负载的自动并行化系统。XLA 编译器将根据用户提供的分割提示,将单设备程序转换为适当的分区程序。此功能允许开发人员像在单个大设备上编写 PyTorch 程序一样编写代码,无需任何自定义分割计算操作和/或集体通信。

alt_text

*图1. 两种不同执行策略的比较,(a) 非SPMD 和 (b) SPMD。*

如何使用PyTorch/XLA SPMD?

这是一个使用SPMD的简单示例。

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh


# Enable XLA SPMD execution mode.
xr.use_spmd()


# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))


t = torch.randn(8, 4).to(xm.xla_device())


# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)

让我们一一解释这些概念

SPMD模式

为了使用SPMD,您需要通过xr.use_spmd()启用它。在SPMD模式下只有一个逻辑设备。mark_sharding负责分布式计算和集体操作。请注意,用户不能将SPMD与其他分布式库混合使用。

Mesh

对于一组给定的设备,物理网是一个表示互联拓扑的关系图。

  1. mesh_shape 是一个元组,将被乘到物理设备总数上。

  2. device_ids 几乎总是小于 np.array(range(num_devices))

  3. 用户也应该给每个网格维度命名。在上述示例中,第一个网格维度是data维度,第二个网格维度是model维度。

您还可以通过更多网格信息了解详情。

>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])

分区规范

partition_spec 的秩与输入张量相同。每个维度描述了对应的输入张量维度如何在设备网格中划分。在上述示例中,张量 t 的第一个维度在 data 维度上进行划分,第二个维度在 model 维度上进行划分。

用户也可以分割维度与网格形状不同的张量。

t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)

# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))

# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))

# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))

进一步阅读

  1. 示例 使用SPMD表达数据并行。

  2. 示例 使用SPMD表达FSDP(完全分片数据并行)。

  3. SPMD高级主题

  4. Spmd分布式检查点

全分割数据并行(FSDP) 通过SPMD

通过SPMD或FSDPv2的完全分片数据并行是一种重新表达著名的FSDP算法在SPMD中的工具。 是一个实验性功能,旨在为用户提供一个熟悉的接口,以享受SPMD带来的所有好处。设计文档 在这里

在继续之前,请先查看SPMD 用户指南。您也可以在这里找到一个最小可运行示例这里

示例用法:

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))

# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))

# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也有可能单独分割各个层,并由外部包装处理剩余的参数。以下是一个示例,自动包装每个DecoderLayer

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        decoder_only_model.DecoderLayer
    },
)
model = FSDPv2(
    model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)

Sharding 输出

为了确保XLA编译器正确实现FSDP算法,我们需要对权重和激活值进行分片。这意味着需要对前向方法的输出进行分片。由于前向函数的输出可能会有所不同,我们提供了`shard_output`来在您的模块输出不属于以下类别时对激活值进行分片。

  1. 一个张量

  2. 一个张量元组,其中第 0 个元素是激活值。

示例用法:

def shard_output(output, mesh):
    xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))

model = FSDPv2(my_module, mesh, shard_output)

梯度检查点技术

目前,在应用FSDP包装器之前,需要对模块应用梯度检查点技术。否则,递归地遍历子模块可能会导致无限循环。我们将在未来的版本中修复此问题。

示例用法:

from torch_xla.distributed.fsdp import checkpoint_module

model = FSDPv2(checkpoint_module(my_module), mesh)

HuggingFace Llama 2 示例

我们有一个HF Llama 2的分支,用于演示潜在的集成 这里

PyTorch/XLA SPMD 高级主题

在本文件中,我们将介绍一些关于GSPMD的高级主题。请在继续阅读本文档之前阅读SPMD用户指南

PyTorch/XLA SPMD 将单设备程序分割并在多个设备上并行执行。SPMD 执行需要使用原生的 PyTorch DataLoader,该加载器会同步地将数据从主机传输到 XLA 设备。这会在每一步输入数据传输期间阻塞训练。为了提高原生数据加载性能,我们让 PyTorch/XLA ParallelLoader 支持直接输入分割(src),当传递可选关键字参数 _input_sharding_ 时。

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
        train_loader,  # wraps PyTorch DataLoader
        device,
          # assume 4d input and we want to shard at the batch dimension.
        input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))

如果每个批次元素的形状不同,也可以为它们指定不同的input_sharding

# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
        train_loader,  # wraps PyTorch DataLoader
        device,
          # specify different sharding for each input of the batch.
        input_sharding={
          'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
          'y': xs.ShardingSpec(input_mesh, ('data', None))
        }
)

PyTorch/XLA通常在定义张量后异步地将张量数据从主机传输到设备,以重叠数据传输与图跟踪时间。然而,由于GSPMD允许用户在张量定义之后修改张量分割方式,我们需要一种优化方法来防止不必要的张量数据在主机和设备之间来回传输。为此,我们引入了虚拟设备优化技术,在所有分割决策最终确定之前,先将张量数据放置在虚拟设备SPMD:0上,然后再上传到物理设备。在SPMD模式下,每张张量数据都会被放置在虚拟设备SPMD:0上。虚拟设备会作为XLA设备XLA:0暴露给用户,实际的分割片则位于TPU:0、TPU:1等物理设备上。

混合网格

Mesh 很好地抽象了物理设备网格的构建方式。用户可以使用逻辑网格以任意形状和顺序排列设备。然而,在涉及数据中心网络(DCN)跨切片连接的情况下,可以根据物理拓扑定义更高效的网格。HybridMesh 为这种多切片环境创建了一个即插即用且性能良好的网格。它接受 ici_mesh_shape 和 dcn_mesh_shape,分别表示内网和外网的逻辑网格形状。

from torch_xla.distributed.spmd import HybridMesh

# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

在TPU Pod上运行SPMD

从单个TPU主机转换到TPU Pod不需要更改代码,如果你根据设备数量而不是某些硬编码常量来构建你的网格和分区规范。要在TPU Pod上运行PyTorch/XLA工作负载,请参阅我们PJRT指南的Pods部分

XLAShardedTensor

xs.mark_sharding 是一个原地操作,将会在输入张量上附加分片注解,但同时也会返回一个 XLAShardedTensor 的 Python 对象。

The main use case for XLAShardedTensor [RFC] 是为了在单个设备上的原生 torch.tensor 上添加分片规格。注释立即发生,但实际分片操作会在计算过程中延迟进行,除非是输入张量,它们会立即分片。一旦张量被注释并包裹在 XLAShardedTensor 中,它就可以作为 nn.Module 层传递给现有的 PyTorch 操作和 nn.Module 层,以满足 torch.Tensor 的要求。这意味着用户不需要为分片计算重写现有的操作和模型代码。具体来说,XLAShardedTensor 将满足以下要求:

  • XLAShardedTensortorch.Tensor 的一个子类,并且可以直接与原生的 torch 操作和 module.layers 进行交互。我们使用 __torch_dispatch__XLAShardedTensor 发送到 XLA 后端。PyTorch/XLA 会检索附加的分片注解以跟踪图形并调用 XLA SPMDPartitioner。

  • Internally, XLAShardedTensor(及其global_tensor输入)由XLATensor支持,后者使用特殊的数据结构来持有对分割设备数据的引用。

  • 在惰性执行后,分片张量可能在主机上请求时被收集并重新材料化为全局张量(例如,打印全局张量的值)。

  • The handles to the local shards are materialized strictly after the lazy execution. XLAShardedTensor 暴露了 local_shards 方法以返回可寻址设备上的本地碎片,表示为 List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]

也有一个正在进行的努力,将XLAShardedTensor集成到DistributedTensorAPI以支持XLA后端[RFC]。

DTensor集成

PyTorch 在 2.1 版本中原型发布了 DTensor。 我们正在将 PyTorch/XLA SPMD 整合到 DTensor API RFC 中。我们有一个针对 distribute_tensor 的概念验证集成,它使用 XLA 调用 mark_sharding 注解 API 来划分张量及其计算:

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

该功能尚处于实验阶段,请关注后续发布的更新、示例和教程。

Activation Sharding for torch.compile

在2.3版本中,PyTorch/XLA添加了自定义操作dynamo_mark_sharding,可以在torch.compile区域中执行激活分片。这属于我们持续努力的一部分,旨在使torch.compile+GSPMD成为使用PyTorch/XLA进行模型推理推荐的方式。使用此自定义操作的示例:

# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)

SPMD调试工具

我们为使用单主机/多主机的TPU/GPU/CPU上的PyTorch/XLA SPMD用户提供了一个shard placement visualization debug tool:您可以使用visualize_tensor_sharding来可视化分片张量,或者使用visualize_sharding来可视化共享字符串。以下是TPU单主机(v4-8)上的两个代码示例,分别使用visualize_tensor_shardingvisualize_sharding

  • 代码片段用于 visualize_tensor_sharding 和可视化结果:

import rich

# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))

# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
visualize_tensor_sharding example on TPU v4-8(single-host)
  • 代码片段用于 visualize_sharding 和可视化结果:

from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
visualize_sharding example on TPU v4-8(single-host)

您可以在TPU/GPU/CPU单主机上使用这些示例,并对其进行修改以在多主机上运行。并且您可以将其修改为tiledpartial_replicationreplicated的分片样式。

Auto-Sharding

我们正在介绍一个新的PyTorch/XLA SPMD功能,称为auto-shardingRFC。这在r2.3nightly中是一个实验性功能,并支持XLA:TPU和一个单个TPUVM主机。

PyTorch/XLA 自动分片可以通过以下方式启用:

  • 设置环境变量 XLA_AUTO_SPMD=1

  • 在代码的开始处调用SPMD API:

import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • 调用 pytorch.distributed._tensor.distribute_moduleauto-policyxla

import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

可选地,可以设置以下选项/环境变量来控制基于XLA的自动分片优化的行为:

  • XLA_AUTO_USE_GROUP_SHARDING: 参数的分组重分布。默认设置。

  • XLA_AUTO_SPMD_MESH: 逻辑网格形状,用于自动分片。例如,XLA_AUTO_SPMD_MESH=2,2 对应一个由4个全局设备组成的2x2网格。如果未设置,默认使用一个设备网格形状为 num_devices,1

分布式检查点

PyTorch/XLA SPMD 与 torch.distributed.checkpoint 库兼容,通过专用的 Planner 实例。用户能够通过这个通用接口同步保存和加载检查点。

SPMDSavePlanner 和 SPMDLoadPlanner 类(源码)使 saveload 函数能够直接在 XLAShardedTensor 的分片上操作,从而实现 SPMD 训练中分布式检查点的所有优势。

这是一个同步分布式检查点 API 的演示:

import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

# Saving a state_dict
state_dict = {
    "model": model.state_dict(),
    "optim": optim.state_dict(),
}

dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
    planner=xc.SPMDSavePlanner(),
)
...

# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
    "model": model.state_dict(),
}

dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])

实验性的 CheckpointManager 接口提供了对 torch.distributed.checkpoint 函数的更高层次API,以启用一些关键功能:

  • 管理检查点: 每个由CheckpointManager捕获的检查点都通过所处的步数进行标识。所有跟踪的步数都可以通过CheckpointManager.all_steps方法访问,并且可以使用CheckpointManager.restore恢复任何跟踪的步数。

  • 异步检查点: 通过 CheckpointManager.save_async API 取得的检查点会异步写入持久存储,以在检查点期间不阻塞训练。输入分片的状态字典首先会被移动到 CPU,然后检查点被派发到后台线程。

  • 自动检查点在抢占时启用:在Cloud TPU上,可以检测到抢占并在进程终止前创建检查点。要使用此功能,请确保您的TPU通过启用了自动检查点的QueuedResource进行配置,并确保在构造CheckpointManager时设置了chkpt_on_preemption参数(此选项默认启用)。

  • FSSpec 支持: CheckpointManager 使用 fsspec 存储后端,可以直接将检查点保存到任何 fsspec 兼容的文件系统,包括 GCS。

以下是 CheckpointManager 的示例用法:

from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer

# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)

# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
    # Choose the highest step
    best_step = max(tracked_steps)
    # Before restoring the checkpoint, the optimizer state must be primed
    # to allow state to be loaded into it.
    prime_optimizer(optim)
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    chkpt_mgr.restore(best_step, state_dict)
    model.load_state_dict(state_dict['model'])
    optim.load_state_dict(state_dict['optim'])

# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
    ...
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    if chkpt_mgr.save_async(step, state_dict):
        print(f'Checkpoint taken at step {step}')

在分布式检查点中,state_dicts 是就地加载的,并且只加载检查点所需的分片。由于优化器状态是懒加载的,直到第一次 optimizer.step 调用时,状态才存在,尝试加载未初始化的优化器将会失败。

The utility method prime_optimizer is provided for this: it runs a fake train step by setting all gradients to zero and calling optimizer.step. 这是一个破坏性方法,将会修改模型参数和优化器状态, 所以它应该只在即将恢复时被调用。

要使用torch.distributed个API,如分布式检查点,需要一个进程组。在SPMD模式下,xla后端不受支持,因为编译器负责所有的集体操作。

相反,必须使用CPU进程组,例如gloo。在TPUs上,仍然支持xla:// init_method来发现主IP、全局世界大小和主机排名。以下是一个初始化示例:

import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr

xr.use_spmd()

# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源