目录

如何实现分布式数据并行(DDP)

本文档介绍了如何在xla中使用torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生xla数据并行方法的区别。你可以在这里找到一个最小可运行示例这里

背景 / 动机

客户长期以来一直希望能够使用 PyTorch 的 DistributedDataParallel API 与 xla 结合,并且我们现在将其作为实验性功能启用。

如何使用DistributedDataParallel

对于那些从PyTorch急切模式切换到XLA的人,这里是将您的急切DDP模型转换为XLA模型所需的所有更改。我们假设您已经知道如何在单个设备上使用XLA

  1. 导入特定于 XLA 的分布式包:

import torch_xla
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
  1. 初始化 XLA 进程组,与其他进程组(如 nccl 和 gloo)类似。

dist.init_process_group("xla", rank=rank, world_size=world_size)
  1. 如果您需要获取秩和世界大小,请使用与 XLA 相关的特定 API。

new_rank = xr.global_ordinal()
world_size = xr.world_size()
  1. gradient_as_bucket_view=True 传递给 DDP 包装器。

ddp_model = DDP(model, gradient_as_bucket_view=True)
  1. 最后使用特定的 XLA 启动器启动您的模型。

torch_xla.launch(demo_fn)

这里我们把所有东西都放在一起了(这个例子实际上取自 DDP教程)。 你编写代码的方式与急切模式的体验非常相似。只是在单个设备上有一些针对xla的具体修改,再加上对脚本的上述五个更改。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xr.global_ordinal()
    assert new_rank == rank
    world_size = xr.world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    torch_xla.launch(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

基准测试

Resnet50 与虚假数据

以下结果是通过在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA执行命令python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1收集的。统计指标是通过使用此拉取请求中的脚本生成的。速率的单位是每秒图像数。

类型 均值 中位数 90th % 标准差 CV
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

我们本地的分布式数据并行方法与DistributedDataParallel包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。 这个结果在考虑到DistributedDataParallel包装器会在DDP运行时引入额外开销的情况下似乎是合理的。

MNIST与虚假数据

以下结果是通过在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA执行命令python test/test_train_mp_mnist.py --fake_data收集的。统计指标是通过使用此拉取请求中的脚本生成的。速率的单位是每秒图像数。

类型 均值 中位数 90th % 标准差 CV
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

我们本地的分布式数据并行方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。由于数据集较小,前几轮受数据加载影响较大,因此我们比较的是第 90 百分位数。这种性能下降很大,但考虑到模型本身较小,这在一定程度上是可以理解的。额外的 DDP 运行时跟踪开销难以摊销。

MNIST 与真实数据

以下结果是在TPU VM V3-8环境,使用命令:python test/test_train_mp_mnist.py --logdir mnist/,ToT PyTorch和PyTorch/XLA收集的。

learning_curves

我们可以观察到,DDP 包装器即使最终达到了 97.48% 的高准确率,其收敛速度仍然慢于原生 XLA 方法。(原生方法的准确率为 99%。)

免责声明

此功能仍处于试验阶段并正在积极开发中。使用时请谨慎,并随时向 XLA GitHub 仓库提交任何错误报告。对于那些对原生 XLA 数据并行方法感兴趣的人,这里是教程

以下是一些已知的问题,目前正在调查中:

  • gradient_as_bucket_view=True 需要被强制执行。

  • 在使用时会遇到一些问题,torch.utils.data.DataLoader​​test_train_mp_mnist.py在处理真实数据时会在退出前崩溃。

Fully Sharded Data Parallel (FSDP) 在 PyTorch XLA

PyTorch XLA 中的全分割数据并行(FSDP)是一种工具,用于将 Module 参数在数据并行工作者之间进行分割。

示例用法:

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以单独分割各个层,并由外部包装器处理剩余的参数。

Notes:

  • The XlaFullyShardedDataParallel 类支持在 https://arxiv.org/abs/1910.02054 中的 ZeRO-2 优化器(梯度和优化器状态分片)和 ZeRO-3 优化器(参数、梯度和优化器状态分片)。

    • ZeRO-3 优化器应该通过嵌套 FSDP 实现 reshard_after_forward=True。见 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 中的例子。

    • 对于无法 fitting 到单个 TPU 内存或主机 CPU 内存中的大型模型,应该在子模块构建与内部 FSDP 包装之间交错进行。详见 ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_。

  • 一个简单的包装器 checkpoint_module 已提供(基于 torch_xla.utils.checkpoint.checkpointhttps://github.com/pytorch/xla/pull/3524),用于在给定的 nn.Module 实例上执行 梯度检查点。请参见 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 以获取示例。

  • 自动包装子模块:除了手动嵌套FSDP包装外,还可以指定一个auto_wrap_policy参数,以自动用内部FSDP包装子模块。size_based_auto_wrap_policytorch_xla.distributed.fsdp.wrap中是一个auto_wrap_policy可调用的示例,此策略将参数数量大于1亿的层进行包装。transformer_auto_wrap_policytorch_xla.distributed.fsdp.wrap中是类似变换器模型架构的auto_wrap_policy可调用示例。

例如,要自动将所有torch.nn.Conv2d子模块用内嵌FSDP包裹起来,可以使用:

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,还可以指定一个auto_wrapper_callable参数来使用自定义可调用包装器为子模块(默认包装器就是XlaFullyShardedDataParallel类本身)。例如,可以使用以下代码对每个自动包装的子模块应用梯度检查点(即激活检查点/重新材料化)。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在调用优化器时,请直接调用 optimizer.step,不要调用 xm.optimizer_step。后者会在不同节点之间减少梯度,但在FSDP中已经对参数进行了分片,因此不需要这样做。

  • 在训练过程中保存模型和优化器检查点时,每个训练过程需要保存其自身的(分片的)模型和优化器状态字典(使用master_only=False并设置不同的路径供每个节点在xm.save中使用)。当恢复训练时,它需要加载对应节点的检查点。

  • 请也将 model.get_shard_metadata()model.state_dict() 一起保存,并使用 consolidate_sharded_model_checkpoints 将模型切片检查点缝合为一个完整的模型状态字典。请参见 test/test_train_mp_mnist_fsdp_with_ckpt.py 以获取示例。 .. code-block:: python3

    ckpt = {

    ‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),

    } ckpt_path = f’/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)

  • 检查点合并脚本也可以通过命令行启动,如下所示。 .. code-block:: bash

    # consolidate the saved checkpoints via command line tool python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”

