目录

torch.export

警告

此功能是一个正在积极开发的原型,未来将会有重大变更。

概述

torch.export.export() 接受任意的Python可调用对象(一个 torch.nn.Module、函数或方法),并生成一个仅表示函数中Tensor计算的跟踪图,以提前编译(AOT)的方式进行,随后可以使用不同的输出执行或序列化。

import torch
from torch.export import export

def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    f, args=example_args
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
            # code: a = torch.sin(x)
            sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);

            # code: b = torch.cos(y)
            cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
            return (add,)

    Graph signature: ExportGraphSignature(
        parameters=[],
        buffers=[],
        user_inputs=['arg0_1', 'arg1_1'],
        user_outputs=['add'],
        inputs_to_parameters={},
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}
    Equality constraints: []

torch.export 生成一个具有以下不变性的清晰中间表示(IR)。有关 IR 的更多规范请参见此处(即将推出!)。

  • 正确性: 它保证是原始程序的正确表示,并且保持了原始程序相同的调用约定。

  • 已规范化: 图中没有Python语义。来自原始程序的子模块被内联以形成一个完全扁平化的计算图。

  • 定义的操作符集: 生成的图仅包含少量定义的 Core ATen IR 操作符集和注册的自定义操作符。

  • 图属性: 该图是纯粹的功能性,意味着它不包含具有副作用的操作,如突变或别名。它不会改变任何中间值、参数或缓冲区。

  • 元数据: 图形包含在跟踪过程中捕获的元数据,例如来自用户代码的堆栈跟踪。

在幕后,torch.export 利用了以下最新技术:

  • TorchDynamo (torch._dynamo) 是一个内部API,它使用CPython的一个特性 称为帧评估API来安全地追踪PyTorch图。这 提供了一个极大改进的图捕获体验,需要重写的次数大大减少 以便完全追踪PyTorch代码。

  • 提前自动微分(AOT Autograd) 提供了一个函数化的 PyTorch 计算图,并确保该计算图 被分解/转换为小范围定义的核心 ATen 操作符集合。

  • PyTorch FX (torch.fx) 是图的底层表示形式,允许灵活的基于Python的转换。

现有框架

torch.compile() 同样使用了与 torch.export 相同的PT2堆栈,但 略有不同:

  • 即时编译与提前编译: torch.compile() 是一个即时编译器,而 则不打算用于在部署之外生成编译工件。

  • 部分图捕获与全图捕获: 当 torch.compile() 遇到模型中无法追踪的部分时,它将“图中断”并回退到急切的Python运行时执行程序。相比之下,torch.export 旨在获取PyTorch模型的完整图表示,因此当遇到无法追踪的内容时会报错。由于 torch.export 生成的图与任何Python特性或运行时无关,因此该图可以保存、加载并在不同的环境和语言中运行。

  • 可用性权衡: 由于 torch.compile() 在遇到无法追踪的内容时能够回退到Python运行时,因此它更加灵活。而 torch.export 则需要用户提供更多信息或重写代码以使其可追踪。

torch.fx.symbolic_trace()相比,torch.export使用TorchDynamo进行追踪,它在Python字节码级别运行,因此具有追踪任意Python构造的能力,而不受Python运算符重载支持的限制。此外,torch.export对张量元数据进行细粒度跟踪,因此基于张量形状等条件的追踪不会失败。一般来说,torch.export预计可以在更多用户程序上工作,并生成较低级别的图(在torch.ops.aten运算符级别)。请注意,用户仍然可以将torch.fx.symbolic_trace()作为torch.export之前的预处理步骤。

torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图更简单,只有直线控制流(除了显式的控制流操作符)。

torch.jit.trace() 相比,torch.export 是可靠的:它能够追踪执行整数计算的代码,并记录所有必要的附加条件,以证明特定的追踪对其他输入也是有效的。

导出PyTorch模型

一个示例

主要入口点是通过 torch.export.export(),它接受一个可调用对象(torch.nn.Module、函数或方法)和示例输入,并将计算图捕获到一个 torch.export.ExportedProgram 中。一个例子:

import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):

            # code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
                arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
            );

            # code: a.add_(constant)
            add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);

            # code: return self.maxpool(self.relu(a))
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
                relu, [3, 3], [3, 3]
            );
            getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
            return (getitem,)

    Graph signature: ExportGraphSignature(
        parameters=['L__self___conv.weight', 'L__self___conv.bias'],
        buffers=[],
        user_inputs=['arg2_1', 'arg3_1'],
        user_outputs=['getitem'],
        inputs_to_parameters={
            'arg0_1': 'L__self___conv.weight',
            'arg1_1': 'L__self___conv.bias',
        },
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}
    Equality constraints: []

