目录

TorchScript 脚本

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 保存 进程中加载,并在没有 Python 依赖项的进程中加载。

我们提供了从纯 Python 程序逐步转换模型的工具 添加到可以独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中。 这使得使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后导出 模型通过 TorchScript 部署到 Python 程序可能不利的生产环境 出于性能和多线程的原因。

有关 TorchScript 的简要介绍,请参阅 TorchScript 简介教程。

有关将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行它的端到端示例,请参阅在 C++ 中加载 PyTorch 模型教程。

创建 TorchScript 代码

script

编写函数脚本或将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,并返回一个nn.ModuleScriptModuleScriptFunction.

trace

跟踪函数并返回可执行文件或ScriptFunction这将使用 Just-in-Time 编译进行优化。

script_if_tracing

在跟踪期间首次调用时进行编译。fn

trace_module

跟踪模块并返回可执行文件ScriptModule这将使用 Just-in-Time 编译进行优化。

fork

创建一个执行 func 的异步任务,并创建一个对此执行结果值的引用。

wait

强制完成 torch.jit.Future[T] 异步任务,返回任务的结果。

ScriptModule

C++ 的包装器。torch::jit::Module

ScriptFunction

在功能上等同于ScriptModule,但表示单个函数,并且没有任何属性或参数。

freeze

冻结ScriptModule将克隆它并尝试将克隆模块的子模块、参数和属性内联为 TorchScript IR Graph 中的常量。

optimize_for_inference

执行一组优化过程以优化模型以进行推理。

enable_onednn_fusion

根据启用的参数启用或禁用 onednn JIT 融合。

onednn_fusion_enabled

返回是否启用 onednn JIT 融合

set_fusion_strategy

设置融合期间可能发生的特化类型和数量。

strict_fusion

如果不是所有节点都在推理中融合,或者在训练中符号区分,则此类错误。

save

保存此模块的脱机版本,以便在单独的进程中使用。

load

加载ScriptModuleScriptFunction之前保存为torch.jit.save

ignore

此装饰器向编译器指示应忽略函数或方法并将其保留为 Python 函数。

unused

此装饰器向编译器指示应忽略函数或方法,并将其替换为引发异常。

isinstance

此函数在 TorchScript 中提供 conatiner 类型优化。

Attribute

此方法是返回 value 的传递函数,主要用于向 TorchScript 编译器指示左侧表达式是 type 为 type 的类实例属性。

annotate

此方法是返回 the_value 的直通函数,用于提示 TorchScript 编译器the_value的类型。

混合跟踪和脚本

在许多情况下,跟踪或脚本是将模型转换为 TorchScript 的更简单的方法。 可以组合跟踪和脚本以满足特定要求 的一部分。

脚本化函数可以调用跟踪函数。当您需要时,这特别有用 围绕简单的前馈模型使用 control-flow。例如,光束搜索 序列到序列模型通常用脚本编写,但可以调用 使用 tracing 生成的 encoder 模块。

示例(在脚本中调用跟踪的函数):

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

跟踪的函数可以调用脚本函数。这在一小部分 模型需要一些控制流,即使模型的大部分只是前馈 网络。由跟踪函数调用的脚本函数内的控制流是 正确保存。

示例(在跟踪函数中调用脚本函数):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

此合成也适用于 s,可用于生成 一个使用跟踪的子模块,可以从 script 模块的方法中调用。nn.Module

示例(使用 traced 模块):

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript 语言

TorchScript 是 Python 的静态类型子集,因此许多 Python 功能都适用 直接发送到 TorchScript。有关详细信息,请参阅完整的 TorchScript 语言参考

内置函数和模块

TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。 有关受支持函数的完整参考,请参阅 TorchScript Builtins

PyTorch 函数和模块

TorchScript 支持张量和神经网络的子集 PyTorch 提供的函数。Tensor 上的大多数方法以及 命名空间、和 中的所有函数 TorchScript 支持大多数 的模块。torchtorch.nn.functionaltorch.nn

