目录

PJRT 运行时

PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到了 PJRT 运行时,该运行时也被 JAX 使用。

如果在 PJRT 中遇到 Bug,请在 GitHub 上提交问题,并添加 runtime 标签。

PyTorch/XLA r2.1的新功能:

  • PJRT 在 PyTorch/XLA r2.1 中已经稳定!

  • 公共运行时API已从torch_xla.experimental.pjrt迁移到torch_xla.runtime

    • The pjrt:// init 方法已被重命名为 xla://,并由 torch_xla.distributed.xla_backend 注册。

    • 之前的 torch_xla.experimental.* 名称在此次发布中仍然可用,以保持兼容性。

  • torchrun 现在支持使用 init_method='xla://'

  • 通过 PJRT C API 新增的 XPU 和 Neuron 插件。

PyTorch/XLA r2.0的新功能:

  • PJRT 将会默认配置,如果你没有传递其他运行时配置。如果你继续设置 XRT 配置(XRT_TPU_CONFIG),这个更改没有任何影响

  • New TPU运行时实现libtpu提高性能最多30%。

  • xm.rendezvous 实现,可扩展到数千个TPU内核

  • [实验性] torch.distributed 支持 TPU v2 和 v3,包括 pjrt:// init_method

TL;DR

  • 要使用PJRT预览运行时,请将环境变量PJRT_DEVICE设置为CPUTPUCUDA

  • 在XRT中,所有分布式工作负载都是多进程的,每个设备一个进程。在PJRT的TPU v2和v3上,工作负载是多进程和多线程的(4个进程,每个进程有2个线程),因此你的工作负载应该是线程安全的。更多信息请参见TPU v2/v3上的多线程API指南中的多进程部分。需要注意的关键差异如下:

    • 要在线程安全的方式下初始化模型,可以在初始化后将参数广播到各个副本 (torch_xla.experimental.pjrt.broadcast_master_param),或者从公共检查点加载每个副本的参数。

    • 对于其他随机数生成,请尽可能使用 torch.Generator。 全局 torch 随机数生成器即使你在副本中设置了相同的 torch.manual_seed 也不是线程安全的。

    • 要使用 torch.distributed,导入 torch_xla.experimental.pjrt_backend 并 使用 xla:// init_method

    • 这些步骤对于GPU和TPU v4是可选的。

XRT 和 PJRT 的样本差异:

 import os

 import torch
 import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.optim as optim
 import torch.distributed as dist
 import torch_xla
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr


 def _mp_fn(index):
   device = xm.xla_device()
-  dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+  dist.init_process_group('xla', init_method='xla://')

   torch.manual_seed(42)
   model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
   model = DDP(model, gradient_as_bucket_view=True)

   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(model.parameters(), lr=.001)

   for i in range(10):
     data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()

     optimizer.step()
     xm.mark_step()

   # Print mean parameters so we can confirm they're the same across replicas
   print([p.mean() for p in model.parameters()])

 if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

   torch_xla.launch(_mp_fn)

好处

  • 简单的运行时配置:只需将 PJRT_DEVICE 设置为 TPUCPUCUDA,然后开始使用 XLA!或者,让 PJRT 基于您的环境自动选择设备。

  • 性能提升:减少的 gRPC 开销意味着端到端执行速度更快。在 TorchBench 2.0 中,我们在 TPU v4 上观察到了超过 35% 的训练时间改进。

  • Easy pod执行:只需将你的代码复制到每个TPU工作器,然后使用gcloud compute tpus tpuvm ssh --worker=all同时运行它们。

  • 更好的扩展性:消除了 XRT的参数大小限制,并支持多达2048个TPU芯片。

快速上手

要开始使用PJRT与PyTorch/XLA,您只需要设置 PJRT_DEVICE 环境变量。如果您正在使用TPU v2或v3,请继续阅读以了解TPU v2和v3以及v4之间的差异。

CPU

在任何安装了PyTorch/XLA的机器上,你可以像这样在CPU上运行我们的MNIST示例:

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

TPU

要创建一个新的TPU并安装PyTorch/XLA r2.0:

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,你可以这样运行我们的 ResNet50 示例:

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

默认情况下,PJRT 将使用所有 TPU 芯片。要只使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS:

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

Pods

