目录

守卫概述

从 UX 的角度来看,TorchDynamo 非常易于使用。用户作为注释调用:torchdynamo.optimize

@torchdynamo.optimize(my_compiler)
def fn_foo(bar):

其中完整的示例如下所示:

from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

这允许 TorchDynamo 捕获解释的 Python 帧,抓取 任何和所有相关信息,并尽可能加快速度。 加速来自几个地方,并且可能相当依赖于 backend(上面示例中的 my_compiler)提供,但一个 speedup 在本节中,重要的是缓存。缓存本身不是 直接加速,但关键使能阻止 重新 编译。我们使用 dynamo 挖一个洞,缓存允许我们得到 外。它使我们能够保持性能 中立性,然后启用后端 - 我们的 加速。

甚至提供了直通无操作后端:

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return gm.forward

我们可以看到 TorchDynamo 加快了 Python 的执行速度,即使在 常规 Python,而不仅仅是 PyTorch。

缓存和守卫概述

TorchDynamo 通过缓存转换(由 TorchDynamo)用户运行 字节码。当 TorchDynamo 收到要评估的帧时,它会检查帧中引用的对象是否以某种方式发生了变化,以及 not,则 TorchDynamo 会读取之前转换的用户字节码来评估它。 在本节中,我们将重点介绍如何识别帧中引用的对象是否已更改。这是一个关键的 的功能,因为它驱动整个 失效生命周期。此功能称为守卫

在非常高的级别上,流程可以总结如下:

  1. TorchDynamo 接收 Python 帧。

  2. 它转换帧 (1),将其传递给 instruction 译本。

  3. 对于在 (2) 中捕获的对象,TorchDynamo 会创建跟踪对象,这些对象 是: * 在输出图上跟踪,这是内部特化 的 torch.fx.Tracer * 守卫

  4. TorchDynamo 处理在(3)中创建的守卫对象,将它们转换为 生成的 Python 函数 check_fn,与一段代码相关联。

  5. 每当我们遇到此代码时,都会对 check_fn 进行评估 subsequent time - 如果check_fn通过并计算为 True,则 TorchDynamo 将缓存中的代码与此处遇到的代码标识为相同,并且 可以安全使用。如果失败且计算结果为 False,则 TorchDynamo 将缓存中的代码标识为无效,并且可以在 通过 recompilation 或 graph break 支持新条目。

Python 帧评估和 PEP 523

TorchDynamo 的功能基于 PEP 523

TorchDynamo 使用 _PyInterpreterState_SetEvalFrameFunc 在 Python 上安装帧评估函数。TorchDynamo 有一个钩子,其中 Python 可以在评估期间将控制权交还给我们。

我们安装的功能是 or in the case,但 glossing 现在,关于这个细微差别,让我们来看看 , 作为它的代理。convert_frameconvert_frame_assertnopython=Trueconvert_frame_assertconvert_frame

我们可以在 convert_frame.py 的第 20 行找到它, 签名如下:

def  convert_frame_assert(compiler_fn: Callable, one_graph=True):

此函数包装 Python 调用 TorchDynamo 的入口点 带框架:

def  _convert_frame_assert(frame: types.FrameType, cache_size: int):

以下是此函数的作用:

  1. 检查它之前是否看到过这个(参见:f_code 这里)并退出 如果它这样做了,就早点。code

  2. 检查代码是否为不受支持的情况。

  3. 检查 (上面的第二个 arg) 是否超过限制 在配置 .如果有,则函数 丢弃该帧并记录警告。这有助于避免常量 重新编译帧,因为它通常意味着该帧是热的 以意想不到的方式进行缓存,并且缓存它会产生不必要的开销, 因为它很可能在下次遇到时被驱逐。cache_sizecache_size_limit

  4. 传递帧,以及创建 through 字节码的函数 转换,通过 .一些关键的事情 发生在后台:InstructionTranslatortransform_code_object

    1. 新代码是通过 .transform_code_object

    2. 名为 的 FX 跟踪器是通过 生成的。outputInstructionTranslator

      这可能有点令人困惑, AS 不是 FX 跟踪器,而是其存储的 在名为 tracer 的变量中,其 output*是一个'fx'tracer。InstructionTranslator

    3. 该函数生成守卫并将它们存储在 above 上。output

    4. 该函数生成并存储它们 above.output_instructionsoutput

    5. 该函数将新生成的转换代码映射到初始代码 it 读出框架。这个映射值得记住,我们将 稍后我们将介绍守卫故障。

  5. 使用 4.1 中转换后的代码和 4.3 中的守卫, 该函数生成 GuardedCode

