目录

TorchScript 脚本

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。 任何 TorchScript 程序都可以从 Python 保存 进程中加载,并在没有 Python 依赖项的进程中加载。

我们提供了从纯 Python 程序逐步转换模型的工具 添加到可以独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中。 这使得使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后导出 模型通过 TorchScript 部署到 Python 程序可能不利的生产环境 出于性能和多线程的原因。

有关 TorchScript 的简要介绍,请参阅 TorchScript 简介教程。

有关将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行它的端到端示例,请参阅在 C++ 中加载 PyTorch 模型教程。

创建 TorchScript 代码

脚本

编写函数脚本。

跟踪

跟踪函数并返回可执行文件,或者将使用 just-in-time 编译进行优化的可执行文件。

script_if_tracing

在跟踪期间首次调用时进行编译。fn

trace_module

跟踪模块并返回将使用 just-in-time 编译进行优化的可执行文件

创建一个执行 func 的异步任务,并引用此执行的结果的值。

强制完成 torch.jit.Future[T] 异步任务,返回任务结果。

脚本模块

C++ torch::jit::Module 的包装器,带有方法、属性和参数。

脚本函数

在功能上等同于 ,但表示单个函数,并且没有任何属性或参数。

冻结

将 ScriptModule、内联子模块和属性冻结为常量。

optimize_for_inference

执行一组优化过程以优化模型以进行推理。

enable_onednn_fusion

根据启用的参数启用或禁用 onednn JIT 融合。

onednn_fusion_enabled

返回是否启用了 onednn JIT 融合。

set_fusion_strategy

设置融合期间可能发生的特化类型和数量。

strict_fusion

如果不是所有节点都在推理中融合,或者在训练中符号微分,则给出错误。

保存此模块的脱机版本,以便在单独的进程中使用。

负荷

加载 或 以前使用 保存的

忽视

此装饰器向编译器指示应忽略函数或方法并将其保留为 Python 函数。

闲置

此装饰器向编译器指示应忽略函数或方法,并将其替换为引发异常。

接口

Decorate 为不同类型的类或模块添加注释。

is实例

在 TorchScript 中提供容器类型优化。

属性

此方法是返回 value 的传递函数,主要用于向 TorchScript 编译器指示左侧表达式是 type 为 type 的类实例属性。

注释

用于在 TorchScript 编译器中提供 the_value 类型。

混合跟踪和脚本

在许多情况下,跟踪或脚本是将模型转换为 TorchScript 的更简单的方法。 可以组合跟踪和脚本以满足特定要求 的一部分。

脚本化函数可以调用跟踪函数。当您需要时,这特别有用 围绕简单的前馈模型使用 control-flow。例如,光束搜索 序列到序列模型通常用脚本编写,但可以调用 使用 tracing 生成的 encoder 模块。

示例(在脚本中调用跟踪的函数):

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

跟踪的函数可以调用脚本函数。这在一小部分 模型需要一些控制流,即使模型的大部分只是前馈 网络。由跟踪函数调用的脚本函数内的控制流是 正确保存。

示例(在跟踪函数中调用脚本函数):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

此合成也适用于 s,可用于生成 一个使用跟踪的子模块,可以从 script 模块的方法中调用。nn.Module

示例(使用 traced 模块):

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript 语言

TorchScript 是 Python 的静态类型子集,因此许多 Python 功能都适用 直接发送到 TorchScript。有关详细信息,请参阅完整的 TorchScript 语言参考

内置函数和模块

TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。 有关受支持函数的完整参考,请参阅 TorchScript Builtins

PyTorch 函数和模块

TorchScript 支持张量和神经网络的子集 PyTorch 提供的函数。Tensor 上的大多数方法以及 命名空间、和 中的所有函数 TorchScript 支持大多数 的模块。torchtorch.nn.functionaltorch.nn

有关不支持的 PyTorch 函数和模块的列表,请参阅 TorchScript 不支持的 PyTorch 构造

Python 函数和模块

TorchScript 支持许多 Python 的内置函数。 该模块也受支持(有关详细信息,请参见 math Module),但不支持其他 Python 模块 (内置或第三方)。

Python 语言参考比较

有关支持的 Python 功能的完整列表,请参阅 Python 语言参考覆盖范围

调试

禁用 JIT 进行调试

PYTORCH_JIT

设置环境变量将禁用所有脚本 以及描摹注释。如果您的 TorchScript 模型,您可以使用此标志强制所有内容都使用本机运行 蟒。由于使用此标志禁用了 TorchScript(脚本和跟踪),因此 你可以使用 like 工具来调试模型代码。例如:PYTORCH_JIT=0pdb

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

调试此脚本有效,但调用函数时除外。我们可以全局禁用 JIT 的函数,这样我们就可以像普通的 Python 函数一样调用该函数,而不是编译它。如果上述脚本 被调用,我们可以像这样调用它:pdbdisable_jit_example.py

