PyTorch XLA 中的 TorchDynamo(torch.compile) 集成¶
TorchDynamo 是一个 Python 级别的 JIT 编译器,旨在使未修改的 PyTorch 程序更快。它为编译器后端提供了一个干净的 API 来挂接,其最大的功能是在执行 Python 字节码之前动态修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了一个实验性后端,用于推理和训练。
XLA bridge 的工作方式是,Dynamo 在识别模型模式时将提供 TorchFX 图形,而 PyTorch/XLA 将使用现有的惰性张量技术来编译 FX 图形并返回编译的函数。
集成¶
当前存在对 PyTorch/XLA 和 Dynamo 的支持,方法是将参数添加到 .例如:backend='openxla'
torch.compile
import torch
import torch_xla.core.xla_model as xm
def add(a, b):
a_xla = a.to(xm.xla_device())
b_xla = b.to(xm.xla_device())
return a_xla + b_xla
compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))
推理¶
下面是一个运行 resnet18 的小代码示例,使用torch.compile
import torch
import torchvision
import torch_xla.core.xla_model as xm
def eval_model(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
xla_resnet18, backend='openxla')
for data, _ in loader:
with torch.no_grad():
output = dynamo_resnet18(data)
使用 ,您将看到 PyTorch/XLA 在初始化期间只跟踪 resent18 模型一次,并在每次调用时执行编译后的二进制文件,而不是每次都跟踪模型。以下是在 Cloud TPU v4-8 上使用 torch bench 比较 Dynamo 和 Lazy 的推理速度分析torch.compile
dynamo_resnet18
resnet18 |2.59 RESnet50 |2.64 resnext50_32x4d |1.91 亚历克斯内特 |1.28 mobilenet_v2 |18.62 mnasnet1_0 |2.68 VGG16 系列 |1.33 BERT_pytorch |7.49 squeezenet1_1 |2.29 timm_vision_transformer |3.52 Geomean (吉奥米亚) |3.04
训练¶
PyTorch/XLA 还支持使用 Dynamo 进行训练,但它是实验性的,我们正在与 PyTorch 编译器团队合作迭代实现。下面是一个使用torch.compile
import torch
import torchvision
import torch_xla.core.xla_model as xm
def train_model(model, data, target, optimizer):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
loss = loss_fn(pred, target)
loss.backward()
optimizer.step()
return pred
def train_model_main(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
train_model, backend='openxla')
for data, target in loader:
xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)
如果您使用 Lazy 张量,我们希望每个训练步骤提取并执行 3 个图形,而不是每个训练步骤 1 个图形。以下是在 Cloud TPU v4-8 上使用Torch工作台的 Dynamo 和 Lazy 的训练速度分析。
RESnet50 |1.33 resnet18 |1.33 BERT_pytorch |3.07 resnext50_32x4d |1.43 亚历克斯内特 |1.12 mobilenet_v2 |1.4 mnasnet1_0 |1.19 VGG16 系列 |0.81 timm_vision_transformer |1.87 squeezenet1_1 |1.41 Geomean (吉奥米亚) |1.41
注意:我们运行每个模型的 fwd 和 bwd 一个步骤,然后收集 e2e 时间。在现实世界中,我们将在每个训练作业中运行多个步骤,这可以很容易地隐藏执行的跟踪成本(因为它是异步的)。在这种情况下,Lazy Tensor 的性能会好得多。
功能差距¶
我们想要指出一个差距,它阻止我们在更大规模的模型上使用 TorchDynamo。
TorchDynamo 将向前和向后跟踪到单独的图形中。对于 PyTorch/XLA,请务必让 XLA 编译器将整个步骤视为一个图形,以最好地优化速度。启动每个设备执行还有固定的开销,这使得每个训练步骤执行多个图形不太理想。
与 Lazy Tensor 相比,这种差距使其在实际训练用例中的效率较低,尤其是跟踪成本可能与训练中的执行重叠。
带走¶
TorchDynamo 为编译器后端提供了一种非常有前途的方法,可以向用户隐藏复杂性,并轻松检索图形格式的建模代码。与 PyTorch/XLA 传统的 Lazy Tensor 提取图形的方式相比,TorchDynamo 可以在每次迭代时跳过图形跟踪,从而提供更好的推理响应时间。
PyTorch/XLA 支持的大多数模型在使用新的 dynamo-xla 桥运行推理时都实现了显著的加速。我们的社区正在努力扩展支持的模型集。关于上述训练功能差距,PyTorch/XLA 社区非常高兴能够在我们即将到来的开发工作中缩小训练差距。该团队继续对 TorchDynamo 进行大量投资,并与上游合作以完善训练故事。