TorchScript¶
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 进程中保存,并在没有 Python 依赖的进程中加载。
我们提供工具,逐步将模型从纯 Python 程序过渡到可以独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中运行。 这使得可以在使用熟悉的 Python 工具训练模型后,通过 TorchScript 将模型导出到生产环境中,在该环境中,由于性能和多线程等原因,Python 程序可能处于不利地位。
对于TorchScript的简明介绍,请参阅TorchScript简介教程。
要将PyTorch模型转换为TorchScript并在C++中运行的完整示例,请参阅 在C++中加载PyTorch模型 教程。
创建 TorchScript 代码¶
脚本化一个函数或 |
|
绘制一个函数并返回一个可执行的代码或 |
|
在第一次调用时编译 |
|
跟踪一个模块并返回一个可执行的 |
|
创建一个异步任务来执行 func 并引用此执行结果的值。 |
|
强制完成一个 torch.jit.Future[T] 异步任务,返回该任务的结果。 |
|
围绕 C++ 的包装器 |
|
功能上等同于一个 |
|
冻结一个 |
|
对模型执行一组优化操作,以优化其推理性能。 |
|
设置在融合过程中可能发生的专业化类型和数量。 |
|
保存此模块的离线版本以供单独进程使用。 |
|
加载之前使用 |
|
此装饰器指示编译器忽略该函数或方法,并将其保留在 Python 函数形式。 |
|
此装饰器指示编译器忽略该函数或方法,并用抛出异常来替代。 |
|
此函数为 TorchScript 提供容器类型细化功能。 |
|
此方法是一个透传函数,返回value,主要用于指示TorchScript编译器左侧表达式是一个类型为type的类实例属性。 |
|
此方法是一个直通函数,返回 the_value,用于提示 TorchScript 编译器 the_value 的类型。 |
混合追踪和脚本化¶
在许多情况下,无论是通过追踪还是脚本编写,将模型转换为TorchScript的方法都更为简便。 追踪和脚本编写可以结合使用,以满足模型某一部分的特定需求。
脚本函数可以调用追踪函数。这在你需要在简单的前馈模型周围使用控制流时特别有用。例如,序列到序列模型的束搜索通常会以脚本形式编写,但可以调用使用追踪生成的编码器模块。
示例(在脚本中调用一个经过追踪的函数):
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
可追溯函数可以调用脚本函数。这在模型的某一小部分需要一些控制流时很有用,即使大部分模型只是一个前馈网络。由可追溯函数调用的脚本函数内部的控制流会被正确保留。
示例(在追踪函数中调用脚本函数):
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
此组合同样适用于nn.Modules,其中可以使用通过跟踪生成的子模块,并可以从脚本模块的方法中调用。
示例(使用跟踪模块):
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript 语言¶
TorchScript 是 Python 的静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。详见完整的 TorchScript 语言参考。
内置函数和模块¶
TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。 请参阅 TorchScript 内置函数 以获取支持的函数的完整参考。
PyTorch函数和模块¶
TorchScript 支持 PyTorch 提供的部分张量和神经网络函数。大多数 Tensor 的方法以及 torch 命名空间中的函数、torch.nn.functional 中的所有函数,还有来自 torch.nn 的大部分模块都在 TorchScript 中得到了支持。
请参阅TorchScript 不支持的 PyTorch 构造,以获取不支持的 PyTorch 函数和模块列表。
Python 函数和模块¶
Python 的许多 内置函数 在 TorchScript 中都受支持。
math 模块也受支持(详情请参见 math 模块),但其他 Python 模块(内置或第三方)不受支持。
Python 语言参考对比¶
有关支持的 Python 功能的完整列表,请参阅 Python 语言参考覆盖范围。
调试¶
禁用JIT进行调试¶
-
PYTORCH_JIT¶
设置环境变量 PYTORCH_JIT=0 将禁用所有脚本和跟踪注释。如果你的某个
TorchScript 模型中出现难以调试的错误,可以使用此标志强制使用原生 Python 运行一切。由于此标志禁用了 TorchScript(脚本和跟踪),你可以使用像 pdb 这样的工具来调试模型代码。例如:
@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x
def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
通过调试此脚本 pdb 可行,除了当我们调用
@torch.jit.script 函数时。我们可以全局禁用
JIT,以便可以像调用普通 Python 函数一样调用 @torch.jit.script 函数而不进行编译。如果上述脚本名为 disable_jit_example.py,我们可以这样调用它:
$ PYTORCH_JIT=0 python disable_jit_example.py
并且我们将能够像调用普通Python函数一样进入@torch.jit.script函数。要禁用特定函数的TorchScript编译器,请参见@torch.jit.ignore。
检查代码¶
TorchScript 提供了一个代码美化器,用于所有 ScriptModule 实例。这个美化器将脚本方法的代码解释为有效的
Python 语法。例如:
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.code)
一个ScriptModule带有一个单独的forward方法将具有一个属性
code,你可以用它来检查ScriptModule的代码。
如果ScriptModule有多个方法,你需要访问
该方法本身的.code而不是模块。我们可以通过访问.foo.code来检查
名为foo的方法在一个ScriptModule上的代码。
上面的例子会产生这样的输出:
def foo(len: int) -> Tensor:
rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
rv0 = rv
for i in range(len):
if torch.lt(i, 10):
rv1 = torch.sub(rv0, 1., 1)
else:
rv1 = torch.add(rv0, 1., 1)
rv0 = rv1
return rv0
这是TorchScript对forward方法代码的编译。
你可以使用它来确保TorchScript(跟踪或脚本编写)正确捕获了你的模型代码。
解释图表¶
TorchScript 还有一种比代码美化打印器更低层次的表示形式,即以 IR 图的形式呈现。
TorchScript 使用静态单赋值 (SSA) 中间表示 (IR) 来表示计算。这种格式的指令由 ATen(PyTorch 的 C++ 后端)操作符和其他基本操作符组成,包括用于循环和条件的操作符。例如:
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.graph)
graph 遵循 检查代码 部分中描述的相同规则
关于 forward 方法查找。
上述示例脚本生成了以下图形:
graph(%len.1 : int):
%24 : int = prim::Constant[value=1]()
%17 : bool = prim::Constant[value=1]() # test.py:10:5
%12 : bool? = prim::Constant()
%10 : Device? = prim::Constant()
%6 : int? = prim::Constant()
%1 : int = prim::Constant[value=3]() # test.py:9:22
%2 : int = prim::Constant[value=4]() # test.py:9:25
%20 : int = prim::Constant[value=10]() # test.py:11:16
%23 : float = prim::Constant[value=1]() # test.py:12:23
%4 : int[] = prim::ListConstruct(%1, %2)
%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
block0(%i.1 : int, %rv.14 : Tensor):
%21 : bool = aten::lt(%i.1, %20) # test.py:11:12
%rv.13 : Tensor = prim::If(%21) # test.py:11:9
block0():
%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
-> (%rv.3)
block1():
%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
-> (%rv.6)
-> (%17, %rv.13)
return (%rv)
以指令 %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 为例。
%rv.1 : Tensor表示我们将输出分配给一个(唯一)名为rv.1的值,该值的类型为Tensor,并且我们不知道它的具体形状。aten::zeros是操作符(等同于torch.zeros),输入列表(%4, %6, %6, %10, %12)指定了作用域中应作为输入传递的值。内置函数(如aten::zeros)的模式可以在 内置函数中找到。# test.py:9:10是原始源文件中生成此指令的位置。在这种情况下,它是一个名为test.py的文件,在第9行,第10个字符处。
注意,操作符也可以有相关的blocks,即prim::Loop和prim::If操作符。在图形输出中,这些操作符会被格式化为与其等效的源代码形式,以方便调试。
可以通过如上所示的图表检查确认由ScriptModule描述的计算是否正确,无论是自动化还是手动方式,如下所述。
跟踪器¶
追踪边缘情况¶
在某些边缘情况下,给定 Python 函数/模块的调用栈可能无法代表底层代码。这些情况包括:
依赖于输入(例如张量形状)的控制流跟踪
张量视图的就地操作跟踪(例如,在赋值左侧进行索引)
请注意,这些情况将来可能会有迹可循。
自动跟踪检查¶
一种自动捕获许多跟踪错误的方法是使用check_inputs
在torch.jit.trace() API上。check_inputs接受一个元组列表
作为输入,这些输入将用于重新跟踪计算并验证结果。例如:
def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
提供给我们以下诊断信息:
ERROR: Graphs differed across invocations!
Graph diff:
graph(%x : Tensor) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Tensor = aten::select(%x, %4, %5)
%result.2 : Tensor = aten::mul(%result.1, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Tensor = aten::select(%x, %8, %9)
- %result : Tensor = aten::mul(%result.2, %10)
+ %result.3 : Tensor = aten::mul(%result.2, %10)
? ++
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Tensor = aten::select(%x, %12, %13)
+ %result : Tensor = aten::mul(%result.3, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Tensor = aten::select(%x, %16, %17)
- %15 : Tensor = aten::mul(%result, %14)
? ^ ^
+ %19 : Tensor = aten::mul(%result, %18)
? ^ ^
- return (%15);
? ^
+ return (%19);
? ^
}
此消息表明,我们在首次追踪和使用check_inputs进行追踪时的计算有所不同。
实际上,在loop_in_traced_fn主体内的循环依赖于输入x的形状,
因此当我们尝试另一个具有不同形状的x时,追踪结果会有所不同。
在这种情况下,可以使用
torch.jit.script() 来捕获这样的数据依赖控制流:
def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:
torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))
这会产生:
graph(%x : Tensor) {
%5 : bool = prim::Constant[value=1]()
%1 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %1)
%4 : int = aten::size(%x, %1)
%result : Tensor = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Tensor) {
%10 : Tensor = aten::select(%x, %1, %i)
%result.2 : Tensor = aten::mul(%7, %10)
-> (%5, %result.2)
}
return (%result);
}
追踪警告¶
跟踪器会对追踪计算中出现的几种有问题的模式发出警告。例如,考虑一个函数的跟踪记录,该函数在一个张量的切片(视图)上进行了就地赋值操作:
def fill_row_zero(x):
x[0] = torch.rand(*x.shape[1:2])
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
产生多个警告并绘制一个简单的图形,该图形只是返回输入:
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
我们可以通过修改代码来避免使用就地更新,而是通过torch.cat构建结果张量:
def fill_row_zero(x):
x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
常见问题解答¶
我想在 GPU 上训练模型并在 CPU 上进行推理。有哪些最佳实践?
First convert your model from GPU to CPU and then save it, like so:
cpu_model = gpu_model.cpu() sample_input_cpu = sample_input_gpu.cpu() traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) torch.jit.save(traced_cpu, "cpu.pt") traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) torch.jit.save(traced_gpu, "gpu.pt") # ... later, when using the model: if use_gpu: model = torch.jit.load("gpu.pt") else: model = torch.jit.load("cpu.pt") model(input)This is recommended because the tracer may witness tensor creation on a specific device, so casting an already-loaded model may have unexpected effects. Casting the model before saving it ensures that the tracer has the correct device information.
问:如何在一个 ScriptModule 上存储属性?
Say we have a model like:
import torch class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.x = 2 def forward(self): return self.x m = torch.jit.script(Model())If
Modelis instantiated it will result in a compilation error since the compiler doesn’t know aboutx. There are 4 ways to inform the compiler of attributes onScriptModule:1.
nn.Parameter- Values wrapped innn.Parameterwill work as they do onnn.Modules2.
register_buffer- Values wrapped inregister_bufferwill work as they do onnn.Modules. This is equivalent to an attribute (see 4) of typeTensor.3. Constants - Annotating a class member as
Final(or adding it to a list called__constants__at the class definition level) will mark the contained names as constants. Constants are saved directly in the code of the model. See builtin-constants for details.4. Attributes - Values that are a supported type can be added as mutable attributes. Most types can be inferred but some may need to be specified, see module attributes for details.
我想跟踪模块的方法,但我总是收到这个错误:
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
This error usually means that the method you are tracing uses a module’s parameters and you are passing the module’s method instead of the module instance (e.g.
my_module_instance.forwardvsmy_module_instance).
Invoking
tracewith a module’s method captures module parameters (which may require gradients) as constants.On the other hand, invoking
tracewith module’s instance (e.g.my_module) creates a new module and correctly copies parameters into the new module, so they can accumulate gradients if required.To trace a specific method on a module, see
torch.jit.trace_module
已知问题¶
如果你在使用 Sequential 和 TorchScript 时,某些 Sequential 子模块的输入可能会被错误地推断为 Tensor,即使它们有其他注解。标准解决方案是子类化 nn.Sequential 并重新声明 forward,以正确指定输入类型。
附录¶
迁移到 PyTorch 1.2 递归脚本 API¶
本节详细介绍了 PyTorch 1.2 中 TorchScript 的更改。如果您是第一次接触 TorchScript,可以跳过本节。PyTorch 1.2 对 TorchScript API 进行了两项主要更改。
1. torch.jit.script 将会尝试递归编译它遇到的函数、方法和类。一旦你调用 torch.jit.script,编译是“退出”模式,而不是“进入”模式。
2. torch.jit.script(nn_module_instance) 现在是创建
ScriptModule的首选方式,而不是继承自torch.jit.ScriptModule。
这些更改结合起来提供了一个更简单、更易于使用的API,用于将您的nn.Module转换为ScriptModule,准备在非Python环境中进行优化和执行。
新用法如下所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
my_model = Model()
my_scripted_model = torch.jit.script(my_model)
该模块的
forward默认情况下会被编译。从forward调用的方法会在forward中按使用顺序惰性编译。要编译除
forward以外且不是从forward调用的方法,请添加@torch.jit.export。为了阻止编译器编译某个方法,请添加
@torch.jit.ignore或@torch.jit.unused。@ignore保留该方法被视为对python的调用,并且
@unused用一个异常替换了它。@ignored不能被导出;@unused可以。大多数属性类型可以被推断,因此
torch.jit.Attribute是不必要的。对于空的容器类型,使用 PEP 526 风格的类注释来标注它们的类型。常量可以用
Final类注释标记,而不是将成员名称添加到__constants__中。Python 3 类型提示可以替代
torch.jit.annotate
- As a result of these changes, the following items are considered deprecated and should not appear in new code:
装饰器
@torch.jit.script_method继承自
torch.jit.ScriptModule的类The
torch.jit.Attribute封装类数组
__constants__函数
torch.jit.annotate
模块¶
警告
在PyTorch 1.2中,@torch.jit.ignore 注解的行为发生了变化。
在PyTorch 1.2之前,使用@ignore装饰器可以让一个函数或方法从导出的代码中调用。
要恢复此功能,请使用@torch.jit.unused()。@torch.jit.ignore 现在等同于@torch.jit.ignore(drop=False)。
有关详细信息,请参见@torch.jit.ignore 和 @torch.jit.unused。
当传递给torch.jit.script函数时,一个torch.nn.Module的数据会被复制到一个ScriptModule中,并且TorchScript编译器会编译该模块。
模块的forward默认会被编译。从forward调用的方法会在forward中按使用顺序惰性编译,以及任何@torch.jit.export方法。
-
torch.jit.export(fn)[source]¶ 此装饰器表示一个方法在
nn.Module上被用作进入ScriptModule的入口点,并应进行编译。forward默认被视为入口点,因此不需要这个装饰器。 从forward调用的函数和方法在编译器看到时会被编译,所以它们也不需要这个装饰器。示例(在方法上使用
@torch.jit.export):import torch import torch.nn as nn class MyModule(nn.Module): def implicitly_compiled_method(self, x): return x + 99 # `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effect def forward(self, x): return x + 10 @torch.jit.export def another_forward(self, x): # When the compiler sees this call, it will compile # `implicitly_compiled_method` return self.implicitly_compiled_method(x) def unused_method(self, x): return x - 20 # `m` will contain compiled methods: # `forward` # `another_forward` # `implicitly_compiled_method` # `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export` m = torch.jit.script(MyModule())
功能¶
函数变化不大,如果需要,可以用@torch.jit.ignore或torch.jit.unused进行装饰。
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
return 2
# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
return 2
# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
import pdb; pdb.set_trace()
return 4
# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
return 2
TorchScript 类¶
警告
TorchScript 类支持是实验性的。目前它最适合用于简单的记录类型(想象一个带有方法的 NamedTuple)。
用户定义的TorchScript 类中的所有内容默认都会被导出,如果需要,函数可以使用 @torch.jit.ignore 进行装饰。
属性¶
TorchScript编译器需要知道module attributes的类型。大多数类型可以从成员的值推断出来。
空列表和字典无法推断其类型,必须使用PEP 526风格的类注释进行类型标注。
如果一个类型无法推断且未明确标注,则它不会被添加为结果ScriptModule的属性。
旧API:
from typing import Dict
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.my_dict = torch.jit.Attribute({}, Dict[str, int])
self.my_int = torch.jit.Attribute(20, int)
m = MyModule()
新 API:
from typing import Dict
class MyModule(torch.nn.Module):
my_dict: Dict[str, int]
def __init__(self):
super(MyModule, self).__init__()
# This type cannot be inferred and must be specified
self.my_dict = {}
# The attribute type here is inferred to be `int`
self.my_int = 20
def forward(self):
pass
m = torch.jit.script(MyModule())
常量¶
构造函数Final可以用于将成员标记为constant。如果成员未被标记为常量,则它们将以属性的形式复制到结果ScriptModule中。使用Final可以在值已知且固定的情况下提供优化机会,并增加类型安全性。
旧API:
class MyModule(torch.jit.ScriptModule):
__constants__ = ['my_constant']
def __init__(self):
super(MyModule, self).__init__()
self.my_constant = 2
def forward(self):
pass
m = MyModule()
新 API:
try:
from typing_extensions import Final
except:
# If you don't have `typing_extensions` installed, you can use a
# polyfill from `torch.jit`.
from torch.jit import Final
class MyModule(torch.nn.Module):
my_constant: Final[int]
def __init__(self):
super(MyModule, self).__init__()
self.my_constant = 2
def forward(self):
pass
m = torch.jit.script(MyModule())
变量¶
容器假定类型为Tensor且为非可选类型(有关更多信息,请参见Default Types)。之前,使用torch.jit.annotate来告知TorchScript编译器应使用的类型。现在支持Python 3风格的类型提示。
import torch
from typing import Dict, Optional
@torch.jit.script
def make_dict(flag: bool):
x: Dict[str, int] = {}
x['hi'] = 2
b: Optional[int] = None
if flag:
b = 2
return x, b