检查ExportedProgram,我们可以注意到以下内容:

  • torch.fx.Graph 包含原始程序的计算图,以及原始代码记录,便于调试。

  • 该图仅包含在Core ATen IR opset和自定义运算符中找到的torch.ops.aten个运算符,并且是完全功能性的,没有任何就地运算符,如torch.add_

  • 参数(权重和偏差到卷积)被提升为图的输入,
    导致图中没有 get_attr 个节点,这些节点以前存在于 torch.fx.symbolic_trace() 的结果中。

  • torch.export.ExportGraphSignature 模型描述了输入和输出的签名,并指定了哪些输入是参数。

  • 图中每个节点生成的张量的形状和数据类型都被标注出来。例如,convolution 节点将生成一个数据类型为 torch.float32 且形状为 (1, 16, 256, 256) 的张量。

表达动态性

默认情况下,torch.export 会假设所有输入形状都是 静态 的,并将导出的程序专门化为这些维度。但是,某些维度(例如批次维度)可以是动态的,并且可以从运行到运行变化。这样的维度必须使用 torch.export.dynamic_dim() API 标记为动态,并通过 torch.export.export()constraints 参数传递。一个例子:

import torch
from torch.export import export, dynamic_dim

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))
constraints = [
    # First dimension of each input is a dynamic batch size
    dynamic_dim(example_args[0], 0),
    dynamic_dim(example_args[1], 0),
    # The dynamic batch size between the inputs are equal
    dynamic_dim(example_args[0], 0) == dynamic_dim(example_args[1], 0),
]

exported_program: torch.export.ExportedProgram = export(
  M(), args=example_args, constraints=constraints
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):

            # code: out1 = self.branch1(x1)
            permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
            addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
            relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);

            # code: out2 = self.branch2(x2)
            permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
            addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
            relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None

            # code: return (out1 + self.buffer, out2)
            add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
            return (add, relu_1)

    Graph signature: ExportGraphSignature(
        parameters=[
            'branch1.0.weight',
            'branch1.0.bias',
            'branch2.0.weight',
            'branch2.0.bias',
        ],
        buffers=['L__self___buffer'],
        user_inputs=['arg5_1', 'arg6_1'],
        user_outputs=['add', 'relu_1'],
        inputs_to_parameters={
            'arg0_1': 'branch1.0.weight',
            'arg1_1': 'branch1.0.bias',
            'arg2_1': 'branch2.0.weight',
            'arg3_1': 'branch2.0.bias',
        },
        inputs_to_buffers={'arg4_1': 'L__self___buffer'},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
    Equality constraints: [(InputDim(input_name='arg5_1', dim=0), InputDim(input_name='arg6_1', dim=0))]

需要注意的一些其他事项:

  • 通过 torch.export.dynamic_dim() API,我们指定了每个输入的第一个 维度是动态的。查看输入 arg5_1arg6_1,它们具有符号形状 (s0, 64) 和 (s0, 128),而不是 我们作为示例输入传递的 (32, 64) 和 (32, 128) 形状的张量。 s0 是一个符号,表示该维度可以是一系列值。

  • exported_program.range_constraints 描述了图表中每个符号的范围。 在这种情况下,我们看到 s0 的范围是 [2, inf]。由于技术原因,在这里难以解释,它们被假定为不是 0 或 1。这不是一个错误,并不一定意味着 导出的程序在维度为 0 或 1 时无法工作。请参阅 0/1 特殊化问题 以深入了解此主题。

  • exported_program.equality_constraints 描述了哪些维度需要相等。由于我们在约束中指定了每个参数的第一个维度是等价的, (dynamic_dim(example_args[0], 0) == dynamic_dim(example_args[1], 0)), 我们在相等性约束中看到指定 arg5_1 维度 0 和 arg6_1 维度 0 相等的元组。

序列化

要保存ExportedProgram,用户可以使用torch.export.save()torch.export.load() API。一种约定是使用ExportedProgram 并以.pt2文件扩展名保存。

一个示例:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

专业领域

输入形状

如前所述,默认情况下,torch.export 将追踪针对输入张量形状优化的程序,除非通过 torch.export.dynamic_dim() API 指定某个维度为动态。这意味着如果存在依赖形状的控制流,torch.export 将根据给定的示例输入所选择的分支进行优化。例如:

import torch
from torch.export import export

def fn(x):
    if x.shape[0] > 5:
        return x + 1
    else:
        return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(fn, example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 2]):
            add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            return (add,)