有关不支持的 PyTorch 函数和模块的列表,请参阅 TorchScript 不支持的 Pytorch 构造

Python 函数和模块

TorchScript 支持许多 Python 的内置函数。 这mathmodule 也受支持(有关详细信息,请参阅 math Module ),但不支持其他 Python 模块 (内置或第三方)。

Python 语言参考比较

有关支持的 Python 功能的完整列表,请参阅 Python 语言参考覆盖范围

调试

禁用 JIT 进行调试

PYTORCH_JIT

设置环境变量将禁用所有脚本 以及描摹注释。如果您的 TorchScript 模型,您可以使用此标志强制所有内容都使用本机运行 蟒。由于使用此标志禁用了 TorchScript(脚本和跟踪),因此 你可以使用 like 工具来调试模型代码。例如:PYTORCH_JIT=0pdb

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

调试此脚本的工作情况除外,但当我们调用pdb@torch.jit.script功能。我们可以全局禁用 JIT 的 API 调用,以便我们可以调用@torch.jit.script函数作为普通的 Python 函数,而不是编译它。如果上述脚本 被调用,我们可以像这样调用它:disable_jit_example.py

$ PYTORCH_JIT=0 python disable_jit_example.py

我们将能够步入@torch.jit.script函数作为普通的 Python 函数。要禁用 TorchScript 编译器,请参阅@torch.jit.ignore.

检查代码

TorchScript 为所有ScriptModule实例。这 pretty-printer 将 script 方法的代码解释为有效 Python 语法。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

一个ScriptModule将具有一个 attribute ,您可以使用它来检查forwardcodeScriptModule的代码。 如果ScriptModule有多个方法,则需要访问方法本身而不是模块。我们可以检查 code 中名为.codefooScriptModule通过访问 . 上面的示例生成以下输出:.foo.code

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

这是 TorchScript 对该方法代码的编译。 您可以使用它来确保 TorchScript(跟踪或脚本)已捕获 您的模型代码正确。forward

解释图形

TorchScript 也有一个比代码 pretty- 低级别的表示 打印机,以 IR 图形的形式。

TorchScript 使用静态单一赋值 (SSA) 中间表示 (IR) 来表示计算。此格式的说明包括 ATen(PyTorch 的 C++ 后端)运算符和其他原始运算符, 包括 Loop 和 Conditional 的控制流运算符。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph遵循 Inspecting Code 部分中描述的相同规则 关于方法查找。forward

上面的示例脚本生成了图表:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

按照说明作 例。%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10

  • %rv.1 : Tensor意味着我们将输出分配给一个名为 的(唯一)值,该值的类型,并且我们不知道它的具体形状。rv.1Tensor

  • aten::zeros是运算符(相当于 ),而 input list 指定 scope 中的哪些值应作为 inputs传递。内置函数的架构可以在 Builtin Functions 中找到。torch.zeros(%4, %6, %6, %10, %12)aten::zeros

  • # test.py:9:10是生成此指令的原始源文件中的位置。在本例中,它是一个名为 test.py 的文件,位于第 9 行和字符 10 处。

请注意,运算符也可以具有关联的 ,即 和 运算符。在图形打印输出中,这些 运算符的格式设置以反映其等效的源代码格式 以便于调试。blocksprim::Loopprim::If

可以如图所示检查图形,以确认所描述的计算 由ScriptModule在自动和手动方式上都是正确的,因为 如下所述。

示 踪

跟踪边缘案例

存在一些边缘情况,其中给定 Python 的跟踪 function/module 不代表底层代码。这些 案例可能包括:

  • 跟踪依赖于输入(例如张量形状)的控制流

  • 跟踪张量视图的就地作(例如,在赋值左侧进行索引)

请注意,这些情况实际上可能在将来可以追踪。