现在我们已经了解了帧评估,让我们回顾一下,看看它如何转动我们递给的帧 它转换为 TorchDynamo 内部类型。InstructionTranslator

指令翻译器

InstructionTranslator 做了很多!我们不会介绍以下细节 它所做的一切,但对本文档来说最重要的是,它会产生 的映射维护来自 frame 的 Defined 附加到 TorchDynamo 内部变量对象(更多关于 一会儿。 通过遍历帧的 当地人:symbolic_localsf_localssymbolic_locals

self.symbolic_locals = collections.OrderedDict(
    (k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
    for k in vars
    if k in f_locals
)

这里的重要组件是调用 到。的调用实现 代理到一个名为 的函数中,该函数反过来又构造了 实例并调用它们。更多 稍后再说。VariableBuilderVariableBuilder_wrapVariableTrackermake_guards

反过来,此映射至关重要,因为每个变量都已关联 guards,然后将其传递给 的实例,即上一节 4.2 中提到的 FX 跟踪器实例。如果 你还记得,this 存储在一个名为 is 的变量中,我们的守卫在被传递给 become 之前被存储在这里self.outputOutputGraphOutputGraphoutputGuardedCode

这是怎么做到的呢?它的核心是 一个被抽取的循环,它驱动一个函数 。InstructionTranslatorstep

step就是那个 - 一个加工步骤,正好采取一个 指令并用它做点什么

注意

这些是 TorchDynamo 处理的真实指令,这很酷。transform_code_object

注意

本节特意跳过了 dis.get_instructions 的详细信息。

对于上面的示例,下面是一些 的片段:Instruction

Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)

这是此功能的核心功能。看一下 , 然后从内部查看这个小片段opnamestep;

if not hasattr(self, inst.opname):
    unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)

正如我们所看到的,该函数检查当前类 the 是否具有与运算符名称匹配的属性集 (例如,)。如果是这样,函数将调用它,并将 整个指令对象 in.否则,该函数会将帧作为 未实现。InstructionTranslatorLOAD_CONST

对于这个例子,我们可以看到我们确实支持它, 定义相对简单:LOAD_CONST

def  LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))

我们可以看到,这个函数创建了一个类 的新实例 ,其值在我们的例子中为 -1,然后是 将其推送到堆栈上。ConstantVariable

有几十种这样的方法 - 请参阅所有 他们。通常,我们实现了尽可能多的 Python 匹配方法 字节码指令。symbolic_convert.py

在 logic downstream of 和 logic from incalling-我们现在有很多 s 和 of 当然,我们已经谈到了让后卫安静一点。让我们深入研究一下 什么是 Variables,并更接近于理解 guards。stepVariableBuilderVariableTracker

变量

A 是 的实例。 表示跟踪的 Python 本地值或堆栈值。ConstantVariableVariableTrackerVariableTracker

当在 TorchDynamo 中表示对象时,a 完全按照它所说的 - 它跟踪给定的变量。 这是一个非常灵活的课程,但有几点需要保留 介意:VariableTracker

  • 它管理围绕底层对象的关系 通过:guard

    • make_guard

    • replace_guards

    • add_guard(s)

    • propagate - propagate(*vars: List[List["VariableTracker"]])- 也许最重要的是,它结合了 所有提供的实例都传入。它访问 守卫并将这些守卫组合到自身上。VariableTracker

  • 它充当代表底层对象的代理,实现 方法获取 TorchDynamo 其余部分的 tracked object(被跟踪对象):

    • call_method

    • call_function

    • python_type

    • as_proxy

    • is/as_python_proxy

  • 它存储 类型为 的变量 ,来自 。此源类型是相对自的 包含类,帮助我们组织和记下原始 source 的来源,并帮助为 things 提供便捷的方法 比如得到这个名字,对我们来说重要的是,培养后卫。sourceSourcetorchdynamo/source.py

