目录

PJRT 运行时

PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到 PJRT 运行时

如果您在使用 PJRT 时遇到错误,请在 GitHub 上提交带有标签的问题。runtime

PyTorch/XLA r2.1 的新功能

  • PJRT 在 PyTorch/XLA r2.1 中是稳定的!

  • 公共运行时 API 已从 迁移到 。torch_xla.experimental.pjrttorch_xla.runtime

    • init 方法已重命名为 ,并且已注册 由。pjrt://xla://torch_xla.distributed.xla_backend

    • 以前的名称在此 release 以实现兼容性。torch_xla.experimental.*

  • torchrun现在支持在使用 .init_method='xla://'

  • 通过 PJRT C API 为 XPU 和 Neuron 提供新插件。

PyTorch/XLA r2.0 的新功能

  • 如果您不传入任何其他运行时,则默认情况下将配置 PJRT 配置。如果继续设置 XRT 配置 (), 此更改没有影响XRT_TPU_CONFIG

  • 新的 TPU 运行时实现将性能提高了 30%。libtpu

  • 可扩展至数千个 TPU 内核的新实施xm.rendezvous

  • 对 TPU v2 和 v3 的 [实验性] 支持,包括torch.distributedpjrt:// init_method

TL;博士

  • 要使用 PJRT 预览运行时,请将环境变量设置为 、 或PJRT_DEVICECPUTPUCUDA

  • 在 XRT 中,所有分布式工作负载都是多进程的,每个工作负载有一个进程 装置。在 PJRT 的 TPU v2 和 v3 上,工作负载是多进程和多线程的 (4 个进程,每个进程 2 个线程),因此您的工作负载应该是线程安全的。请参阅 TPU v2/v3 上的多线程处理和 API 的多处理部分 指南了解更多信息。需要记住的主要区别:

    • 要以线程安全的方式初始化模型,请广播参数 初始化后跨副本 () 或加载每个 来自公共检查点的 replica 参数。torch_xla.experimental.pjrt.broadcast_master_param

    • 对于其他随机数生成,请尽可能使用。 全局 RNG 不是线程安全的,即使您在副本之间设置相同也是如此。torch.Generatortorchtorch.manual_seed

    • 要使用 、 import 和 使用 .torch.distributedtorch_xla.experimental.pjrt_backendxla://init_method

    • 对于 GPU 和 TPU v4,这些步骤是可选的。

从 XRT 到 PJRT 的样本差异:

 import os

 import torch
 import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.optim as optim
 import torch.distributed as dist
 import torch_xla
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr


 def _mp_fn(index):
   device = xm.xla_device()
-  dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+  dist.init_process_group('xla', init_method='xla://')

   torch.manual_seed(42)
   model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
   model = DDP(model, gradient_as_bucket_view=True)

   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(model.parameters(), lr=.001)

   for i in range(10):
     data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()

     optimizer.step()
     xm.mark_step()

   # Print mean parameters so we can confirm they're the same across replicas
   print([p.mean() for p in model.parameters()])

 if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

   torch_xla.launch(_mp_fn)

好处

  • 简单的运行时配置:只需设置为 、 或 ,即可开始使用 XLA!或者,让 PJRT 根据您的 环境。PJRT_DEVICETPUCPUCUDA

  • 提高性能:减少 gRPC 的开销意味着更快的端到端 执行。在 TorchBench 2.0 上,我们观察到训练时间缩短了 >35% 在 TPU v4 上。

  • 轻松执行 Pod:只需将代码复制到每个 TPU 工作线程,然后执行它们 所有这些都与 .gcloud compute tpus tpuvm ssh --worker=all

  • 更好的缩放:消除了 XRT 对参数的限制 尺寸和支持高达 2048 TPU 芯片。

快速入门

要开始将 PJRT 与 PyTorch/XLA 一起使用,您需要做的就是设置环境变量。如果您正在使用 TPU v2 或 v3,请保留 阅读以了解 TPU v2 与 v3 和 v4 之间的区别。PJRT_DEVICE

中央处理器

在任何安装了 PyTorch/XLA 的机器上,您都可以在 CPU 上运行我们的 MNIST 示例 喜欢这个:

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

热塑性聚氨酯

要创建安装了 PyTorch/XLA r2.0 的新 TPU,请执行以下作:

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,您可以运行我们的 ResNet50 示例,如下所示:

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置并:TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

豆荚

在 TPU Pod 上,用于在每个 TPU 上并行运行命令:gcloud

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

码头工人