在TPU Pod上,使用gcloud可以在每个TPU上并行运行您的命令:

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

Docker

您也可以使用 Docker 在容器中运行工作负载,并预先安装 PyTorch/XLA:

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

请注意,docker run 需要对主机有特权访问权限 (--privileged) 才能将 TPU 设备暴露给容器。目前,TPU pod 上的 Docker 仅支持 使用主机网络 --net=host。有关更多信息,请参阅 Cloud TPU 文档

GPU

单节点GPU训练

要使用PJRT的GPU,只需设置PJRT_DEVICE=CUDA并将GPU_NUM_DEVICES配置为主机上的设备数量。例如:

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您也可以使用 torchrun 来启动单节点多GPU训练。例如,

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上面的例子中,--nnodes 表示要使用的机器数量(物理机或虚拟机),它是 1,因为我们进行的是单节点训练。--nproc-per-node 表示要使用的 GPU 设备数量。

多节点GPU训练

请注意,此功能仅适用于cuda 12+. 类似于PyTorch如何进行多节点训练,您可以运行以下命令:

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes: 要使用的GPU机器数量。

  • --node_rank: 当前GPU机器的索引。该值可以是0, 1, …, ${NUMBER_GPU_VM}-1。

  • --nproc_per_node: 当前机器上使用的GPU设备数量。

  • –rdzv_endpoint: GPU机器中 node_rank==0 的端点,格式为 host:port`。`hostwill be the internal IP address. The`端口可以是机器上任意可用的端口。对于单节点训练/推理,此参数可以省略。

例如,如果你想在两台GPU机器上进行训练,分别是machine_0和machine_1,在第一台GPU机器machine_0上运行:

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二台GPU机器上运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

上述两个命令之间的区别在于--node_rank,如果要在每台机器上使用不同数量的GPU设备,可能是--nproc_per_node。其余部分都是相同的。有关torchrun的更多信息,请参阅此页面

XRT的区别

尽管在大多数情况下,我们期望 PJRT 和 XRT 在最终用户的角度下基本互换使用(特别是在 TPU v4 上),但仍有一些细微的区别需要注意。重要的是,XRT 是围绕 TPU 节点架构设计的,因此它会在 TPU VM 上也始终启动客户端和服务器进程。因此,每批输入数据会增加额外的延迟,这是因为数据需要被序列化和反序列化以便通过网络发送。

PJRT 直接使用本地设备,无需中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档

  • 性能提升可能适用于受制于开销的工作负载。

  • 在XRT中,服务器进程是唯一与TPU设备交互的进程,客户端进程没有直接访问TPU设备的权限。当对单主机TPU(例如v3-8或v4-8)进行性能分析时,通常会看到8个设备跟踪记录(每个TPU核心一个)。而使用PJRT时,每个进程拥有一块芯片,并且来自该进程的分析只会显示2个TPU核心。

    • 由于同样的原因,使用XRT时TPU Pod的性能分析无法工作,因为服务器进程独立于用户的模型代码运行。PJRT没有这个限制,因此在一个TPU Pod中,每个进程中可以对2个TPU内核进行性能分析。

  • PJRT 只支持 TPU VM 架构,我们目前没有计划支持 PJRT 的 TPU Node 架构。

  • PJRT使得运行时配置显著简化。xla_dist不需要用于运行TPU Pod工作负载。相反,将代码复制到每个TPU主机 ([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)) 并在每个主机上并行运行代码(例如 [gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)

  • xm.rendezvous 已使用XLA原生集体通信重新实现,以增强在大型TPU Pod上的稳定性。更多详情请见下方。

TPU v2/v3上的多线程

在TPU v2和v3上,分布式工作负载始终以多线程方式运行,因为每个TPU核心会暴露两个TPU核心作为设备,并且同一时间只能有一个进程打开一个TPU芯片。在默认配置下,xmp.spawn会自动启动尽可能多的进程(每台TPU主机4个进程),并在每个进程中创建两个线程(每个TPU核心一个线程)。

注意:在TPU v4上,每个TPU芯片被表示为一个PyTorch设备,因此分布式工作负载将在4个进程中运行,每个进程只有一个线程。这与XRT的行为相同。

在大多数情况下,这通常不需要对现有代码进行重大更改。 在大多数情况下,您需要做出的主要改变是模型初始化。 由于torch的全局RNG在各个线程之间共享,即使您在每个副本中都将torch.manual_seed设置为相同的值,结果在不同线程和运行之间也会有所不同。为了在各个副本之间获得一致的参数,您可以使用torch_xla.experimental.pjrt.broadcast_master_param将一个副本的参数广播到所有其他副本,或者从公共检查点加载每个副本的参数。

xm.rendezvous更改

新在 PyTorch/XLA r2.0

通过XRT,工作节点0运行一个网格主服务,所有工作节点上的进程都会通过gRPC连接到该服务。实际上,我们在使用数千个芯片的TPU Pod时发现,在单个网格主进程中运行会出现不可靠的情况,原因在于工作节点0收到的入站连接数量过多。单一客户端进程超时可能会导致失败,并迫使整个工作负载重新启动。

因此,我们已使用原生XLA集体通信重新实现xm.rendezvous,这在大型TPU集群上更加稳定且经过充分测试。这与XRT实现相比带来了两个新的约束:

  • 因为负载必须成为XLA图的一部分,xm.mark_step会在数据传输前后被调用。在模型代码中间调用xm.rendezvous可能会强制进行不必要的编译。

  • 因为XLA不允许集体操作在一部分工作者上运行,所以所有工作者必须参与rendezvous

如果您需要xm.rendezvous的行为(即在不改变XLA图和/或同步子集worker的情况下通信数据),请考虑使用 ``torch.distributed.barrier` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`_ 或 ``torch.distributed.all_gather_object` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object>`_ 与一个gloo进程组。如果您也在使用xla torch.distributed 后端,可以使用torch.new_group创建一个gloo子组。请参见这个例子 >。请注意这些约束:

  • torch.distributed 在 TPUs v2/v3 上不完全支持。只有部分使用 xla 后端的操作被实现,并且 gloo 可能在多线程上下文中不会按预期工作。

  • 在我们的实验中,gloo 不适合扩展到数千个TPU芯片,因此 请预期这种替代方案在大规模使用时不如使用 xm.rendezvous 结合 PJRT 可靠。

PJRT 和 torch.distributed

新在 PyTorch/XLA r2.0

当使用 PJRT 与 torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) 时,我们强烈建议使用新的 xla:// init_method,它会通过查询运行时自动找到复制品 ID、世界大小和主 IP 地址。例如:

import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  torch_xla.launch(_all_gather)

注:尽管在TPU v4上不需要xla:// init_method,但仍然建议使用它。如果您使用env://,则必须将MASTER_ADDR设置为具有设备0的IP主机,这不一定是worker 0。xla:// init_method可以自动找到这个IP。

注意:对于TPU v2/v3,您仍然需要导入 torch_xla.experimental.pjrt_backend,因为 torch.distributed 对TPU v2/v3的支持仍处于实验阶段。

有关在 PyTorch/XLA 上使用 DistributedDataParallel 的更多信息,请参阅 ``ddp.md` <./ddp.md>`_ 关于 TPU V4。对于一个同时使用 DDP 和 PJRT 的示例,请在 TPU 上运行以下 示例脚本

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

性能

TorchBench 显示,在 PJRT 与 XRT 对比下,各项任务的平均训练时间都有所提升,TPU v4-8 的平均提升幅度超过 35%。不同任务和模型类型带来的好处差异显著,范围从 0% 到 175%。以下图表展示了按任务细分的详细情况:

PJRT vs XRT

新 TPU 运行时

新在 PyTorch/XLA r2.0

PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持, 用于访问基于 TFRT 的新 TPU 运行时,在 libtpu 中。现在当设置 PJRT_DEVICE=TPU 时, 这是默认的运行时。在 1.13 版本中使用的基于 StreamExecutor 的旧版 TPU 运行时在 2.0 版本中仍然可以通过 PJRT_DEVICE=TPU_LEGACY 使用, 但在未来的版本中将被移除。如果你遇到仅在 TPU 上发生而在 TPU_LEGACY 上不发生的问题,请在 GitHub 上提交问题。

在大多数情况下,我们期望两种运行时的性能相似,但在某些情况下,新运行时可能会快达30%。以下图表展示了各项任务的分解情况:

TFRT vs StreamExecutor

注意:本图表中显示的改进也包含在PJRT与XRT的比较中。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源