目录

TorchDynamo 深入探究

作者Jason Ansel

什么是警卫?

TorchDynamo 以 just-in-time 方式运行,并基于 动态属性。例如,上面的第一张图有以下内容 警卫:

GUARDS:
 - local 'a' TENSOR_MATCH
 - local 'b' TENSOR_MATCH
 - global 'torch' FUNCTION_MATCH

如果这些守卫中的任何一个失败,图形将被重新捕获,并且 重新 编译。有趣的守卫类型是 ,它 检查以下属性:TENSOR_MATCHtorch.Tensor

  • 张量的 Python 类(张量子类化等)

  • DTYPE

  • 装置

  • requires_grad

  • dispatch_key(应用了线程本地包含/排除)

  • ndim

  • 尺码*(可选)

  • 步幅* (可选)

对于大小/步幅,您可以通过设置 以下参数:

torch._dynamo.config.dynamic_shapes = True

完全专用化模式允许后端编译器假定 完全静态的图形。不幸的是,大多数后端都需要这个。 返回动态形状的运算符将在 不在动态形状模式下。

Dynamo 在做什么?

如果您想更好地了解 TorchDynamo 在做什么,您可以设置:

import torch._dynamo.config
import logging

torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.output_code = True

此代码会触发有用 (但垃圾邮件) 的打印输出。

例如,中第一个图形的打印输出为:toy_example

__compiled_fn_0 <eval_with_key>.1
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f9ca082f8a0>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}

ORIGINAL BYTECODE toy_example example.py 9
 10           0 LOAD_FAST                0 (a)
              2 LOAD_GLOBAL              0 (torch)
              4 LOAD_METHOD              1 (abs)
              6 LOAD_FAST                0 (a)
              8 CALL_METHOD              1
             10 LOAD_CONST               1 (1)
             12 BINARY_ADD
             14 BINARY_TRUE_DIVIDE
             16 STORE_FAST               2 (x)

 11          18 LOAD_FAST                1 (b)
             20 LOAD_METHOD              2 (sum)
             22 CALL_METHOD              0
             24 LOAD_CONST               2 (0)
             26 COMPARE_OP               0 (<)
             28 POP_JUMP_IF_FALSE       38

 12          30 LOAD_FAST                1 (b)
             32 LOAD_CONST               3 (-1)
             34 BINARY_MULTIPLY
             36 STORE_FAST               1 (b)

 13     >>   38 LOAD_FAST                2 (x)
             40 LOAD_FAST                1 (b)
             42 BINARY_MULTIPLY
             44 RETURN_VALUE

MODIFIED BYTECODE
  9           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                1 (b)
              6 CALL_FUNCTION            2
              8 UNPACK_SEQUENCE          2
             10 STORE_FAST               2 (x)
             12 POP_JUMP_IF_FALSE       24
             14 LOAD_GLOBAL              4 (__resume_at_30_1)
             16 LOAD_FAST                1 (b)
             18 LOAD_FAST                2 (x)
             20 CALL_FUNCTION            2
             22 RETURN_VALUE
        >>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
             26 LOAD_FAST                1 (b)
             28 LOAD_FAST                2 (x)
             30 CALL_FUNCTION            2
             32 RETURN_VALUE

GUARDS:
 - local 'a' TENSOR_MATCH
 - local 'b' TENSOR_MATCH
 - global 'torch' FUNCTION_MATCH

在顶部,您可以看到 FX 图表。 接下来,您将看到函数的原始字节码,后跟 修改了由 TorchDynamo 生成的字节码。最后,你看到了守卫 我们在上面已经介绍过了。

在修改后的字节码中, 是 (编译后的图) 的返回值。 和 都是生成的 continuation 函数,它们选择 Graph break 后向上执行(在字节码偏移量 30 和 38 处)。每 这些函数的形式为:__compiled_fn_0my_compiler()__resume_at_30_1__resume_at_38_2

__resume_at_<offset>:
    ... restore stack state if needed ...
    JUMP_ABSOLUTE <offset> into toy_example
    ... original bytecode of toy_example ...

通过生成这个 resume_at 函数,我们强制 函数在新的 Python 框架中执行,该框架以递归方式执行 触发 TorchDynamo 在执行达到该目标后重新启动其捕获 点。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源