您还可以使用 Docker 在具有 PyTorch/XLA 的容器中运行工作负载 预装:

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

请注意,需要对主机 () 的特权访问权限 将 TPU 设备公开给容器。仅支持 TPU Pod 上的 Docker 此时使用 Host Networking。有关更多信息,请参阅 Cloud TPU 文档docker run--privileged--net=host

图形处理器

单节点 GPU 训练

要将 GPU 与 PJRT 一起使用,只需设置并配置为主机上的设备数量即可。例如:PJRT_DEVICE=CUDAGPU_NUM_DEVICES

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您还可以使用 来启动单节点多 GPU 训练。例如torchrun

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上面的示例中,表示要使用多少台机器(物理机或 VM)(由于我们进行单节点训练,因此为 1)。 表示要使用的 GPU 设备数量。--nnodes--nproc-per-node

多节点 GPU 训练

请注意,此功能仅适用于 cuda 12+。与 PyTorch 使用多节点训练的方式类似,您可以按如下方式运行命令:

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes:要使用的 GPU 计算机数量。

  • --node_rank:当前 GPU 计算机的索引。该值可以是 0, 1, ..., ${NUMBER_GPU_VM}-1。

  • --nproc_per_node:当前计算机上要使用的 GPU 设备数。

  • –rdzv_endpoint:GPU 计算机的端点,node_rank==0,格式为 host:port'。''hostport' 可以是机器上的任何可用端口。对于单节点训练/推理,可以省略此参数。will be the internal IP address. The

例如,如果要在 2 台 GPU 计算机(machine_0 和 machine_1)上进行训练,请在第一台 GPU 计算机machine_0上运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二台 GPU 计算机上,运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

如果您想在每台计算机上使用不同数量的 GPU 设备,则上述 2 个命令之间的差异是 和 可能。其余的都是相同的。有关 的更多信息,请参阅此页面。--node_rank--nproc_per_nodetorchrun

与 XRT 的区别

尽管在大多数情况下,我们预计 PJRT 和 XRT 大多可以互换工作 从最终用户的角度来看(尤其是在 TPU v4 上),有一些微妙的 需要牢记的重要差异。重要的是,XRT 的设计 围绕 TPU Node 架构,因此它总是会生成一个客户端和一个服务器 进程,即使在 TPU 虚拟机上也是如此。因此,每批输入都有额外的延迟 从序列化和反序列化数据到通过网络发送数据。

PJRT 直接使用本地设备,无需中间服务器进程。在 default 配置,PJRT 会为每个 TPU 芯片创建一个进程,或者 4 个进程 每个 TPU 主机。查看 Cloud TPU 文档 有关 TPU 架构的更多信息。

  • 对于受 限制开销的工作负载,性能可能会提高。

  • 在 XRT 下,服务器进程是唯一与 TPU 交互的进程 设备和客户端进程无法直接访问 TPU 设备。 在分析单主机 TPU(例如 v3-8 或 v4-8)时,您通常会看到 8 设备跟踪记录(每个 TPU 内核一个)。使用 PJRT,每个过程都有一个芯片, 并且该进程的配置文件将仅显示 2 个 TPU 内核。

    • 出于同样的原因,性能分析不适用于使用 XRT 的 TPU Pod,因为 服务器进程独立于用户的模型代码运行。PJRT 确实 没有该约束,因此可以分析 2 个 TPU 内核 进程。

  • PJRT 仅支持 TPU VM 架构,我们没有计划支持 使用 PJRT 的 TPU 节点架构。

  • 使用 PJRT 的运行时配置要简单得多。 莫 运行 TPU Pod 工作负载所需的。相反,请将您的代码复制到每个 TPU 主机 () 并在每个主机上并行运行代码(例如xla_dist[gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))

  • xm.rendezvous已使用 XLA 原生 collective 重新实现 通信以增强大型 TPU Pod 的稳定性。更多信息见下文 详。

TPU v2/v3 上的多线程处理

在 TPU v2 和 v3 上,分布式工作负载始终以多线程方式运行,因为每个 TPU 核心将两个 TPU 核心作为设备公开,并且只有一个进程可以打开 TPU Chip 的 Fragment S 的 Intent S S在其默认配置中,会自动生成 尽可能多的进程(每个 TPU 主机 4 个),并为每个 工艺(每个 TPU 内核一个)。xmp.spawn

注意:在 TPU v4 上,每个 TPU 芯片表示为一个 PyTorch 设备,因此 分布式工作负载将在 4 个进程中运行,每个进程只有一个线程。 这与 XRT 的行为相同。

