如何执行 DistributedDataParallel(DDP)¶
本文档展示了如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel, 并进一步描述了它与原生 XLA 数据并行的区别 方法。你可以在这里找到一个最小可运行的例子。
背景 / 动机¶
长期以来,客户一直要求能够使用 PyTorch 的 DistributedDataParallel API 与 xla 一起使用。在这里,我们将其作为实验性 特征。
如何使用 DistributedDataParallel¶
对于从 PyTorch Eager 模式切换到 XLA 的用户,以下是所有 将 Eager DDP 模型转换为 XLA 模型所需进行的更改。我们假设 您已经知道如何在单个 设备。
导入特定于 xla 的分布式包:
import torch_xla
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
Init xla 进程组类似于其他进程组,例如 nccl 和 gloo。
dist.init_process_group("xla", rank=rank, world_size=world_size)
如果需要,请使用 xla 特定的 API 来获取排名和world_size。
new_rank = xr.global_ordinal()
world_size = xr.world_size()
传递给 DDP 包装器。
gradient_as_bucket_view=True
ddp_model = DDP(model, gradient_as_bucket_view=True)
最后,使用 xla 特定的启动器启动您的模型。
torch_xla.launch(demo_fn)
在这里,我们将所有内容放在一起(该示例实际上取自 DDP 教程)。 你编码的方式与 Eager 体验非常相似。只是用 xla 对单个设备的特定作以及对脚本的上述 5 项更改。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the xla process group
dist.init_process_group("xla", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 1000000)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1000000, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank):
# xla specific APIs to get rank, world_size.
new_rank = xr.global_ordinal()
assert new_rank == rank
world_size = xr.world_size()
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss_fn(outputs, labels).backward()
optimizer.step()
# xla specific API to execute the graph
xm.mark_step()
cleanup()
def run_demo(demo_fn):
# xla specific launcher
torch_xla.launch(demo_fn)
if __name__ == "__main__":
run_demo(demo_basic)
标杆¶
带有虚假数据的 Resnet50¶
使用以下命令收集以下结果:在
具有 ToT PyTorch 和 PyTorch/XLA 的 TPU VM V3-8 环境。以及统计数据
指标是使用此 pull 中的脚本生成的
请求。费率的单位为
每秒图像数。python
test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
类型 | 意味 着 | 中位数 | 第 90 % | 标准开发 | 简历 |
xm.optimizer_step | 418.54 | 419.22 | 430.40 | 9.76 | 0.02 |
DDP | 395.97 | 395.54 | 407.13 | 7.60 | 0.02 |
我们针对分布式数据的原生方法之间的性能差异 parallel 和 DistributedDataParallel 包装器为:1 - 395.97 / 418.54 = 5.39%。 鉴于 DDP 包装器引入了额外的开销,此结果似乎是合理的 跟踪 DDP 运行时。
带有虚假数据的 MNIST¶
使用以下命令收集以下结果:在具有 ToT 的 TPU VM V3-8 环境中
PyTorch 和 PyTorch/XLA。统计指标是使用
脚本。这
速率的单位是每秒图像数。python
test/test_train_mp_mnist.py --fake_data
类型 | 意味 着 | 中位数 | 第 90 % | 标准开发 | 简历 |
xm.optimizer_step | 17864.19 | 20108.96 | 24351.74 | 5866.83 | 0.33 |
DDP | 10701.39 | 11770.00 | 14313.78 | 3102.92 | 0.29 |
我们针对分布式数据的原生方法之间的性能差异 parallel 和 DistributedDataParallel 包装器为:1 - 14313.78 / 24351.74 = 41.22%.在这里,我们比较第 90% 个,因为数据集很小,首先是一个 很少有轮次受到数据加载的严重影响。这种放缓是巨大的,但使 意义,因为模型很小。额外的 DDP 运行时跟踪开销为 难以摊销。
具有真实数据的 MNIST¶
使用以下命令收集以下结果:在 TPU VM V3-8 环境中,使用
ToT PyTorch 和 PyTorch/XLA。python
test/test_train_mp_mnist.py --logdir mnist/

