目录

TorchDynamo(torch.compile) 在 PyTorch XLA 中的集成

TorchDynamo 是一个 Python 级别的 JIT 编译器,旨在使未修改的 PyTorch 程序更快。它提供了一个干净的 API 供编译后端插入,并且它的最大特点是动态修改 Python 字节码,在执行前进行调整。在 pytorch/xla 2.0 发布中,PyTorch/XLA 提供了一个实验性的 TorchDynamo 后端,适用于推理和训练。

XLA桥接的工作方式是,Dynamo在识别出模型模式后会提供一个TorchFX图,而PyTorch/XLA则会利用现有的惰性张量技术编译FX图,并返回编译后的函数。

集成

对 PyTorch/XLA 和 Dynamo 的支持目前通过向 backend='openxla' 添加参数来实现 torch.compile。例如:

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))

推理

这里是一个运行 resnet18 的小代码示例:torch.compile

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
    xla_resnet18, backend='openxla')
  for data, _ in loader:
    with torch.no_grad():
      output = dynamo_resnet18(data)

With the torch.compile您将会看到,PyTorch/XLA 只在初始化时跟踪一次resent18模型,并且每次调用dynamo_resnet18时都会执行编译后的二进制文件,而不是每次都重新跟踪模型。这里是对比Dynamo和Lazy在Cloud TPU v4-8上使用torch bench进行推理速度分析的结果。

resnet18 | 2.59 resnet50 | 2.64 resnext50_32x4d | 1.91 alexnet | 1.28 mobilenet_v2 | 18.62 mnasnet1_0 | 2.68 vgg16 | 1.33 BERT_pytorch | 7.49 squeezenet1_1 | 2.29 timm_vision_transformer | 3.52 geomean | 3.04

训练

PyTorch/XLA 也支持 Dynamo 进行训练,但它是实验性的,我们正在与 PyTorch 编译器团队合作完善实现。以下是一个使用 torch.compile 训练 ResNet18 的示例。

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target, optimizer):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  optimizer.step()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='openxla')
  for data, target in loader:
    xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
    output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)

如果我们使用Lazy张量,我们期望每一步训练提取并执行3个图,而不是1个图。以下是在Cloud TPU v4-8上使用torch bench进行Dynamo与Lazy训练速度分析的比较。

resnet50 | 1.33 resnet18 | 1.33 BERT_pytorch | 3.07 resnext50_32x4d | 1.43 alexnet | 1.12 mobilenet_v2 | 1.4 mnasnet1_0 | 1.19 vgg16 | 0.81 timm_vision_transformer | 1.87 squeezenet1_1 | 1.41 geomean | 1.41

NOTE: We run each model’s fwd and bwd for a single step and then collect the e2e time. In the real world we will run multiple steps at each training job which can easily hide the tracing cost from execution(since it is async). Lazy Tensor will have much better performance in that scenario.

Feature gaps

我们想指出一个目前阻碍我们使用TorchDynamo处理大规模模型的问题。

  1. TorchDynamo 将分别追踪前向和后向过程。对于 PyTorch/XLA 来说,让 XLA 编译器将整个步骤视为一个单独的图形以最佳优化速度是非常重要的。此外,每次设备执行都有固定的开销,这使得在每个训练步骤中执行多个图形的效果不太理想。

与 Lazy Tensor 相比,这个差距使其在实际训练场景中效率较低,尤其是在训练过程中,跟踪成本可能会与执行过程重叠。

Take away

TorchDynamo 提供了一种非常有前景的方法,可以让编译后端隐藏复杂性,并且轻松地以图形格式检索建模代码。与 PyTorch/XLA 传统的惰性张量方式提取图形相比,TorchDynamo 可以在每次迭代时跳过图形跟踪,从而提供更好的推理响应时间。

大多数由PyTorch/XLA支持的模型,在使用新的dynamo-xla桥进行推理时,运行速度有了显著提升。我们的社区正努力扩大支持的模型范围。关于上述提到的训练功能差距,PyTorch/XLA社区对在即将开展的工作中改进训练功能非常兴奋。团队将继续大力投资于TorchDynamo,并与上游合作完善训练故事。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源