$ PYTORCH_JIT=0 python disable_jit_example.py

我们将能够像普通的 Python 函数一样单步执行该函数。要禁用 TorchScript 编译器,请参阅

检查代码

TorchScript 为所有实例提供 code pretty-printer。这 pretty-printer 将 script 方法的代码解释为有效 Python 语法。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

具有单个方法的 a 将具有一个属性 ,您可以使用该属性来检查 的代码。 如果 有多个方法,则需要访问方法本身而不是模块。我们可以检查 通过访问 在 上命名的方法的代码。 上面的示例生成以下输出:forwardcode.codefoo.foo.code

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

这是 TorchScript 对该方法代码的编译。 您可以使用它来确保 TorchScript(跟踪或脚本)已捕获 您的模型代码正确。forward

解释图形

TorchScript 也有一个比代码 pretty- 低级别的表示 打印机,以 IR 图形的形式。

TorchScript 使用静态单一赋值 (SSA) 中间表示 (IR) 来表示计算。此格式的说明包括 ATen(PyTorch 的 C++ 后端)运算符和其他原始运算符, 包括 Loop 和 Conditional 的控制流运算符。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph遵循 Inspecting Code 部分中描述的相同规则 关于方法查找。forward

上面的示例脚本生成了图表:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

按照说明操作 例。%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10

  • %rv.1 : Tensor意味着我们将输出分配给一个名为 的(唯一)值,该值的类型,并且我们不知道它的具体形状。rv.1Tensor

  • aten::zeros是运算符(相当于 ),而 input list 指定 scope 中的哪些值应作为 inputs传递。内置函数的架构可以在 Builtin Functions 中找到。torch.zeros(%4, %6, %6, %10, %12)aten::zeros

  • # test.py:9:10是生成此指令的原始源文件中的位置。在本例中,它是一个名为 test.py 的文件,位于第 9 行和字符 10 处。

请注意,运算符也可以具有关联的 ,即 和 运算符。在图形打印输出中,这些 运算符的格式设置以反映其等效的源代码格式 以便于调试。blocksprim::Loopprim::If

可以如图所示检查图形,以确认所描述的计算 by a 是正确的,无论是自动还是手动方式,因为 如下所述。

示 踪

跟踪边缘案例

存在一些边缘情况,其中给定 Python 的跟踪 function/module 不代表底层代码。这些 案例可能包括:

  • 跟踪依赖于输入(例如张量形状)的控制流

  • 跟踪张量视图的就地操作(例如,在赋值左侧进行索引)

请注意,这些情况实际上可能在将来可以追踪。

自动跟踪检查

自动捕获跟踪中许多错误的一种方法是在 API 上使用。 接受 Tuples 列表 的输入,这些输入将用于重新跟踪计算并验证 结果。例如:check_inputstorch.jit.trace()check_inputs

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

为我们提供以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

此消息向我们表明,当 我们首先跟踪它,当我们使用 .事实上 body of 内的 loop 取决于 shape 的输入,因此当我们尝试另一个具有不同 形状,则跟踪会有所不同。check_inputsloop_in_traced_fnxx

在这种情况下,可以使用以下方法捕获像这样的数据依赖型控制流:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))

这会产生:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

Tracer 警告

tracer 为 traced 中的几个有问题的模式生成警告 计算。例如,以包含 对 Tensor 的切片(视图)进行就地赋值:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

生成多个警告和一个仅返回输入的图形:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我们可以通过修改代码来解决这个问题,使其不使用就地更新,但 而是用 :torch.cat

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常见问题解答

问:我想在 GPU 上训练模型并在 CPU 上进行推理。什么是 最佳实践?

首先将模型从 GPU 转换为 CPU,然后保存,如下所示:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

建议这样做,因为跟踪器可能会在 特定设备,因此强制转换已加载的模型可能会产生意外 影响。在保存模型之前强制转换模型可确保跟踪器具有 正确的设备信息。

Q: 如何在 ?

假设我们有一个这样的模型:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

如果实例化,则会导致编译错误 由于编译器不知道 .有 4 种方法可以通知 属性的编译器 :Modelx

1. - 包装的值将按其方式工作 对 S 执行nn.Parameternn.Parameternn.Module

2. - 括入的值将用作 他们在 S 上这样做。这相当于 类型的属性(请参阅 4)。register_bufferregister_buffernn.ModuleTensor

3. 常量 - 将类成员注释为 (或将其添加到类定义级别调用的列表中) 将标记包含的名称 作为常量。常量直接保存在模型的代码中。有关详细信息,请参阅 builtin-constantsFinal__constants__

4. 属性 - 支持类型的值可以添加为可变值 属性。大多数类型都可以推断,但有些类型可能需要指定,有关详细信息,请参阅 module attributes

Q: 我想跟踪 module 的方法,但一直收到这个错误:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

