TorchDynamo 深入解析¶
作者: Jason Ansel
什么是保护机制?¶
TorchDynamo 采用即时编译方式,根据动态属性对图进行专门化处理。例如,上面的第一个图具有以下保护条件:
GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH
如果其中任何一个保护条件失败,图将被重新捕获并重新编译。其中有趣的保护类型是 TENSOR_MATCH,它检查以下 torch.Tensor 个属性:
张量的 Python 类(如张量子类等)
数据类型(dtype)
设备
requires_grad
dispatch_key(已应用线程本地的包含/排除)
维度数
sizes* (可选)
步长* (可选)
对于尺寸/步长,你可以通过设置以下参数来禁用此专用功能:
torch._dynamo.config.dynamic_shapes = True
完整专用模式允许后端编译器假定一个完全静态的图。不幸的是,大多数后端都需要这种模式。在非动态形状模式下,返回动态形状的操作符将触发图中断。
Dynamo在做什么?¶
如果你想更清楚地了解 TorchDynamo 的运作方式,你可以设置:
import torch._dynamo.config
import logging
torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.output_code = True
这段代码会触发有用的(但可能过多)打印输出。
例如,第一个图的toy_example
的输出为:
__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
ORIGINAL BYTECODE toy_example example.py 9
10 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
11 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
12 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
13 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
MODIFIED BYTECODE
9 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 24
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH
在顶部你可以看到 FX 图。 接下来,你将看到函数的原始字节码,随后是 TorchDynamo 生成的修改后的字节码。最后,你将看到我们上面提到的 guards。
在修改后的字节码中,__compiled_fn_0 是 my_compiler()(编译后的图)的返回值。__resume_at_30_1 和 __resume_at_38_2 都是生成的延续函数,在图中断后(字节码偏移量 30 和 38 处)继续执行。这些函数的形式如下:
__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...
通过生成这个 resume_at 函数,我们强制函数的其余部分在新的 Python 帧中执行,这会递归地触发 TorchDynamo 在首次执行到达该点时重新开始捕获。