目录

Eager 模式 + Compile API

在本文档中,我们将介绍如何使用PyTorch/XLA的新实验性eager模式与compileAPI。目标是使PyTorch/XLA的体验更加接近原生PyTorch,并简化开发过程。

背景

目前,默认情况下 PyTorch/XLA 运行在 LazyTensor 跟踪模式下。在以下代码中

import torch
import torch_xla
import torchvision

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)

# model tracing
res = model(input)

# model execution, same as `xm.mark_step`
torch_xla.sync()

实际模型编译和设备执行发生在调用 torch_xla.sync 时。这种做法有多个缺点。

  1. 用户往往会对框架何时进行跟踪何时执行感到困惑。

  2. 非核心模型代码(例如数据预处理)通常会生成一些小型的待执行任务,并将其泄露到主图(step函数)中,从而导致重新编译。整个图的重新编译通常非常昂贵。

  3. 重新编译发生的时间和原因很难调试。

为缓解上述问题,我们想引入新的用户体验,结合 eager 和 compile 的优势。

基本用法

import torch
import torch_xla
import torchvision

# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)

# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)

# Compilation and execution happens right away.
res = compiled_model(input)

请注意

  1. 目前用户必须手动启用激进模式由 torch_xla.experimental.eager_mode(True)

  2. 需要编译的代码区域应该用 torch_xla.compile 包裹。

The implementation of the torch_xla.compile 实际上非常简单,它在进入目标函数时禁用 eager 模式并开始跟踪。当目标函数返回时,它会调用 torch_xla.sync() 并重新启用 eager 模式。你可以期望使用 eager + compile API 的性能与现有的 mark_step/sync 方法相同。

推理

torch_xla.experimental.eager_mode(True)

compiled_model = torch.compile(model, backend="openxla")

建议在推理时使用 torch.compile 而不是 torch_xla.compile,以减少追踪开销。

训练

torch_xla.experimental.eager_mode(True)

def step_fn(model, data, target, loss_fn, optimizer):
    optimizer.zero_grad()
    logits = model(data)
    loss = loss_fn(logits, target)
    loss.backward()
    optimizer.step()
    return loss

step_fn = torch_xla.compile(step_fn)

在训练过程中,我们要求用户重构step_fn,因为它通常更好,可以将模型的前向、反向传播和优化器一起编译。长期目标是也使用torch.compile进行训练,但目前我们建议用户使用torch_xla.compile(出于性能原因)。

基准

我在一块v4-8芯片上运行了一种仅解码器的2层模型训练(基本上就是一个llama2),使用的是假数据,共进行了300步。以下是观察到的数值。

token/s
追踪模式(基线) 147
懒模式 65
Eager + torch_xla 编译 147

即时模式可以实现仅解码器模型的完全编译模型约45%的性能。我用来测试的训练器可以在这里找到这里。请注意,即时模式的性能非常依赖于模型。当我尝试运行resnet50时,即时模式的性能约为编译模式的1%。我们不期望用户使用即时模式来执行主要的训练循环。即时模式旨在用于处理训练/推理逻辑的非核心部分(数据预处理、随机数生成等)或进行调试。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源