The conditional of (x.shape[0] > 5) does not appear in the ExportedProgram because the example inputs have the static shape of (10, 2). Since torch.export specializes on the inputs’ static shapes, the else branch (x - 1) will never be reached. To preserve the dynamic branching behavior based on the shape of a tensor in the traced graph, torch.export.dynamic_dim() will need to be used to specify the dimension of the input tensor (x.shape[0]) to be dynamic, and the source code will need to be rewritten.

非张量输入

torch.export 还根据非 torch.Tensor 的输入值专门化跟踪的图,例如 intfloatboolstr。 然而,我们可能会在不久的将来对此进行更改,不再对基本类型输入进行专门化处理。

例如:

import torch
from torch.export import export

def fn(x: torch.Tensor, const: int, times: int):
    for i in range(times):
        x = x + const
    return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(fn, example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
            add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
            add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
            return (add_2,)

由于整数是专门化的,torch.ops.aten.add.Tensor 操作 都是使用内联常量 1 计算的,而不是 arg1_1。 此外,在 for 循环中使用的 times 迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor 调用在图中“内联”了, 并且输入 arg2_1 从未被使用过。

torch.export的限制

图中断点

由于torch.export是一个一次性过程,用于从PyTorch程序中捕获计算图,因此它最终可能会遇到程序中无法追踪的部分,因为几乎不可能支持追踪所有PyTorch和Python功能。在torch.compile的情况下,不支持的操作将导致“图形中断”,并且不支持的操作将使用默认的Python评估运行。相比之下,torch.export将要求用户提供额外的信息或重写代码的部分以使其可追踪。由于追踪是基于TorchDynamo的,后者在Python字节码级别进行评估,因此与以前的追踪框架相比,所需的重写将显著减少。

当遇到图中断时,ExportDB 是一个很好的资源,可以了解支持和不支持的程序类型,以及如何重写程序以使其可跟踪。

数据/形状依赖的控制流

当形状未被专门化时,图中断也可能出现在数据依赖的控制流中 (if x.shape[0] > 2),因为追踪编译器无法在不为组合爆炸数量的路径生成代码的情况下处理这种情况。在这种情况下,用户将需要使用特殊的控制流操作符重写他们的代码(即将推出!)。

数据相关访问

依赖数据的行为,例如使用张量中的值来构建另一个张量,或者使用张量的值对另一个张量进行切片,也是追踪器无法完全确定的事情。用户需要使用内联约束API torch.export.constrain_as_size()torch.export.constrain_as_value() 重写他们的代码。

运算符缺少元内核

在追踪时,所有操作符都需要一个META实现(或称为“元内核”)。这用于推断该操作符的输入/输出形状。

请注意,为自定义操作注册自定义元内核的官方API目前正在开发中。在最终API完善之前,您可以参考此处的文档。

在不幸的情况下,如果你的模型使用了一个尚未实现元内核的ATen操作符,请提交问题。

了解更多

额外链接供导出用户使用

深入探索PyTorch开发者

API 参考

torch.export.export(f, args, kwargs=None, *, constraints=None)[source]

export() 接收一个任意的 Python 可调用对象(例如 nn.Module、函数或方法),并以提前编译(Ahead-of-Time,AOT)的方式生成一个仅表示该函数张量计算的追踪图。此追踪图随后可以使用不同的输出执行或序列化。(1) 生成的规范化操作符集只包含功能性的 核心 ATen 操作符集 和用户指定的自定义操作符;(2) 已消除所有 Python 控制流和数据结构(某些特定情况除外);以及 (3) 包含一组形状约束,用于证明这种规范化和控制流消除对于未来输入是合理的。

健全性保证

在跟踪过程中,export() 会记录用户程序和底层PyTorch运算符内核所做的与形状相关的假设。 只有当这些假设成立时,输出 ExportedProgram 才被认为是有效的。

在跟踪过程中会做出两种类型的假设

  • 输入张量的形状(不包括值)。

  • 通过.item()或直接索引从中间张量提取的值的范围(下限和上限)。

在捕获图时,必须验证所有假设才能使 export() 成功。具体来说:

  • 对输入张量的静态形状假设会自动进行验证,无需额外的努力。

  • 对输入张量动态形状的假设需要显式 Input Constraint 使用 dynamic_dim() API 构建

  • 对中间值范围的假设需要显式地 Inline Constraint, 构造使用 constrain_as_size()constraint_as_value() API。

如果任何假设无法验证,将引发致命错误。当发生这种情况时, 错误信息将包括建议的代码,用于构建必要的约束以验证这些假设,例如 export() 将建议 以下代码用于输入约束:

def specify_constraints(x):
    return [
        # x:
        dynamic_dim(x, 0) <= 5,
    ]

此示例表示程序要求输入的 x 维度 0 必须小于或等于 5 才有效。您可以检查所需的约束条件,然后将此精确函数复制到您的代码中,以生成需要传递给 constraints 参数的约束条件。

Parameters
  • f (Callable) – 要跟踪的可调用对象。

  • args (元组[任意类型, ...]) – 示例位置输入。

  • kwargs (可选[字典[字符串, 任意类型]]) – 可选的示例关键字输入。

  • constraints (Optional[List[Constraint]]) – 一个可选的约束列表,用于指定动态参数可能的形状范围。默认情况下,输入 torch.Tensors 的形状被认为是静态的。如果某个输入 torch.Tensor 预期具有动态形状,请使用 dynamic_dim() 来定义 Constraint 对象,这些对象指定了动态性和可能的形状范围。有关如何使用它的示例,请参见 dynamic_dim() 的文档字符串。

Returns

一个 ExportedProgram 包含被跟踪的可调用对象。

Return type

ExportedProgram

可接受的输入/输出类型

可接受的输入类型(对于argskwargs)和输出包括:

  • 基本类型,即torch.Tensorintfloatboolstr

  • (嵌套) 包含 dict, list, tuple, namedtupleOrderedDict 的数据结构,包含以上所有类型。

torch.export.dynamic_dim(t, index)[source]

dynamic_dim() 构造一个 Constraint 对象,该对象描述了张量 t 的维度 index 的动态性。应将 Constraint 对象传递给 export()constraints 参数。

Parameters
  • t (torch.Tensor) – 示例输入张量,具有动态维度大小

  • 索引 (int) – 动态维度的索引

Returns

一个 Constraint 对象,用于描述形状的动态性。它可以传递给 export() , 这样 export() 就不会假设指定张量的静态大小,即保持其动态性作为符号大小, 而不是根据示例跟踪输入的大小进行专门化。

具体来说,dynamic_dim() 可以用来表达以下类型的动态性。

  • 维度的大小是动态且无界的:

    t0 = torch.rand(2, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size rather than always being static size 2
    constraints = [dynamic_dim(t0, 0)]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 维度的大小是动态的,并且有一个下限:

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
    # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) >= 5,
        dynamic_dim(t1, 1) > 2,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 维度的大小是动态的,并且有一个上限:

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
    # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) <= 16,
        dynamic_dim(t1, 1) < 8,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 一个维度的大小是动态的,并且它总是等于另一个动态维度的大小:

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # Sizes of second dimension of t0 and first dimension are always equal
    constraints = [
        dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 混合和匹配以上所有类型,只要它们不表达相互冲突的要求

torch.export.constrain_as_size(symbol, min=None, max=None)[source]

提示 export() 关于一个中间标量值的约束,该值表示张量的形状,以便后续张量构造器可以正确跟踪,因为许多操作符需要对尺寸范围做出假设。

Parameters
  • symbol – 用于应用范围约束的中间标量值(目前仅支持整数)。

  • 最小值 (可选[整数]) – 给定符号的最小可能值(包含)

  • 最大值 (可选[整数]) – 给定符号可能的最大值(包含)

Returns

请提供需要翻译的单词列表。

例如,下面的程序在不使用 constrain_as_size() 来向 export() 提供形状范围提示的情况下无法被正确追踪:

def fn(x):
    d = x.max().item()
    return torch.ones(v)

export() 会引发以下错误:

torch._dynamo.exc.Unsupported: guard on data-dependent symbolic int/float

假设d的实际范围可以在[3, 10]之间,你可以在源代码中添加对 constrain_as_size()的调用,如下所示:

def fn(x):
    d = x.max().item()
    torch.export.constrain_as_size(d, min=3, max=10)
    return torch.ones(d)

有了额外的提示,export() 将能够通过选择 else 分支正确地追踪程序,从而生成以下图表:

graph():
    %arg0_1 := placeholder[target=arg0_1]

    # d = x.max().item()
    %max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,))
    %_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,))

    # Asserting 3 <= d <= 10
    %ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 3))
    %scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,))
    %_assert_async := call_function[target=torch.ops.aten._assert_async.msg](
        args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [3, 10].))
    %le := call_function[target=operator.le](args = (%_local_scalar_dense, 10))
    %scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,))
    %_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg](
        args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [3, 10].))
    %sym_constrain_range_for_size := call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](
        args = (%_local_scalar_dense,), kwargs = {min: 3, max: 10})

    # Constructing new tensor with d
    %full := call_function[target=torch.ops.aten.full.default](
        args = ([%_local_scalar_dense], 1),
        kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})

    ......

