PJRT 运行时¶
PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到了 PJRT 运行时,该运行时也被 JAX 使用。
如果在 PJRT 中遇到 Bug,请在 GitHub 上提交问题,并添加 runtime 标签。
PyTorch/XLA r2.1的新功能:
PJRT 在 PyTorch/XLA r2.1 中已经稳定!
公共运行时API已从
torch_xla.experimental.pjrt迁移到torch_xla.runtime。The
pjrt://init 方法已被重命名为xla://,并由torch_xla.distributed.xla_backend注册。之前的
torch_xla.experimental.*名称在此次发布中仍然可用,以保持兼容性。
torchrun现在支持使用init_method='xla://'。通过 PJRT C API 新增的 XPU 和 Neuron 插件。
PyTorch/XLA r2.0的新功能:
PJRT 将会默认配置,如果你没有传递其他运行时配置。如果你继续设置 XRT 配置(
XRT_TPU_CONFIG),这个更改没有任何影响New TPU运行时实现
libtpu提高性能最多30%。新
xm.rendezvous实现,可扩展到数千个TPU内核[实验性]
torch.distributed支持 TPU v2 和 v3,包括pjrt://init_method
TL;DR¶
要使用PJRT预览运行时,请将环境变量
PJRT_DEVICE设置为CPU、TPU或CUDA在XRT中,所有分布式工作负载都是多进程的,每个设备一个进程。在PJRT的TPU v2和v3上,工作负载是多进程和多线程的(4个进程,每个进程有2个线程),因此你的工作负载应该是线程安全的。更多信息请参见TPU v2/v3上的多线程和API指南中的多进程部分。需要注意的关键差异如下:
要在线程安全的方式下初始化模型,可以在初始化后将参数广播到各个副本 (
torch_xla.experimental.pjrt.broadcast_master_param),或者从公共检查点加载每个副本的参数。对于其他随机数生成,请尽可能使用
torch.Generator。 全局torch随机数生成器即使你在副本中设置了相同的torch.manual_seed也不是线程安全的。要使用
torch.distributed,导入torch_xla.experimental.pjrt_backend并 使用xla://init_method。这些步骤对于GPU和TPU v4是可选的。
XRT 和 PJRT 的样本差异:
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr
def _mp_fn(index):
device = xm.xla_device()
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+ dist.init_process_group('xla', init_method='xla://')
torch.manual_seed(42)
model = nn.Linear(128, 10).to(device)
+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
for i in range(10):
data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Print mean parameters so we can confirm they're the same across replicas
print([p.mean() for p in model.parameters()])
if __name__ == '__main__':
- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
+ # Recommended: set PJRT_DEVICE to your local device type
+ os.environ['PJRT_DEVICE'] = 'TPU'
torch_xla.launch(_mp_fn)
好处¶
简单的运行时配置:只需将
PJRT_DEVICE设置为TPU、CPU或CUDA,然后开始使用 XLA!或者,让 PJRT 基于您的环境自动选择设备。性能提升:减少的 gRPC 开销意味着端到端执行速度更快。在 TorchBench 2.0 中,我们在 TPU v4 上观察到了超过 35% 的训练时间改进。
Easy pod执行:只需将你的代码复制到每个TPU工作器,然后使用
gcloud compute tpus tpuvm ssh --worker=all同时运行它们。更好的扩展性:消除了 XRT的参数大小限制,并支持多达2048个TPU芯片。
快速上手¶
要开始使用PJRT与PyTorch/XLA,您只需要设置
PJRT_DEVICE 环境变量。如果您正在使用TPU v2或v3,请继续阅读以了解TPU v2和v3以及v4之间的差异。
CPU¶
在任何安装了PyTorch/XLA的机器上,你可以像这样在CPU上运行我们的MNIST示例:
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
TPU¶
要创建一个新的TPU并安装PyTorch/XLA r2.0:
gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT
在 v4-8 上,你可以这样运行我们的 ResNet50 示例:
git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
默认情况下,PJRT 将使用所有 TPU 芯片。要只使用一个 TPU 芯片,请配置
TPU_PROCESS_BOUNDS 和 TPU_VISIBLE_CHIPS:
TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Pods¶
在TPU Pod上,使用gcloud可以在每个TPU上并行运行您的命令:
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
Docker¶
您也可以使用 Docker 在容器中运行工作负载,并预先安装 PyTorch/XLA:
export DOCKER_IMAGE=gcr.io/...
# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"
# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"
请注意,docker run 需要对主机有特权访问权限 (--privileged)
才能将 TPU 设备暴露给容器。目前,TPU pod 上的 Docker 仅支持
使用主机网络 --net=host。有关更多信息,请参阅 Cloud TPU 文档。
GPU¶
单节点GPU训练¶
要使用PJRT的GPU,只需设置PJRT_DEVICE=CUDA并将GPU_NUM_DEVICES配置为主机上的设备数量。例如:
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
您也可以使用 torchrun 来启动单节点多GPU训练。例如,
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在上面的例子中,--nnodes 表示要使用的机器数量(物理机或虚拟机),它是 1,因为我们进行的是单节点训练。--nproc-per-node 表示要使用的 GPU 设备数量。
多节点GPU训练¶
请注意,此功能仅适用于cuda 12+. 类似于PyTorch如何进行多节点训练,您可以运行以下命令:
PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
--nnodes: 要使用的GPU机器数量。--node_rank: 当前GPU机器的索引。该值可以是0, 1, …, ${NUMBER_GPU_VM}-1。--nproc_per_node: 当前机器上使用的GPU设备数量。–rdzv_endpoint: GPU机器中 node_rank==0 的端点,格式为 host:port`。`host
will be the internal IP address. The`端口可以是机器上任意可用的端口。对于单节点训练/推理,此参数可以省略。
例如,如果你想在两台GPU机器上进行训练,分别是machine_0和machine_1,在第一台GPU机器machine_0上运行:
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在第二台GPU机器上运行
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
上述两个命令之间的区别在于--node_rank,如果要在每台机器上使用不同数量的GPU设备,可能是--nproc_per_node。其余部分都是相同的。有关torchrun的更多信息,请参阅此页面。
XRT的区别¶
尽管在大多数情况下,我们期望 PJRT 和 XRT 在最终用户的角度下基本互换使用(特别是在 TPU v4 上),但仍有一些细微的区别需要注意。重要的是,XRT 是围绕 TPU 节点架构设计的,因此它会在 TPU VM 上也始终启动客户端和服务器进程。因此,每批输入数据会增加额外的延迟,这是因为数据需要被序列化和反序列化以便通过网络发送。
PJRT 直接使用本地设备,无需中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档。
性能提升可能适用于受制于开销的工作负载。
在XRT中,服务器进程是唯一与TPU设备交互的进程,客户端进程没有直接访问TPU设备的权限。当对单主机TPU(例如v3-8或v4-8)进行性能分析时,通常会看到8个设备跟踪记录(每个TPU核心一个)。而使用PJRT时,每个进程拥有一块芯片,并且来自该进程的分析只会显示2个TPU核心。
由于同样的原因,使用XRT时TPU Pod的性能分析无法工作,因为服务器进程独立于用户的模型代码运行。PJRT没有这个限制,因此在一个TPU Pod中,每个进程中可以对2个TPU内核进行性能分析。
PJRT 只支持 TPU VM 架构,我们目前没有计划支持 PJRT 的 TPU Node 架构。
PJRT使得运行时配置显著简化。
xla_dist不需要用于运行TPU Pod工作负载。相反,将代码复制到每个TPU主机 ([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)) 并在每个主机上并行运行代码(例如[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))xm.rendezvous已使用XLA原生集体通信重新实现,以增强在大型TPU Pod上的稳定性。更多详情请见下方。
TPU v2/v3上的多线程¶
在TPU v2和v3上,分布式工作负载始终以多线程方式运行,因为每个TPU核心会暴露两个TPU核心作为设备,并且同一时间只能有一个进程打开一个TPU芯片。在默认配置下,xmp.spawn会自动启动尽可能多的进程(每台TPU主机4个进程),并在每个进程中创建两个线程(每个TPU核心一个线程)。
注意:在TPU v4上,每个TPU芯片被表示为一个PyTorch设备,因此分布式工作负载将在4个进程中运行,每个进程只有一个线程。这与XRT的行为相同。
在大多数情况下,这通常不需要对现有代码进行重大更改。
在大多数情况下,您需要做出的主要改变是模型初始化。
由于torch的全局RNG在各个线程之间共享,即使您在每个副本中都将torch.manual_seed设置为相同的值,结果在不同线程和运行之间也会有所不同。为了在各个副本之间获得一致的参数,您可以使用torch_xla.experimental.pjrt.broadcast_master_param将一个副本的参数广播到所有其他副本,或者从公共检查点加载每个副本的参数。
xm.rendezvous更改¶
新在 PyTorch/XLA r2.0
通过XRT,工作节点0运行一个网格主服务,所有工作节点上的进程都会通过gRPC连接到该服务。实际上,我们在使用数千个芯片的TPU Pod时发现,在单个网格主进程中运行会出现不可靠的情况,原因在于工作节点0收到的入站连接数量过多。单一客户端进程超时可能会导致失败,并迫使整个工作负载重新启动。
因此,我们已使用原生XLA集体通信重新实现xm.rendezvous,这在大型TPU集群上更加稳定且经过充分测试。这与XRT实现相比带来了两个新的约束:
因为负载必须成为XLA图的一部分,
xm.mark_step会在数据传输前后被调用。在模型代码中间调用xm.rendezvous可能会强制进行不必要的编译。因为XLA不允许集体操作在一部分工作者上运行,所以所有工作者必须参与
rendezvous。
如果您需要xm.rendezvous的行为(即在不改变XLA图和/或同步子集worker的情况下通信数据),请考虑使用
``torch.distributed.barrier` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`_
或
``torch.distributed.all_gather_object` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object>`_
与一个gloo进程组。如果您也在使用xla torch.distributed
后端,可以使用torch.new_group创建一个gloo子组。请参见
torch.distributed在 TPUs v2/v3 上不完全支持。只有部分使用xla后端的操作被实现,并且gloo可能在多线程上下文中不会按预期工作。在我们的实验中,
gloo不适合扩展到数千个TPU芯片,因此 请预期这种替代方案在大规模使用时不如使用xm.rendezvous结合 PJRT 可靠。
PJRT 和 torch.distributed¶
新在 PyTorch/XLA r2.0
当使用 PJRT 与 torch.distributed 和
[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)
时,我们强烈建议使用新的 xla:// init_method,它会通过查询运行时自动找到复制品 ID、世界大小和主 IP 地址。例如:
import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend
def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='xla://')
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)
xm.mark_step()
print(output)
if __name__ == '__main__':
torch_xla.launch(_all_gather)
注:尽管在TPU v4上不需要xla:// init_method,但仍然建议使用它。如果您使用env://,则必须将MASTER_ADDR设置为具有设备0的IP主机,这不一定是worker 0。xla:// init_method可以自动找到这个IP。
注意:对于TPU v2/v3,您仍然需要导入
torch_xla.experimental.pjrt_backend,因为
torch.distributed 对TPU v2/v3的支持仍处于实验阶段。
有关在 PyTorch/XLA 上使用 DistributedDataParallel 的更多信息,请参阅
``ddp.md` <./ddp.md>`_ 关于 TPU V4。对于一个同时使用 DDP 和 PJRT 的示例,请在 TPU 上运行以下 示例脚本:
PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1
性能¶
TorchBench 显示,在 PJRT 与 XRT 对比下,各项任务的平均训练时间都有所提升,TPU v4-8 的平均提升幅度超过 35%。不同任务和模型类型带来的好处差异显著,范围从 0% 到 175%。以下图表展示了按任务细分的详细情况:
新 TPU 运行时¶
新在 PyTorch/XLA r2.0
PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持,
用于访问基于 TFRT 的新 TPU 运行时,在 libtpu 中。现在当设置 PJRT_DEVICE=TPU 时,
这是默认的运行时。在 1.13 版本中使用的基于 StreamExecutor 的旧版 TPU 运行时在 2.0 版本中仍然可以通过 PJRT_DEVICE=TPU_LEGACY 使用,
但在未来的版本中将被移除。如果你遇到仅在 TPU 上发生而在 TPU_LEGACY 上不发生的问题,请在 GitHub 上提交问题。
在大多数情况下,我们期望两种运行时的性能相似,但在某些情况下,新运行时可能会快达30%。以下图表展示了各项任务的分解情况:
注意:本图表中显示的改进也包含在PJRT与XRT的比较中。