自动跟踪检查

自动捕获跟踪中许多错误的一种方法是在 API 上使用。 接受 Tuples 列表 的输入,这些输入将用于重新跟踪计算并验证 结果。例如:check_inputstorch.jit.trace()check_inputs

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

为我们提供以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

此消息向我们表明,当 我们首先跟踪它,当我们使用 .事实上 body of 内的 loop 取决于 shape 的输入,因此当我们尝试另一个具有不同 形状,则跟踪会有所不同。check_inputsloop_in_traced_fnxx

在这种情况下,可以使用torch.jit.script()相反:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))

这会产生:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

Tracer 警告

tracer 为 traced 中的几个有问题的模式生成警告 计算。例如,以包含 对 Tensor 的切片(视图)进行就地赋值:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

生成多个警告和一个仅返回输入的图形:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我们可以通过修改代码来解决这个问题,使其不使用就地更新,但 而是用 :torch.cat

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常见问题解答

问:我想在 GPU 上训练模型并在 CPU 上进行推理。什么是 最佳实践?

首先将模型从 GPU 转换为 CPU,然后保存,如下所示:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

建议这样做,因为跟踪器可能会在 特定设备,因此强制转换已加载的模型可能会产生意外 影响。在保存模型之前强制转换模型可确保跟踪器具有 正确的设备信息。

问:如何在ScriptModule?

假设我们有一个这样的模型:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

如果实例化,则会导致编译错误 由于编译器不知道 .有 4 种方法可以通知 Compiler of attributes onModelxScriptModule:

1. - 包装的值将按其方式工作 对 S 执行nn.Parameternn.Parameternn.Module

2. - 括入的值将用作 他们在 S 上这样做。这相当于 类型的属性(请参阅 4)。register_bufferregister_buffernn.ModuleTensor

3. 常量 - 将类成员注释为 (或将其添加到类定义级别调用的列表中) 将标记包含的名称 作为常量。常量直接保存在模型的代码中。有关详细信息,请参阅 builtin-constantsFinal__constants__

4. 属性 - 支持类型的值可以添加为可变值 属性。大多数类型都可以推断,但有些类型可能需要指定,有关详细信息,请参阅 module attributes

Q: 我想跟踪 module 的方法,但一直收到这个错误:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

此错误通常意味着您正在跟踪的方法使用模块的参数和 您传递的是 Module 的方法而不是 Module 实例(例如 vs )。my_module_instance.forwardmy_module_instance

  • 使用模块的方法调用会将模块参数(可能需要梯度)捕获为常量trace

  • 另一方面,使用 module's instance (e.g. ) 调用会创建一个新模块并正确地将参数复制到新模块中,以便它们可以在需要时累积梯度。tracemy_module

要跟踪模块上的特定方法,请参阅torch.jit.trace_module

已知问题

如果你与 TorchScript 一起使用,则一些 的子模块可能会被错误地推断为 be ,即使它们另有注释。规范 解决方案是 subclass 并使用正确键入的 input 重新声明。SequentialSequentialTensornn.Sequentialforward

附录

迁移到 PyTorch 1.2 递归脚本 API

本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。如果您是 TorchScript 的新手,您可以 跳过本节。PyTorch 1.2 的 TorchScript API 有两个主要变化。

1.torch.jit.script现在将尝试递归编译函数, 方法和它遇到的类。调用 后, 编译是 “opt-out”,而不是 “opt-in”。torch.jit.script

2. 现在是首选的创建torch.jit.script(nn_module_instance)ScriptModules 继承,而不是继承自 . 这些更改结合在一起,提供了一个更简单、易于使用的 API 用于转换 你的 S 进入torch.jit.ScriptModulenn.ModuleScriptModules,准备好在 非 Python 环境。