警告

如果你的尺寸是动态的,请不要测试尺寸是否等于 0 或 1, 这些操作将静默返回 false 并被跳过

torch.export.constrain_as_value(symbol, min=None, max=None)[source]

提示 export() 关于中间标量值的约束,以便后续检查上述标量值范围的分支行为可以被可靠地追踪。

警告

(请注意,如果中间标量值将被用作尺寸(size),包括作为张量工厂或view的尺寸参数传递时,请改用 constrain_as_size()。)

Parameters
  • symbol – 用于应用范围约束的中间标量值(目前仅支持整数)。

  • 最小值 (可选[整数]) – 给定符号的最小可能值(包含)

  • 最大值 (可选[整数]) – 给定符号可能的最大值(包含)

Returns

请提供需要翻译的单词列表。

例如,以下程序无法正确跟踪:

def fn(x):
    v = x.max().item()
    if v > 1024:
        return x
    else:
        return x * 2

v 是一个依赖于数据的值,其范围假定为 (-inf, inf)。 export() 关于应该选择哪个分支的提示将无法确定 跟踪的分支决策是否正确。因此 export() 会给出以下错误:

torch._dynamo.exc.UserError: Consider annotating your code using
torch.export.constrain_as_size() or torch.export().constrain_as_value() APIs.
It appears that you're trying to get a value out of symbolic int/float whose value
is data-dependent (and thus we do not know the true value.)  The expression we were
trying to evaluate is f0 > 1024 (unhinted: f0 > 1024).