该类的实现主要受到并很大程度上遵循了https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.htmlfairscale.nn.FullyShardedDataParallel的启发。与fairscale.nn.FullyShardedDataParallel的一个最大不同在于,在XLA中我们没有显式的参数存储,因此这里我们采用不同的方法来为ZeRO-3释放完整参数。


MNIST和ImageNet的数据集训练示例

安装

FSDP 在 PyTorch/XLA 1.12 版本及更新的 nightly 版本中可用。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。

克隆 PyTorch/XLA 仓库

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上训练 MNIST

它在两个epoch左右可以达到约98.9的准确率:

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此脚本会在自动测试时合并检查点。您也可以手动合并分割的检查点 via

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上训练 ResNet-50 对图像识别模型

它在100个周期内大约获得75.9的准确率;下载 ImageNet-1k/datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您还可以添加 --use_gradient_checkpointing(这需要与 --use_nested_fsdp--auto_wrap_policy 一起使用)以在残差块上应用梯度检查点。


TPU机架上的示例训练脚本(参数量为10亿)

要训练无法 fitting 到单个 TPU 的大型模型,应在构建整个模型时应用自动包装或手动将子模块用内嵌 FSDP 包装起来,以实现 ZeRO-3 算法。

请参见 https://github.com/ronghanghu/vit_10b_fsdp_example,了解如何使用此XLA FSDP PR进行分片训练的Vision Transformer (ViT) 模型示例。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源