Eager 模式 + Compile API¶
在本文档中,我们将介绍如何使用PyTorch/XLA的新实验性eager模式与compileAPI。目标是使PyTorch/XLA的体验更加接近原生PyTorch,并简化开发过程。
背景¶
目前,默认情况下 PyTorch/XLA 运行在 LazyTensor 跟踪模式下。在以下代码中
import torch
import torch_xla
import torchvision
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)
# model tracing
res = model(input)
# model execution, same as `xm.mark_step`
torch_xla.sync()
实际模型编译和设备执行发生在调用 torch_xla.sync 时。这种做法有多个缺点。
用户往往会对框架何时进行跟踪何时执行感到困惑。
非核心模型代码(例如数据预处理)通常会生成一些小型的待执行任务,并将其泄露到主图(step函数)中,从而导致重新编译。整个图的重新编译通常非常昂贵。
重新编译发生的时间和原因很难调试。
为缓解上述问题,我们想引入新的用户体验,结合 eager 和 compile 的优势。
基本用法¶
import torch
import torch_xla
import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)
# Compilation and execution happens right away.
res = compiled_model(input)
请注意
目前用户必须手动启用激进模式由
torch_xla.experimental.eager_mode(True)。需要编译的代码区域应该用
torch_xla.compile包裹。
The implementation of the torch_xla.compile 实际上非常简单,它在进入目标函数时禁用 eager 模式并开始跟踪。当目标函数返回时,它会调用 torch_xla.sync() 并重新启用 eager 模式。你可以期望使用 eager + compile API 的性能与现有的 mark_step/sync 方法相同。
推理¶
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
建议在推理时使用 torch.compile 而不是 torch_xla.compile,以减少追踪开销。
训练¶
torch_xla.experimental.eager_mode(True)
def step_fn(model, data, target, loss_fn, optimizer):
optimizer.zero_grad()
logits = model(data)
loss = loss_fn(logits, target)
loss.backward()
optimizer.step()
return loss
step_fn = torch_xla.compile(step_fn)
在训练过程中,我们要求用户重构step_fn,因为它通常更好,可以将模型的前向、反向传播和优化器一起编译。长期目标是也使用torch.compile进行训练,但目前我们建议用户使用torch_xla.compile(出于性能原因)。