而这个类 () 是围绕子类 介于完整的抽象基类和完全充实的类之间 - 它留下了许多方法,提高了 - 依赖于 子。请参阅所有要实现的分支职业 Contract 和自定义行为。VariableTrackerNotImplementedErrortorchdynamo/variables/

了解我们现在所知道的,我们可以看到一个指令如何 从:disBUILD_TUPLE

BUILD_TUPLE(count)创建一个元组,使用 count 项从 stack 的 v,并将结果元组推送到堆栈上。

在我们的例子中,由于方式的不同,我们的签名会略有不同 我们创建对象,但它的要点是相同的。 我们不是传入 ,而是传入一个带有一点 额外的簿记,当然,我们还处理普通的旧账 python 对象转换为 TorchDynamo 概念:Instructioncount

def BUILD_TUPLE(self, inst):
    items = self.popn(inst.argval)
    options = VariableTracker.propagate(items)
    self.push(TupleVariable(items, **options))

以下是此代码的作用:

  1. 函数读取 ,在本例中为 类似于 pydoc 中的等效指令。argvalcounts

  2. 函数 items ,在本例中,签名是 this 暗示一个 基础合同 - 我们正在返回 。如果我们 仔细看看 和 /我们 看到 唯一被推到我们的堆栈中并从我们的堆栈中弹出的是 S。popndef  popn(self, n: int) -> List[TensorVariable]:TensorVariablessybmolic_convert.pyInstructionTranslatorBaseInstructionTranslatorVariableTracker

  1. 该函数调用 .这 从 2 个堆栈中弹出的每个物品中夺走守卫, 并递归遍历它并将所有守卫组合成 :VariableTracker.propagateoptionspy  return {      "guards": guards,  }

  2. 然后,该函数从 和 中创建 , 的新实例。这 允许我们从组成新的VariableTrackerTupleVariableitemsoptionsitemsTupleVariable

注意

第一批守卫来自哪里?增殖 是一种很好的技术,但我们需要先创造一些东西才能实现 传播。 在创建实例时调用 。这反过来又会调用 ,让它创建 警卫。VariableBuildermake_guardsVariableTrackerf_localssource

在这一切之后,字节码转换完成了,我们离目标又近了一步 到 生产 。我们现在了解了当地人如何成为 s,如何处理指令,以及守卫在哪里 被要求创造。在我们开始了解如何编写代码和 guards 组合成一个 GuardedCode 对象,我们需要挖掘一下 位到上面的那些和调用。我们 然后就可以理解,当我们制作 Guards 时发生了什么 与 instances 并列。GuardedCodeVariableTrackermake_guardsource.make_guardVariableTracker

制作守卫

守卫只是类 .让我们看看它们 更详细地。Guard

查看数据类的定义(因此,ctor signature),我们看到它有一个 name、一个 source 和一个 create 函数。

@dataclasses.dataclass
class Guard:
    name: str
    source: GuardSource
    create_fn: Callable

name 应该是变量的名称。

这里的 source 是一个枚举,指示守卫的 source 类型 属于。

注意

不要与 和其他类型的 混淆 中 ,如 上存储的 。Sourcesource.pyVariableTracker

create_fn提供了从简单的 data类实际生成要调用 了解两次调用之间是否发生了变化,以及 我们是否可以安全地从代码缓存中读取。

获取守卫实例的最常见代码路径是 通过 on 。->''source.make_guard''->''返回守卫(self.name(), self.guard_source(), fn)''make_guardsVariableTrackermake_guards

或者,在一个具体的例子中:

...
elif istype(value, range):
    guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
    return RangeVariable(value=value, guards=guards)

由于是在构建时设置的,因此这里需要做的就是将 ,提供给现场。sourceVariableTrackerfnGuardBuilder.EQUALS_MATCHcreate_fn