此错误通常意味着您正在跟踪的方法使用模块的参数和 您传递的是 Module 的方法而不是 Module 实例(例如 vs )。my_module_instance.forwardmy_module_instance

  • 使用模块的方法调用会将模块参数(可能需要梯度)捕获为常量trace

  • 另一方面,使用 module's instance (e.g. ) 调用会创建一个新模块并正确地将参数复制到新模块中,以便它们可以在需要时累积梯度。tracemy_module

要跟踪模块上的特定方法,请参阅

已知问题

如果你与 TorchScript 一起使用,则一些 的子模块可能会被错误地推断为 be ,即使它们另有注释。规范 解决方案是 subclass 并使用正确键入的 input 重新声明。SequentialSequentialTensornn.Sequentialforward

附录

迁移到 PyTorch 1.2 递归脚本 API

本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。如果您是 TorchScript 的新手,您可以 跳过本节。PyTorch 1.2 的 TorchScript API 有两个主要变化。

1. 现在将尝试递归编译函数, 方法和它遇到的类。调用 后, 编译是 “opt-out”,而不是 “opt-in”。torch.jit.script

2. 现在是创建 S 的首选方式,而不是从 继承 . 这些更改结合在一起,提供了一个更简单、易于使用的 API 用于转换 您的 S 转换为 S,准备好在 非 Python 环境。torch.jit.script(nn_module_instance)torch.jit.ScriptModulenn.Module

新用法如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • 默认情况下,该模块的 is 编译的。调用 from 的方法按照它们在 中的使用顺序进行延迟编译。forwardforwardforward

  • 要编译不是从 、 、 add .forwardforward@torch.jit.export

  • 要阻止编译器编译方法,请添加 。 离开@ignore

  • method 作为对 Python 的调用,并将其替换为 exception。 无法导出; 能。@unused@ignored@unused

  • 大多数属性类型都可以推断,因此不是必需的。对于空容器类型,请使用 PEP 526 样式的类注释来注释它们的类型。torch.jit.Attribute

  • 常量可以使用类注释进行标记,而不是将成员的名称添加到 .Final__constants__

  • Python 3 类型提示可以代替torch.jit.annotate

由于这些更改,以下项目被视为已弃用,不应出现在新代码中:
  • 装饰器@torch.jit.script_method

  • 继承自torch.jit.ScriptModule

  • wrapper 类torch.jit.Attribute

  • 数组__constants__

  • 函数torch.jit.annotate

模块

警告

注释的行为在 PyTorch 1.2.在 PyTorch 1.2 之前,@ignore 装饰器用于创建函数 或从导出的代码中调用的方法。要恢复此功能, 用。 现在是等效的 自。有关详细信息,请参阅 @torch.jit.unused()@torch.jit.ignore@torch.jit.ignore(drop=False)

当传递给函数时,a 的数据为 复制到 a 并且 TorchScript 编译器会编译该模块。 默认情况下,该模块的 is 编译的。调用 from 的方法包括 按照它们在 中的使用顺序以及任何方法进行延迟编译。torch.nn.Moduleforwardforwardforward@torch.jit.export

torch.jit 中。导出 fn[来源]

此装饰器指示 an 上的方法用作 a 的入口点,并且应该进行编译。nn.Module

forwardimplicitly 被假定为一个入口点,因此它不需要这个装饰器。 调用 from 的函数和方法按其可见的方式进行编译 由编译器,因此它们也不需要这个装饰器。forward

示例(在方法上使用):@torch.jit.export

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

功能

函数没有太大变化,它们可以装饰或在需要时进行装饰。

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 类

警告

TorchScript 类支持是实验性的。目前最适合 对于简单的类似记录的类型(想想 with 方法 附件)。NamedTuple

用户定义的 TorchScript 类中的所有内容都是 exported 时,如果需要,可以使用 Functions 进行修饰

属性

TorchScript 编译器需要知道模块属性的类型。大多数类型 可以从 member 的值推断出来。空列表和字典不能具有 类型,并且必须使用 PEP 526 样式的类注释来注释它们的类型。 如果无法推断类型且未显式注释,则不会将其添加为属性 到生成的

旧 API:

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

新 API:

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

常数

类型构造函数可用于将成员标记为 constant。如果成员未标记为常量,则它们将作为属性复制到结果中。如果已知该值是固定的,则 using 会打开优化机会,并提供额外的类型安全性。FinalFinal

旧 API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

新 API:

from typing import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

变量

容器假定具有 type 且不可选(有关更多信息,请参阅 Default Types )。以前,习惯于 告诉 TorchScript 编译器类型应该是什么。Python 3 样式的类型提示是 现在支持。Tensortorch.jit.annotate

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

Fusion 后端

有几个融合后端可用于优化 TorchScript 执行。CPU 上的默认定影器是 NNC,它可以对 CPU 和 GPU 执行融合。GPU 上的默认熔融器是 NVFuser,它支持更广泛的运算符,并已演示生成的内核具有更高的吞吐量。有关使用和调试的更多详细信息,请参阅 NVFuser 文档

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源