Eager 模式 + 编译 API¶
在本文档中,我们将介绍如何通过 API 使用 PyTorch/XLA 的新实验模式。目标是使 PyTorch/XLA 体验与本机 PyTorch 更加一致,并使开发过程更容易。eager
compile
背景¶
目前,PyTorch/XLA 默认在 LazyTensor 跟踪模式下运行。在下面的代码中
import torch
import torch_xla
import torchvision
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)
# model tracing
res = model(input)
# model execution, same as `xm.mark_step`
torch_xla.sync()
实际的模型编译和设备执行发生在 被调用时。这种方法有多个缺点。torch_xla.sync
用户经常对框架何时进行跟踪以及框架何时执行感到困惑。
非核心模型代码(例如数据预处理)通常会生成一些小的待处理执行,这些执行会泄漏到主图(step function)中并导致重新编译。整个图的重新编译通常非常昂贵。
很难调试何时/为什么会发生重新编译。
为了缓解上述问题,我们希望引入带有 eager 和 compile 的新 UX。
基本用法¶
import torch
import torch_xla
import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)
# Compilation and execution happens right away.
res = compiled_model(input)
请注意,
当前用户必须通过 手动启用 Eager 模式。
torch_xla.experimental.eager_mode(True)
要编译的代码区域应由 .
torch_xla.compile
它的实现其实很简单,它在进入 target 函数并开始跟踪时禁用 eager 模式。它将调用 when target 函数返回并重新启用 Eager 模式。与现有方法相比,使用 + API 可以预期相同的性能。torch_xla.compile
torch_xla.sync()
eager
compile
mark_step/sync
推理¶
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
建议使用 而不是 进行推理,以减少跟踪过载。torch.compile
torch_xla.compile
训练¶
torch_xla.experimental.eager_mode(True)
def step_fn(model, data, target, loss_fn, optimizer):
optimizer.zero_grad()
logits = model(data)
loss = loss_fn(logits, target)
loss.backward()
optimizer.step()
return loss
step_fn = torch_xla.compile(step_fn)
在训练中,我们要求 User 重构 out,因为通常最好将模型的 forward 、 backward 和 optimizer 一起编译。长期目标也是用于培训,但现在我们建议用户使用 (出于性能原因)。step_fn
torch.compile
torch_xla.compile
基准¶
我在 v4-8 的单个芯片上运行一个 2 层解码器模型训练(它几乎只是一个 llama2),使用了 300 个步骤的假数据。以下是我观察到的数字。
代币 / s | |
跟踪模式(基线) | 147 |
Eager 模式 | 65 |
Eager + torch_xla 编译 | 147 |
对于仅解码器模型,Eager 模式可以实现完全编译模型的性能 ~45%。我用来测试的教练可以在这里和这里找到。请注意, eager 模式的性能非常依赖于模型。当我尝试运行 resnet50 时,急切模式的性能是编译模式的 ~1%。我们不鼓励用户使用 Eager Mode 来执行主训练循环。Eager 模式用于处理训练/推理逻辑的非核心部分(数据预处理、随机数生成等)或调试。