我们可以观察到 DDP 包装器的收敛速度比原生 XLA 慢 方法,即使它仍然在 结束。(本机方法达到 99%。
免責聲明¶
此功能仍处于试验阶段,正在积极开发中。在 注意事项,并随时将任何错误提交到 XLA Github repo 中。对于那些对 原生 XLA 数据并行方法,这里是教程。
以下是一些正在调查的已知问题:
gradient_as_bucket_view=True
需要强制执行。与 一起使用时存在一些问题。 with Real Data 在退出之前崩溃。
torch.utils.data.DataLoader
test_train_mp_mnist.py
PyTorch XLA 中的完全分片数据并行 (FSDP)¶
PyTorch XLA 中的完全分片数据并行 (FSDP) 是一个用于跨数据并行工作程序对模块参数进行分片的实用程序。
用法示例:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
还可以单独对各个层进行分片,并让外部包装器处理任何剩余的参数。
笔记:
该类在 https://arxiv.org/abs/1910.02054 中同时支持 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。
XlaFullyShardedDataParallel
ZeRO-3 优化器应通过嵌套 FSDP 实现。请参阅 和 有关示例。
reshard_after_forward=True
test/test_train_mp_mnist_fsdp_with_ckpt.py
test/test_train_mp_imagenet_fsdp.py
对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应将子模块构造与内部 FSDP 包装交错。有关示例,请参见 ''FSDPViTModel' <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>'_。
提供了一个简单的包装器(基于 FROM https://github.com/pytorch/xla/pull/3524)来对给定实例执行梯度检查点。请参阅 和 有关示例。
checkpoint_module
torch_xla.utils.checkpoint.checkpoint
nn.Module
test/test_train_mp_mnist_fsdp_with_ckpt.py
test/test_train_mp_imagenet_fsdp.py
自动包装子模块:除了手动嵌套的 FSDP 包装外,还可以指定一个参数来自动使用内部 FSDP 包装子模块。 in 是 callable 的一个示例,此策略包装参数数量大于 100M 的层。 in 是类似 transformer 的模型架构的 Callable 示例。
auto_wrap_policy
size_based_auto_wrap_policy
torch_xla.distributed.fsdp.wrap
auto_wrap_policy
transformer_auto_wrap_policy
torch_xla.distributed.fsdp.wrap
auto_wrap_policy
例如,要使用内部 FSDP 自动包装所有子模块,可以使用:torch.nn.Conv2d
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
此外,还可以指定一个参数来为子模块使用自定义的可调用包装器(默认包装器只是类本身)。例如,可以使用以下内容将梯度检查点(即激活检查点/重新物质化)应用于每个自动包装的子模块。auto_wrapper_callable
XlaFullyShardedDataParallel
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
单步执行优化器时,直接调用,不要调用 。后者减少了跨等级的梯度,这对于 FSDP (参数已经分片)来说是必需的。
optimizer.step
xm.optimizer_step
在训练期间保存 model 和 optimizer 检查点时,每个训练过程都需要保存自己的(分片)模型和优化器状态字典的检查点(使用 和 为每个 rank 设置不同的路径)。恢复时,需要加载对应 rank 的 checkpoint。
master_only=False
xm.save
还请按如下方式保存,并用于将分片的模型检查点拼接成一个完整的模型状态字典。有关示例,请参阅 。 ..代码块::python3
model.get_shard_metadata()
model.state_dict()
consolidate_sharded_model_checkpoints
test/test_train_mp_mnist_fsdp_with_ckpt.py
- ckpt = {
'model':model.state_dict()、 'shard_metadata':model.get_shard_metadata()、 'optimizer':optimizer.state_dict()、
} ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth' xm.save(ckpt, ckpt_path, master_only=False)
也可以从命令行启动 checkpoint consolidation 脚本,如下所示。 ..代码块:: bash
# 通过命令行工具整合保存的 checkpoint python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”
此类的实现在很大程度上受到 in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 的启发,并且主要遵循 的结构。最大的区别之一是,在 XLA 中,我们没有显式参数存储,因此在这里我们采用不同的方法来释放 ZeRO-3 的完整参数。fairscale.nn.FullyShardedDataParallel
fairscale.nn.FullyShardedDataParallel
MNIST 和 ImageNet 上的示例训练脚本¶
最小示例 : ''examples/fsdp/train_resnet_fsdp_auto_wrap.py' <https://github.com/pytorch/xla/blob/master/examples/fsdp/train_resnet_fsdp_auto_wrap.py>`_
MNIST: ''test/test_train_mp_mnist_fsdp_with_ckpt.py' <https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py>'_ (它还测试检查点整合)
ImageNet: ''test/test_train_mp_imagenet_fsdp.py' <https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py>`_
安装¶
FSDP 在 PyTorch/XLA 1.12 版本和更新版本上每晚提供。有关安装指南,请参阅 https://github.com/pytorch/xla#-available-images-and-wheels。
克隆 PyTorch/XLA 存储库¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST¶
它在 2 个 epoch 中获得大约 98.9 的准确率:
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
此脚本在结束时自动测试 checkpoint consolidation。您还可以通过以下方式手动整合分片的 checkpoint
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet¶
它在 75.9 个 epoch 中获得大约 100 的准确率;下载 ImageNet-1k 到 :/datasets/imagenet-1k
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
您还可以添加 (需要与 or 一起使用 ) 以对残差块应用梯度检查点。--use_gradient_checkpointing
--use_nested_fsdp
--auto_wrap_policy
TPU Pod 上的示例训练脚本(具有 100 亿个参数)¶
要训练无法放入单个 TPU 的大型模型,在构建整个模型以实现 ZeRO-3 算法时,应应用自动包装或使用内部 FSDP 手动包装子模块。
有关使用此 XLA FSDP PR 对 Vision Transformer (ViT) 模型进行分片训练的示例,请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example。