这必须是 上的方法。原因 这在我们的下一步中变得很明显。一旦我们有了所有的守卫 为帧创建,我们转到 和 。create_fnGuardBuilderCheckFunctionManagercompile_check_fn

在函数可以产生 , 它需要运行带有所有守卫的 , 以 产生一个,然后反过来又会一起传递 将代码转换为 .这与我们在 cache 条目,以及我们运行以了解是否要检索的相同条目 代码一起存储。作为参考,以下是该代码:convert_frameGuardedCodeCheckFunctionManagercheck_fnGuardedCodecheck_fn

static CacheEntry *create_cache_entry(CacheEntry *next,
                                      PyObject *guarded_code) {
  CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
  DEBUG_NULL_CHECK(e);
  e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
  NULL_CHECK(e->check_fn);
  e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
  NULL_CHECK(e->code);
  e->next = next;
  return e;
}

我们现在知道函数是如何使用的,以及谁来制造它,并且 它由什么组成,但我们还不知道它是如何组成的。a list of objects 成为我们稍后可以运行的函数?check_fnGuard

首先,我们迭代这些守卫:

for guard in sorted(guards or [], key=Guard.sort_key):
    if not config.guard_nn_modules and guard.is_nn_module():
        continue
    guard.create(local_builder, global_builder)

调用我们在上面的类上设置的运行(不要将其与我们正在处理的 生产,名称相似,所以可能会有点混淆)。在 我们上面的示例,我们的 IS . 所以我们现在调用它,传入 , 守卫本身, 在。guard.createcreate_fnGuardcheck_fncreate_fnGuardBuilder.EQUALS_MATCHself

签名为:def EQUALS_MATCH(self, guard: Guard):

在该功能的内部,我们可以使用 on the guard 来 取回我们的原始对象,查询其数据和类型信息, 这反过来又让我们进入了最重要的部分:附加代码。name

最简单的是,只附加一行代码:.的名称在哪里 变量,并且是值。它可能会生成如下代码:EQUALS_MATCHself.code.append(f"{ref} == {val!r}")refval

y == 2

这是一个基本示例。但是,如果我们附加一些其他类型的函数,然后将它们全部组合在每个语句之间(就像我们所做的那样),我们可能会得到一些东西 喜欢这个:GuardBuilderand

___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)

以下是此代码执行的操作:

  1. 检查.valid

  2. 类型 ID 检查

  3. A 值检查

  4. 张量检查

这成为代码 our 的核心,而 ,这反过来又是 在我们下次遇到此代码时进行评估。它 然后检查:check_fn

  1. 此代码是否仍然有效?

  2. 如果 (1),则仍然具有 ?y94367738391392

  3. 如果 (2) 仍然是 2?y

  4. 如果 (3),我们检查一下张量是否以某些特定方式发生了变化。x

如果所有这些都仍然成立,那么我们可以使用缓存的代码 除了这个 .check_fn

注意

更深入地了解这种情况是如何发生的以及在哪里发生的 您可以阅读 。static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {_eval_frame.c

如果没有,那么,我们可以继续重新编译代码,并将 ,以及一个全新的 再次在另一个后续帧上进行检查。check_fn

还有很多其他这样的函数可以获取 合并成(有时是巨大的)字符串,然后被评估为 Python 代码并存储到 .上面的例子 说明了一个简单的案例。要更好地了解此功能,请阅读 上的其他函数 ,或者更好的是,转储变量 查看正在生产的内容, 尤其是在更大的真实模型上。GuardBuildercheck_fnGuardBuildercodecompile_check_fn

总结

在本节中,我们回顾了:

  • 弱引用的作用和失效(可能很快就会成为 NN Moduleinvalidations)。.valid

  • 守卫函数 (, , etc) 的 C++ 端如何运作___check_type_id___check_tensors

  • 当守卫失败时会发生什么。

  • 如果我们生成无效的 guard 代码会发生什么。

我们介绍了用户提供的代码如何包装在 TorchDynamo 上下文中 继续在内部进行跟踪和跟踪,组织成 S 和随后的 S,以及 在处理 Python 时轮流选择和失效缓存条目 法典。VariableTrackerSourceGuardGuards

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源