TorchDynamo 深入探究¶
作者: Jason Ansel
什么是警卫?¶
TorchDynamo 以 just-in-time 方式运行,并基于 动态属性。例如,上面的第一张图有以下内容 警卫:
GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH
如果这些守卫中的任何一个失败,图形将被重新捕获,并且
重新 编译。有趣的守卫类型是 ,它
检查以下属性:TENSOR_MATCH
torch.Tensor
张量的 Python 类(张量子类化等)
DTYPE
装置
requires_grad
dispatch_key(应用了线程本地包含/排除)
ndim
尺码*(可选)
步幅* (可选)
对于大小/步幅,您可以通过设置 以下参数:
torch._dynamo.config.dynamic_shapes = True
完全专用化模式允许后端编译器假定 完全静态的图形。不幸的是,大多数后端都需要这个。 返回动态形状的运算符将在 不在动态形状模式下。
Dynamo 在做什么?¶
如果您想更好地了解 TorchDynamo 在做什么,您可以设置:
import torch._dynamo.config
import logging
torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.output_code = True
此代码会触发有用 (但垃圾邮件) 的打印输出。
例如,中第一个图形的打印输出为:toy_example
__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
ORIGINAL BYTECODE toy_example example.py 9
10 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
11 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
12 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
13 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
MODIFIED BYTECODE
9 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 24
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH
在顶部,您可以看到 FX 图表。 接下来,您将看到函数的原始字节码,后跟 修改了由 TorchDynamo 生成的字节码。最后,你看到了守卫 我们在上面已经介绍过了。
在修改后的字节码中, 是 (编译后的图) 的返回值。 和 都是生成的 continuation 函数,它们选择
Graph break 后向上执行(在字节码偏移量 30 和 38 处)。每
这些函数的形式为:__compiled_fn_0
my_compiler()
__resume_at_30_1
__resume_at_38_2
__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...
通过生成这个 resume_at 函数,我们强制 函数在新的 Python 框架中执行,该框架以递归方式执行 触发 TorchDynamo 在执行达到该目标后重新启动其捕获 点。