新用法如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • 默认情况下,该模块的 is 编译的。调用 from 的方法按照它们在 中的使用顺序进行延迟编译。forwardforwardforward

  • 要编译不是从 、 、 add .forwardforward@torch.jit.export

  • 要阻止编译器编译方法,请添加@torch.jit.ignore@torch.jit.unused. 离开@ignore

  • method 作为对 Python 的调用,并将其替换为 exception。 无法导出; 能。@unused@ignored@unused

  • 大多数属性类型都可以推断,因此不是必需的。对于空容器类型,请使用 PEP 526 样式的类注释来注释它们的类型。torch.jit.Attribute

  • 常量可以使用类注释进行标记,而不是将成员的名称添加到 .Final__constants__

  • Python 3 类型提示可以代替torch.jit.annotate

由于这些更改,以下项目被视为已弃用,不应出现在新代码中:
  • 装饰器@torch.jit.script_method

  • 继承自torch.jit.ScriptModule

  • wrapper 类torch.jit.Attribute

  • 数组__constants__

  • 函数torch.jit.annotate

模块

警告

@torch.jit.ignoreAnnotation 的行为在 PyTorch 1.2.在 PyTorch 1.2 之前,@ignore 装饰器用于创建函数 或从导出的代码中调用的方法。要恢复此功能, 用。 现在是等效的 自。看@torch.jit.unused()@torch.jit.ignore@torch.jit.ignore(drop=False)@torch.jit.ignore@torch.jit.unused了解详情。

当传递给torch.jit.script函数,则 a 的数据为 复制到torch.nn.ModuleScriptModuleTorchScript 编译器编译模块。 默认情况下,该模块的 is 编译的。调用 from 的方法包括 按照它们在 中的使用顺序以及任何方法进行延迟编译。forwardforwardforward@torch.jit.export

torch.jit.export(fn[来源]

此装饰器指示 an 上的方法用作nn.ModuleScriptModule并且应该进行编译。

forwardimplicitly 被假定为一个入口点,因此它不需要这个装饰器。 调用 from 的函数和方法按其可见的方式进行编译 由编译器,因此它们也不需要这个装饰器。forward

示例(在方法上使用):@torch.jit.export

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

功能

函数变化不大,可以用@torch.jit.ignoretorch.jit.unused如果需要。

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 类

警告

TorchScript 类支持是实验性的。目前最适合 对于简单的类似记录的类型(想想 with 方法 附件)。NamedTuple

用户定义的 TorchScript 类中的所有内容都是 exported 时,函数可以用@torch.jit.ignore如果需要。

属性

TorchScript 编译器需要知道模块属性的类型。大多数类型 可以从 member 的值推断出来。空列表和字典不能具有 类型,并且必须使用 PEP 526 样式的类注释来注释它们的类型。 如果无法推断类型且未显式注释,则不会将其添加为属性 到生成的ScriptModule

旧 API:

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

新 API:

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super(MyModule, self).__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

常数

类型构造函数可用于将成员标记为 constant。如果成员未标记为常量,则它们将被复制到生成的FinalScriptModule作为属性。如果已知该值是固定的,则 using 会打开优化机会,并提供额外的类型安全性。Final

旧 API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

新 API:

try:
    from typing_extensions import Final
except:
    # If you don't have `typing_extensions` installed, you can use a
    # polyfill from `torch.jit`.
    from torch.jit import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

变量

容器假定具有 type 且不可选(有关更多信息,请参阅 Default Types )。以前,习惯于 告诉 TorchScript 编译器类型应该是什么。Python 3 样式的类型提示是 现在支持。Tensortorch.jit.annotate

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

Fusion 后端

有几个融合后端可用于优化 TorchScript 执行。CPU 上的默认定影器是 NNC,它可以对 CPU 和 GPU 执行融合。GPU 上的默认熔融器是 NVFuser,它支持更广泛的运算符,并已演示生成的内核具有更高的吞吐量。有关使用和调试的更多详细信息,请参阅 NVFuser 文档

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源