假设v的实际范围在[10, 200]之间,你可以在源代码中添加对 constrain_as_value()的调用,如下所示:

def fn(x):
    v = x.max().item()

    # Give export() a hint
    torch.export.constrain_as_value(v, min=10, max=200)

    if v > 1024:
        return x
    else:
        return x * 2

有了额外的提示,export() 将能够通过选择 else 分支正确地追踪程序,从而生成以下图表:

graph():
    %arg0_1 := placeholder[target=arg0_1]

    # v = x.max().item()
    %max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,))
    %_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,))

    # Asserting 10 <= v <= 200
    %ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 10))
    %scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,))
    %_assert_async := call_function[target=torch.ops.aten._assert_async.msg](
        args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [10, 200].))
    %le := call_function[target=operator.le](args = (%_local_scalar_dense, 200))
    %scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,))
    %_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg](
        args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [10, 200].))
    %sym_constrain_range := call_function[target=torch.ops.aten.sym_constrain_range.default](
        args = (%_local_scalar_dense,), kwargs = {min: 10, max: 200})

    # Always taking `else` branch to multiply elements `x` by 2 due to hints above
    %mul := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {})
    return (mul,)
torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source]

警告

正在积极开发中,保存的文件可能无法在较新版本的PyTorch中使用。

将一个 ExportedProgram 保存到类似文件的对象中。然后可以使用Python API torch.export.load 加载它。

Parameters
  • ep (导出的程序) – 要保存的导出程序。

  • f (Union[str, pathlib.Path, io.BytesIO) – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。

  • extra_files (可选[Dict[str, Any]]) – 从文件名到内容的映射,这些内容将作为f的一部分进行存储。

  • opset_version (可选[Dict[str, int]]) – 操作集名称到该操作集版本的映射

Example:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

ep = torch.export.export(MyModule(), torch.randn(5))

# Save to file
torch.export.save(ep, 'exported_program.pt2')

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]

警告

正在积极开发中,保存的文件可能无法在较新版本的PyTorch中使用。

