如何实现分布式数据并行(DDP)¶
本文档介绍了如何在xla中使用torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生xla数据并行方法的区别。你可以在这里找到一个最小可运行示例这里。
背景 / 动机¶
客户长期以来一直希望能够使用 PyTorch 的 DistributedDataParallel API 与 xla 结合,并且我们现在将其作为实验性功能启用。
如何使用DistributedDataParallel¶
对于那些从PyTorch急切模式切换到XLA的人,这里是将您的急切DDP模型转换为XLA模型所需的所有更改。我们假设您已经知道如何在单个设备上使用XLA 。
导入特定于 XLA 的分布式包:
import torch_xla
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
初始化 XLA 进程组,与其他进程组(如 nccl 和 gloo)类似。
dist.init_process_group("xla", rank=rank, world_size=world_size)
如果您需要获取秩和世界大小,请使用与 XLA 相关的特定 API。
new_rank = xr.global_ordinal()
world_size = xr.world_size()
将
gradient_as_bucket_view=True传递给 DDP 包装器。
ddp_model = DDP(model, gradient_as_bucket_view=True)
最后使用特定的 XLA 启动器启动您的模型。
torch_xla.launch(demo_fn)
这里我们把所有东西都放在一起了(这个例子实际上取自 DDP教程)。 你编写代码的方式与急切模式的体验非常相似。只是在单个设备上有一些针对xla的具体修改,再加上对脚本的上述五个更改。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the xla process group
dist.init_process_group("xla", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 1000000)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1000000, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank):
# xla specific APIs to get rank, world_size.
new_rank = xr.global_ordinal()
assert new_rank == rank
world_size = xr.world_size()
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss_fn(outputs, labels).backward()
optimizer.step()
# xla specific API to execute the graph
xm.mark_step()
cleanup()
def run_demo(demo_fn):
# xla specific launcher
torch_xla.launch(demo_fn)
if __name__ == "__main__":
run_demo(demo_basic)
基准测试¶
Resnet50 与虚假数据¶
以下结果是通过在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA执行命令python
test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1收集的。统计指标是通过使用此拉取请求中的脚本生成的。速率的单位是每秒图像数。
| 类型 | 均值 | 中位数 | 90th % | 标准差 | CV |
| xm.optimizer_step | 418.54 | 419.22 | 430.40 | 9.76 | 0.02 |
| DDP | 395.97 | 395.54 | 407.13 | 7.60 | 0.02 |
我们本地的分布式数据并行方法与DistributedDataParallel包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。 这个结果在考虑到DistributedDataParallel包装器会在DDP运行时引入额外开销的情况下似乎是合理的。
MNIST与虚假数据¶
以下结果是通过在TPU VM V3-8环境中使用ToT PyTorch和PyTorch/XLA执行命令python
test/test_train_mp_mnist.py --fake_data收集的。统计指标是通过使用此拉取请求中的脚本生成的。速率的单位是每秒图像数。
| 类型 | 均值 | 中位数 | 90th % | 标准差 | CV |
| xm.optimizer_step | 17864.19 | 20108.96 | 24351.74 | 5866.83 | 0.33 |
| DDP | 10701.39 | 11770.00 | 14313.78 | 3102.92 | 0.29 |
我们本地的分布式数据并行方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。由于数据集较小,前几轮受数据加载影响较大,因此我们比较的是第 90 百分位数。这种性能下降很大,但考虑到模型本身较小,这在一定程度上是可以理解的。额外的 DDP 运行时跟踪开销难以摊销。
MNIST 与真实数据¶
以下结果是在TPU VM V3-8环境,使用命令:python
test/test_train_mp_mnist.py --logdir mnist/,ToT PyTorch和PyTorch/XLA收集的。
我们可以观察到,DDP 包装器即使最终达到了 97.48% 的高准确率,其收敛速度仍然慢于原生 XLA 方法。(原生方法的准确率为 99%。)
免责声明¶
此功能仍处于试验阶段并正在积极开发中。使用时请谨慎,并随时向 XLA GitHub 仓库提交任何错误报告。对于那些对原生 XLA 数据并行方法感兴趣的人,这里是教程。
以下是一些已知的问题,目前正在调查中:
gradient_as_bucket_view=True需要被强制执行。在使用时会遇到一些问题,
torch.utils.data.DataLoader。test_train_mp_mnist.py在处理真实数据时会在退出前崩溃。
Fully Sharded Data Parallel (FSDP) 在 PyTorch XLA¶
PyTorch XLA 中的全分割数据并行(FSDP)是一种工具,用于将 Module 参数在数据并行工作者之间进行分割。
示例用法:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以单独分割各个层,并由外部包装器处理剩余的参数。
Notes:
The
XlaFullyShardedDataParallel类支持在 https://arxiv.org/abs/1910.02054 中的 ZeRO-2 优化器(梯度和优化器状态分片)和 ZeRO-3 优化器(参数、梯度和优化器状态分片)。ZeRO-3 优化器应该通过嵌套 FSDP 实现
reshard_after_forward=True。见test/test_train_mp_mnist_fsdp_with_ckpt.py和test/test_train_mp_imagenet_fsdp.py中的例子。对于无法 fitting 到单个 TPU 内存或主机 CPU 内存中的大型模型,应该在子模块构建与内部 FSDP 包装之间交错进行。详见 ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_。
一个简单的包装器
checkpoint_module已提供(基于torch_xla.utils.checkpoint.checkpoint从 https://github.com/pytorch/xla/pull/3524),用于在给定的nn.Module实例上执行 梯度检查点。请参见test/test_train_mp_mnist_fsdp_with_ckpt.py和test/test_train_mp_imagenet_fsdp.py以获取示例。自动包装子模块:除了手动嵌套FSDP包装外,还可以指定一个
auto_wrap_policy参数,以自动用内部FSDP包装子模块。size_based_auto_wrap_policy在torch_xla.distributed.fsdp.wrap中是一个auto_wrap_policy可调用的示例,此策略将参数数量大于1亿的层进行包装。transformer_auto_wrap_policy在torch_xla.distributed.fsdp.wrap中是类似变换器模型架构的auto_wrap_policy可调用示例。
例如,要自动将所有torch.nn.Conv2d子模块用内嵌FSDP包裹起来,可以使用:
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
此外,还可以指定一个auto_wrapper_callable参数来使用自定义可调用包装器为子模块(默认包装器就是XlaFullyShardedDataParallel类本身)。例如,可以使用以下代码对每个自动包装的子模块应用梯度检查点(即激活检查点/重新材料化)。
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
在调用优化器时,请直接调用
optimizer.step,不要调用xm.optimizer_step。后者会在不同节点之间减少梯度,但在FSDP中已经对参数进行了分片,因此不需要这样做。在训练过程中保存模型和优化器检查点时,每个训练过程需要保存其自身的(分片的)模型和优化器状态字典(使用
master_only=False并设置不同的路径供每个节点在xm.save中使用)。当恢复训练时,它需要加载对应节点的检查点。请也将
model.get_shard_metadata()与model.state_dict()一起保存,并使用consolidate_sharded_model_checkpoints将模型切片检查点缝合为一个完整的模型状态字典。请参见test/test_train_mp_mnist_fsdp_with_ckpt.py以获取示例。 .. code-block:: python3- ckpt = {
‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),
} ckpt_path = f’/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)
检查点合并脚本也可以通过命令行启动,如下所示。 .. code-block:: bash
# consolidate the saved checkpoints via command line tool python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”
该类的实现主要受到并很大程度上遵循了https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html中fairscale.nn.FullyShardedDataParallel的启发。与fairscale.nn.FullyShardedDataParallel的一个最大不同在于,在XLA中我们没有显式的参数存储,因此这里我们采用不同的方法来为ZeRO-3释放完整参数。
MNIST和ImageNet的数据集训练示例¶
最小示例 : ``examples/fsdp/train_resnet_fsdp_auto_wrap.py` <https://github.com/pytorch/xla/blob/master/examples/fsdp/train_resnet_fsdp_auto_wrap.py>`_
MNIST: ``test/test_train_mp_mnist_fsdp_with_ckpt.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py>`_ (它也测试了检查点合并)
ImageNet: ``test/test_train_mp_imagenet_fsdp.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py>`_
安装¶
FSDP 在 PyTorch/XLA 1.12 版本及更新的 nightly 版本中可用。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。
克隆 PyTorch/XLA 仓库¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST ¶
它在两个epoch左右可以达到约98.9的准确率:
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
此脚本会在自动测试时合并检查点。您也可以手动合并分割的检查点 via
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上训练 ResNet-50 对图像识别模型¶
它在100个周期内大约获得75.9的准确率;下载 ImageNet-1k 到 /datasets/imagenet-1k:
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
您还可以添加 --use_gradient_checkpointing(这需要与 --use_nested_fsdp 或 --auto_wrap_policy 一起使用)以在残差块上应用梯度检查点。
TPU机架上的示例训练脚本(参数量为10亿)¶
要训练无法 fitting 到单个 TPU 的大型模型,应在构建整个模型时应用自动包装或手动将子模块用内嵌 FSDP 包装起来,以实现 ZeRO-3 算法。
请参见 https://github.com/ronghanghu/vit_10b_fsdp_example,了解如何使用此XLA FSDP PR进行分片训练的Vision Transformer (ViT) 模型示例。