目录

如何执行 DistributedDataParallel(DDP)

本文档展示了如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel, 并进一步描述了它与原生 XLA 数据并行的区别 方法。你可以在这里找到一个最小可运行的例子。

背景 / 动机

长期以来,客户一直要求能够使用 PyTorch 的 DistributedDataParallel API 与 xla 一起使用。在这里,我们将其作为实验性 特征。

如何使用 DistributedDataParallel

对于从 PyTorch Eager 模式切换到 XLA 的用户,以下是所有 将 Eager DDP 模型转换为 XLA 模型所需进行的更改。我们假设 您已经知道如何在单个 设备

  1. 导入特定于 xla 的分布式包:

import torch_xla
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
  1. Init xla 进程组类似于其他进程组,例如 nccl 和 gloo。

dist.init_process_group("xla", rank=rank, world_size=world_size)
  1. 如果需要,请使用 xla 特定的 API 来获取排名和world_size。

new_rank = xr.global_ordinal()
world_size = xr.world_size()
  1. 传递给 DDP 包装器。gradient_as_bucket_view=True

ddp_model = DDP(model, gradient_as_bucket_view=True)
  1. 最后,使用 xla 特定的启动器启动您的模型。

torch_xla.launch(demo_fn)

在这里,我们将所有内容放在一起(该示例实际上取自 DDP 教程)。 你编码的方式与 Eager 体验非常相似。只是用 xla 对单个设备的特定作以及对脚本的上述 5 项更改。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xr.global_ordinal()
    assert new_rank == rank
    world_size = xr.world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    torch_xla.launch(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

标杆

带有虚假数据的 Resnet50

使用以下命令收集以下结果:在 具有 ToT PyTorch 和 PyTorch/XLA 的 TPU VM V3-8 环境。以及统计数据 指标是使用此 pull 中的脚本生成的 请求。费率的单位为 每秒图像数。python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

类型 意味 着 中位数 第 90 % 标准开发 简历
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

我们针对分布式数据的原生方法之间的性能差异 parallel 和 DistributedDataParallel 包装器为:1 - 395.97 / 418.54 = 5.39%。 鉴于 DDP 包装器引入了额外的开销,此结果似乎是合理的 跟踪 DDP 运行时。

带有虚假数据的 MNIST

使用以下命令收集以下结果:在具有 ToT 的 TPU VM V3-8 环境中 PyTorch 和 PyTorch/XLA。统计指标是使用 脚本。这 速率的单位是每秒图像数。python test/test_train_mp_mnist.py --fake_data

类型 意味 着 中位数 第 90 % 标准开发 简历
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

我们针对分布式数据的原生方法之间的性能差异 parallel 和 DistributedDataParallel 包装器为:1 - 14313.78 / 24351.74 = 41.22%.在这里,我们比较第 90% 个,因为数据集很小,首先是一个 很少有轮次受到数据加载的严重影响。这种放缓是巨大的,但使 意义,因为模型很小。额外的 DDP 运行时跟踪开销为 难以摊销。

具有真实数据的 MNIST

使用以下命令收集以下结果:在 TPU VM V3-8 环境中,使用 ToT PyTorch 和 PyTorch/XLA。python test/test_train_mp_mnist.py --logdir mnist/

learning_curves

我们可以观察到 DDP 包装器的收敛速度比原生 XLA 慢 方法,即使它仍然在 结束。(本机方法达到 99%。

免責聲明

此功能仍处于试验阶段,正在积极开发中。在 注意事项,并随时将任何错误提交到 XLA Github repo 中。对于那些对 原生 XLA 数据并行方法,这里是教程

以下是一些正在调查的已知问题:

  • gradient_as_bucket_view=True需要强制执行。

  • 与 一起使用时存在一些问题。 with Real Data 在退出之前崩溃。torch.utils.data.DataLoader​​test_train_mp_mnist.py

PyTorch XLA 中的完全分片数据并行 (FSDP)

PyTorch XLA 中的完全分片数据并行 (FSDP) 是一个用于跨数据并行工作程序对模块参数进行分片的实用程序。

用法示例:

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

还可以单独对各个层进行分片,并让外部包装器处理任何剩余的参数。

笔记:

  • 该类在 https://arxiv.org/abs/1910.02054 中同时支持 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。XlaFullyShardedDataParallel

    • ZeRO-3 优化器应通过嵌套 FSDP 实现。请参阅 和 有关示例。reshard_after_forward=Truetest/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py

    • 对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应将子模块构造与内部 FSDP 包装交错。有关示例,请参见 ''FSDPViTModel' <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>'_。

  • 提供了一个简单的包装器(基于 FROM https://github.com/pytorch/xla/pull/3524)来对给定实例执行梯度检查点。请参阅 和 有关示例。checkpoint_moduletorch_xla.utils.checkpoint.checkpointnn.Moduletest/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py

  • 自动包装子模块:除了手动嵌套的 FSDP 包装外,还可以指定一个参数来自动使用内部 FSDP 包装子模块。 in 是 callable 的一个示例,此策略包装参数数量大于 100M 的层。 in 是类似 transformer 的模型架构的 Callable 示例。auto_wrap_policysize_based_auto_wrap_policytorch_xla.distributed.fsdp.wrapauto_wrap_policytransformer_auto_wrap_policytorch_xla.distributed.fsdp.wrapauto_wrap_policy

例如,要使用内部 FSDP 自动包装所有子模块,可以使用:torch.nn.Conv2d

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,还可以指定一个参数来为子模块使用自定义的可调用包装器(默认包装器只是类本身)。例如,可以使用以下内容将梯度检查点(即激活检查点/重新物质化)应用于每个自动包装的子模块。auto_wrapper_callableXlaFullyShardedDataParallel

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 单步执行优化器时,直接调用,不要调用 。后者减少了跨等级的梯度,这对于 FSDP (参数已经分片)来说是必需的。optimizer.stepxm.optimizer_step

  • 在训练期间保存 model 和 optimizer 检查点时,每个训练过程都需要保存自己的(分片)模型和优化器状态字典的检查点(使用 和 为每个 rank 设置不同的路径)。恢复时,需要加载对应 rank 的 checkpoint。master_only=Falsexm.save

  • 还请按如下方式保存,并用于将分片的模型检查点拼接成一个完整的模型状态字典。有关示例,请参阅 。 ..代码块::python3model.get_shard_metadata()model.state_dict()consolidate_sharded_model_checkpointstest/test_train_mp_mnist_fsdp_with_ckpt.py

    ckpt = {

    'model':model.state_dict()、 'shard_metadata':model.get_shard_metadata()、 'optimizer':optimizer.state_dict()、

    } ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth' xm.save(ckpt, ckpt_path, master_only=False)

  • 也可以从命令行启动 checkpoint consolidation 脚本,如下所示。 ..代码块:: bash

    # 通过命令行工具整合保存的 checkpoint python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”

此类的实现在很大程度上受到 in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 的启发,并且主要遵循 的结构。最大的区别之一是,在 XLA 中,我们没有显式参数存储,因此在这里我们采用不同的方法来释放 ZeRO-3 的完整参数。fairscale.nn.FullyShardedDataParallelfairscale.nn.FullyShardedDataParallel


MNIST 和 ImageNet 上的示例训练脚本

安装

FSDP 在 PyTorch/XLA 1.12 版本和更新版本上每晚提供。有关安装指南,请参阅 https://github.com/pytorch/xla#-available-images-and-wheels

克隆 PyTorch/XLA 存储库

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上训练 MNIST

它在 2 个 epoch 中获得大约 98.9 的准确率:

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此脚本在结束时自动测试 checkpoint consolidation。您还可以通过以下方式手动整合分片的 checkpoint

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet

它在 75.9 个 epoch 中获得大约 100 的准确率;下载 ImageNet-1k 到 :/datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您还可以添加 (需要与 or 一起使用 ) 以对残差块应用梯度检查点。--use_gradient_checkpointing--use_nested_fsdp--auto_wrap_policy


TPU Pod 上的示例训练脚本(具有 100 亿个参数)

要训练无法放入单个 TPU 的大型模型,在构建整个模型以实现 ZeRO-3 算法时,应应用自动包装或使用内部 FSDP 手动包装子模块。

有关使用此 XLA FSDP PR 对 Vision Transformer (ViT) 模型进行分片训练的示例,请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源