目录

TorchDynamo 深入解析

作者: Jason Ansel

什么是保护机制?

TorchDynamo 采用即时编译方式,根据动态属性对图进行专门化处理。例如,上面的第一个图具有以下保护条件:

GUARDS:
 - local 'a' TENSOR_MATCH
 - local 'b' TENSOR_MATCH
 - global 'torch' FUNCTION_MATCH

如果其中任何一个保护条件失败,图将被重新捕获并重新编译。其中有趣的保护类型是 TENSOR_MATCH,它检查以下 torch.Tensor 个属性:

  • 张量的 Python 类(如张量子类等)

  • 数据类型(dtype)

  • 设备

  • requires_grad

  • dispatch_key(已应用线程本地的包含/排除)

  • 维度数

  • sizes* (可选)

  • 步长* (可选)

对于尺寸/步长,你可以通过设置以下参数来禁用此专用功能:

torch._dynamo.config.dynamic_shapes = True

完整专用模式允许后端编译器假定一个完全静态的图。不幸的是,大多数后端都需要这种模式。在非动态形状模式下,返回动态形状的操作符将触发图中断。

Dynamo在做什么?

如果你想更清楚地了解 TorchDynamo 的运作方式,你可以设置:

import torch._dynamo.config
import logging

torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.output_code = True

这段代码会触发有用的(但可能过多)打印输出。

例如,第一个图的toy_example 的输出为:

__compiled_fn_0 <eval_with_key>.1
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f9ca082f8a0>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}

ORIGINAL BYTECODE toy_example example.py 9
 10           0 LOAD_FAST                0 (a)
              2 LOAD_GLOBAL              0 (torch)
              4 LOAD_METHOD              1 (abs)
              6 LOAD_FAST                0 (a)
              8 CALL_METHOD              1
             10 LOAD_CONST               1 (1)
             12 BINARY_ADD
             14 BINARY_TRUE_DIVIDE
             16 STORE_FAST               2 (x)

 11          18 LOAD_FAST                1 (b)
             20 LOAD_METHOD              2 (sum)
             22 CALL_METHOD              0
             24 LOAD_CONST               2 (0)
             26 COMPARE_OP               0 (<)
             28 POP_JUMP_IF_FALSE       38

 12          30 LOAD_FAST                1 (b)
             32 LOAD_CONST               3 (-1)
             34 BINARY_MULTIPLY
             36 STORE_FAST               1 (b)

 13     >>   38 LOAD_FAST                2 (x)
             40 LOAD_FAST                1 (b)
             42 BINARY_MULTIPLY
             44 RETURN_VALUE

MODIFIED BYTECODE
  9           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                1 (b)
              6 CALL_FUNCTION            2
              8 UNPACK_SEQUENCE          2
             10 STORE_FAST               2 (x)
             12 POP_JUMP_IF_FALSE       24
             14 LOAD_GLOBAL              4 (__resume_at_30_1)
             16 LOAD_FAST                1 (b)
             18 LOAD_FAST                2 (x)
             20 CALL_FUNCTION            2
             22 RETURN_VALUE
        >>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
             26 LOAD_FAST                1 (b)
             28 LOAD_FAST                2 (x)
             30 CALL_FUNCTION            2
             32 RETURN_VALUE

GUARDS:
 - local 'a' TENSOR_MATCH
 - local 'b' TENSOR_MATCH
 - global 'torch' FUNCTION_MATCH

在顶部你可以看到 FX 图。 接下来,你将看到函数的原始字节码,随后是 TorchDynamo 生成的修改后的字节码。最后,你将看到我们上面提到的 guards。

在修改后的字节码中,__compiled_fn_0my_compiler()(编译后的图)的返回值。__resume_at_30_1__resume_at_38_2 都是生成的延续函数,在图中断后(字节码偏移量 30 和 38 处)继续执行。这些函数的形式如下:

__resume_at_<offset>:
    ... restore stack state if needed ...
    JUMP_ABSOLUTE <offset> into toy_example
    ... original bytecode of toy_example ...

通过生成这个 resume_at 函数,我们强制函数的其余部分在新的 Python 帧中执行,这会递归地触发 TorchDynamo 在首次执行到达该点时重新开始捕获。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源