在大多数情况下,这不需要对现有代码进行大量更改。 在大多数情况下,您必须进行的主要更改是建模初始化。 由于 的全局 RNG 在线程之间共享,因此结果会有所不同 在线程之间运行,即使您设置为相同的值 在每个副本中。要在副本之间获得一致的参数,请使用广播一个副本的 参数添加到所有其他副本中,或者从 common checkpoint 的 Intent Barrier 中。torchtorch.manual_seedtorch_xla.experimental.pjrt.broadcast_master_param

对 xm.rendezvous 的更改

PyTorch/XLA r2.0 中的新增功能

使用 XRT 时,worker 0 运行一个网格主站服务,所有 worker 上的所有进程 通过 gRPC 连接到该服务。在实践中,我们发现,运行单个 Mesh Master 工艺在具有数千个芯片的 TPU pod 上不可靠,因为 到 worker 0 的入站连接数。单个客户端进程计时 out 可能会导致失败并强制整个工作负载重新启动。

因此,我们使用原生 XLA 集合体重新实现 通信,这在大型 TPU Pod 上更加稳定且经过了充分测试。这 与 XRT 实现相比,施加了两个新约束:xm.rendezvous

  • 由于有效负载必须成为 XLA 图形的一部分,因此 在传输数据之前和之后调用。在模型代码中间调用可能会强制进行不需要的编译。xm.mark_stepxm.rendezvous

  • 由于 XLA 不允许集合作在 worker 的 .rendezvous

如果您需要 (即传输数据 而无需更改 XLA 图形和/或同步工作程序子集), 考虑使用 ''torch.distributed.barrier' <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>'_ 或 ''torch.distributed.all_gather_object' <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object>'_ 替换为进程组。如果您还使用后端,则可以使用 创建子组。查看此内容 来自 PyTorch 文档的示例。请记住以下限制:xm.rendezvousglooxlatorch.distributedtorch.new_groupgloo

  • torch.distributed在 TPU v2/v3 上不完全支持。只有 使用后端的作已实现,并且可能不会实现 在多线程上下文中按预期工作。xlagloo

  • 在我们的实验中,不能很好地扩展到数千个 TPU 芯片,因此 预计此替代方案不如使用 PJRT 在大尺度上。glooxm.rendezvous

PJRT 和 torch.distributed

PyTorch/XLA r2.0 中的新增功能

当使用 PJRT 时,我们强烈建议使用新的 ,它会自动 通过查询运行时来查找副本 ID、世界大小和主 IP。为 例:torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)xla://init_method

import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  torch_xla.launch(_all_gather)

注意:尽管 TPU v4 不需要 init_method,但它仍然是必需的 推荐。如果使用 ,则必须将其设置为具有 设备 0,它并不总是工作线程 0。init_method 发现这个 IP 自动。xla://env://MASTER_ADDRxla://

注意:对于 TPU v2/v3,您仍然需要导入,因为 中的 TPU v2/v3 支持仍处于试验阶段。torch_xla.experimental.pjrt_backendtorch.distributed

有关在 PyTorch/XLA 上使用的更多信息,请参阅 TPU V4 上的 ddp.md<./ddp.md>'_。对于同时使用 DDP 和 PJRT 的示例, 在 TPU 上运行以下示例脚本DistributedDataParallel

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

性能

TorchBench 显示使用 PJRT 后各项任务的平均训练时间有所改善 与 XRT 相比,TPU v4-8 的平均改进超过 35%。这 收益因任务和模型类型而异,从 0% 到 175% 不等。 下图显示了按任务划分的细分:

PJRT 与 XRT

新的 TPU 运行时

PyTorch/XLA r2.0 中的新增功能

PyTorch/XLA r2.0 版本引入了对 PJRT 插件的支持 API / 用于在 .现在是 default runtime when 已设置。旧版的基于 StreamExecutor 的 1.13 中使用的 TPU 运行时在 2.0 版本中仍可用,但将在将来的版本中删除。如果您遇到 仅 on 上发生的问题,请提交 Issue 在 GitHub 上。libtpuPJRT_DEVICE=TPUPJRT_DEVICE=TPU_LEGACYTPUTPU_LEGACY

在大多数情况下,我们希望两个运行时之间的性能相似,但 在某些情况下,新的运行时间可能会提高 30%。下图 按任务显示细分:

TFRT 与 StreamExecutor

注意:此图表中显示的改进也包含在 PJRT 与 XRT 中 比较。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源