加载一个 ExportedProgram 之前使用 torch.export.save 保存的。

Parameters
  • ep (导出的程序) – 要保存的导出程序。

  • f (Union[str, pathlib.Path, io.BytesIO) – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。

  • extra_files (可选[Dict[str, Any]]) – 此映射中给出的额外文件名将被加载,其内容将存储在提供的映射中。

  • 预期的opset版本 (可选[Dict[str, int]]) – 操作集名称到预期的操作集版本的映射

Returns

一个 ExportedProgram 对象

Return type

ExportedProgram

Example:

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')

# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
class torch.export.Constraint(*args, **kwargs)[source]

警告

不要直接构造 Constraint,请改用 dynamic_dim()

这表示对输入张量维度的约束,例如,要求它们完全多态或在某个范围内。

class torch.export.ExportedProgram(root, graph, graph_signature, call_spec, state_dict, range_constraints, equality_constraints, module_call_graph, example_inputs=None)[source]

程序包来自 export()。它包含一个 torch.fx.Graph,表示张量计算,一个state_dict包含所有提升参数和缓冲区的张量值,以及各种元数据。

你可以像调用原始可追踪的 export() 一样,以相同的调用约定调用ExportedProgram。

要对图进行转换,请使用.module属性来访问 一个torch.fx.GraphModule。然后你可以使用 FX转换 来重写图。之后,你可以简单地再次使用export() 来构建一个正确的ExportedProgram。

module()[source]

返回一个包含所有参数/缓冲区的独立GraphModule。

Return type

模块

class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[source]
class torch.export.ExportGraphSignature(parameters, buffers, user_inputs, user_outputs, inputs_to_parameters, inputs_to_buffers, buffers_to_mutate, backward_signature, assertion_dep_token=None)[source]

ExportGraphSignature 模型的输入/输出签名是导出图的,这是一个具有更强不变性保证的fx.Graph。

导出图是功能性的,不会通过 getattr 节点访问图中类似参数或缓冲区的“状态”。相反,export() 保证将参数和缓冲区作为输入从图中提取出来。同样,对缓冲区的任何修改也不会包含在图中,而是将修改后的缓冲区值建模为导出图的额外输出。

所有输入和输出的顺序为:

Inputs = [*parameters_buffers, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出了以下模块:

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

生成的图将为:

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

生成的ExportGraphSignature将是:

ExportGraphSignature(
    # Indicates that there is one parameter named `my_parameter`
    parameters=['L__self___my_parameter'],

    # Indicates that there are two buffers, `my_buffer1` and `my_buffer2`
    buffers=['L__self___my_buffer1', 'L__self___my_buffer2'],

    # Indicates that the nodes `arg3_1` and `arg4_1` in produced graph map to
    # original user inputs, ie. x1 and x2
    user_inputs=['arg3_1', 'arg4_1'],

    # Indicates that the node `add_tensor_1` maps to output of original program
    user_outputs=['add_tensor_1'],

    # Indicates that there is one parameter (self.my_parameter) captured,
    # its name is now mangled to be `L__self___my_parameter`, which is now
    # represented by node `arg0_1` in the graph.
    inputs_to_parameters={'arg0_1': 'L__self___my_parameter'},

    # Indicates that there are two buffers (self.my_buffer1, self.my_buffer2) captured,
    # their name are now mangled to be `L__self___my_my_buffer1` and `L__self___my_buffer2`.
    # They are now represented by nodes `arg1_1` and `arg2_1` in the graph.
    inputs_to_buffers={'arg1_1': 'L__self___my_buffer1', 'arg2_1': 'L__self___my_buffer2'},

    # Indicates that one buffer named `L__self___my_buffer2` is mutated during execution,
    # its new value is output from the graph represented by the node named `add_tensor_2`
    buffers_to_mutate={'add_tensor_2': 'L__self___my_buffer2'},

    # Backward graph not captured
    backward_signature=None,

    # Work in progress feature, please ignore now.
    assertion_dep_token=None
)
class torch.export.ArgumentKind(value)[source]

一个枚举。

class torch.export.ArgumentSpec(kind: torch.export.ArgumentKind, value: Any)[source]
class torch.export.ModuleCallSignature(inputs: List[torch.export.ArgumentSpec], outputs: List[torch.export.ArgumentSpec], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[source]
class torch.export.ModuleCallEntry(fqn: str, signature: Union[torch.export.ModuleCallSignature, NoneType] = None)[source]

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源