编译的 Autograd:捕获更大的反向图用于 torch.compile¶
创建时间:2024年10月09日 | 最后更新时间:2024年10月23日 | 最后验证时间:2024年10月09日
作者: Simon Fan
编译后的autograd如何与
torch.compile交互如何使用编译后的 autograd API
如何使用
TORCH_LOGS检查日志
PyTorch 2.4
阅读 TorchDynamo 和 AOTAutograd 部分的 开始使用 PyTorch 2.x
概览¶
编译的自动微分(Autograd)是 PyTorch 2.4 中引入的一个 torch.compile 扩展,
它允许捕获更大的反向图。
当 torch.compile 捕获反向图时,它只是**部分地**捕获。AOTAutograd 组件提前捕获反向图,但有一定的限制:
前向中的图断裂会导致反向中的图断裂
反向钩子 不会被捕获
编译型 Autograd 通过直接与 autograd 引擎集成,解决了这些限制,使其能够在运行时捕获完整的反向计算图。具有这两个特性的模型应尝试使用编译型 Autograd,并且可能会观察到更好的性能。
然而,编译自动求导引入了自身的限制:
在反向传播开始时增加了运行时开销,用于缓存查找
在 dynamo 中,由于捕获范围更大,更容易出现重新编译和图中断的情况
注意
已编译的自动微分(Autograd)正在积极开发中,目前尚不与所有现有 PyTorch 功能兼容。如需了解特定功能的最新状态,请参考 已编译自动微分页面。
设置¶
在这个教程中,我们将基于这个简单的神经网络模型来举例说明。 它接受一个10维的输入向量,通过一个单一的线性层进行处理,并输出另一个10维的向量。
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
基本用法¶
在调用 torch.compile API 之前,请确保将 torch._dynamo.config.compiled_autograd 设置为 True:
model = Model()
x = torch.randn(10)
torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()
train(model, x)
在上面的代码中,我们创建了 Model 类的一个实例,并通过使用 torch.randn(10) 生成了一个随机的 10 维张量 x。
我们定义了训练循环函数 train 并用 @torch.compile 装饰器对其进行优化以提高其执行效率。
当 train(model, x) 被调用时:
Python解释器调用Dynamo,因为此调用被装饰为
@torch.compile。Dynamo 拦截 Python 字节码,模拟其执行过程,并将操作记录到一个图中。
AOTDispatcher禁用钩子并调用自动微分引擎来计算model.linear.weight和model.linear.bias的梯度,并将操作记录到图中。使用torch.autograd.Function,AOTDispatcher 重写train的前向和后向实现。Inductor 生成一个函数,对应于 AOTDispatcher 的前向和反向的优化实现。
Dynamo 通过 Python 解释器设置下一个要评估的优化函数。
Python解释器执行优化后的函数,该函数执行
loss = model(x).sum()。Python解释器执行
loss.backward(),调用自动微分引擎,由于我们设置了torch._dynamo.config.compiled_autograd = True,因此路由到编译后的自动微分引擎。Compiled Autograd 为
model.linear.weight和model.linear.bias计算梯度,并将操作记录到一个图中,包括它遇到的任何钩子。在此过程中,它会记录 AOTDispatcher 之前重写的反向传播。Compiled Autograd 随后生成一个新的函数,该函数对应于loss.backward()的完全追踪实现,并以推理模式使用torch.compile执行它。相同的步骤递归地应用于编译后的自动求导图,但这次 AOTDispatcher 不需要对图进行分割。
检查编译的自动微分日志¶
使用环境变量 TORCH_LOGS 运行脚本:
要仅打印编译的autograd图,请使用
TORCH_LOGS="compiled_autograd" python example.py要打印带有更多张量元数据和重新编译原因的图表,以牺牲性能为代价,请使用
TORCH_LOGS="compiled_autograd_verbose" python example.py
重新运行上面的代码片段,编译后的autograd图现在应该被记录到stderr。某些图节点将带有以aot0_开头的名称,这些对应于之前在AOTAutograd反向图0中提前编译的节点,例如,aot0_view_2对应于ID为0的AOT反向图中的view_2。
在下面的图像中,红色框包含了由 torch.compile 捕获的 AOT 反向图,该图是在未启用 Compiled Autograd 的情况下捕获的。
注意
这是我们将调用 torch.compile 的图,而不是优化后的图。编译的 Autograd 实际上生成了一些未优化的 Python 代码来表示整个 C++ autograd 执行过程。
使用不同的标志编译前向和后向传递¶
你可以为两次编译使用不同的编译器配置,例如,反向传播可能是一个完整图,即使正向传播中存在图断裂。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()
或者你可以使用上下文管理器,它会对其作用域内的所有 autograd 调用生效。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
loss.backward()
编译式 Autograd 解决了 AOTAutograd 的某些限制¶
前向传递中的图断裂不再必然导致反向传递中的图断裂:
@torch.compile(backend="aot_eager")
def fn(x):
# 1st graph
temp = x + 10
torch._dynamo.graph_break()
# 2nd graph
temp = temp + 10
torch._dynamo.graph_break()
# 3rd graph
return temp.sum()
x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)
# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()
# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
在第一个 torch.compile 种情况下,我们看到由于编译函数中的 2 个图中断,生成了 3 个反向图 fn。
而在第二个 torch.compile 种带有编译自动微分的情况中,尽管存在图中断,仍然追踪到了一个完整的反向图。
注意
在跟踪由Compiled Autograd捕获的反向钩子时,Dynamo仍可能发生图断裂。
反向钩子现在可以被捕获了
@torch.compile(backend="aot_eager")
def fn(x):
return x.sum()
x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
图中应该有一个 call_hook 节点,dynamo 之后会将其内联到以下内容中:
Compiled Autograd常见重新编译原因¶
由于损失值的自动求导结构发生了变化:
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
loss = op(x, x).sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的例子中,我们在每次迭代上调用不同的操作符,导致 loss 每次跟踪不同的自动微分历史。你应该会看到一些重新编译的消息:由于新的自动微分节点导致缓存未命中。
由于张量形状发生变化:
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的例子中,x 改变了形状,编译后的 autograd 会在第一次改变后将 x 标记为动态形状张量。你应该会看到重新编译的消息:由于形状变化导致缓存未命中。
结论¶
在本教程中,我们介绍了 torch.compile 的高级生态系统,包括编译后的自动微分、编译后自动微分的基础知识以及一些常见的重新编译原因。敬请关注 dev-discuss 上的深入探讨。