目录

torch.onnx

ONNX 导出器。

开放神经网络交换 (ONNX) 是一种用于表示机器学习模型的开放式标准格式。torch.onnx 模块可以将 PyTorch 模型导出为 ONNX 格式。然后,该模型可以被任何支持 ONNX 的 运行时环境 使用。

示例:从 PyTorch 到 ONNX 的 AlexNet

这是一个简单的脚本,用于将预训练的AlexNet导出为名为 alexnet.onnx 的ONNX文件。 对 torch.onnx.export 的调用会运行模型一次以追踪其执行过程,然后将追踪后的模型导出到指定文件中:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 文件包含一个二进制 协议缓冲区 其中包含了您导出的模型(在此情况下为 AlexNet)的网络结构和参数。 参数 verbose=True 会导致导出器打印出模型的人类可读表示:

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

You can also verify the output using the ONNX library, which you can install using pip:

pip install onnx

然后,你可以运行:

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您还可以使用支持ONNX的众多运行时之一来运行导出的模型。 例如,在安装ONNX Runtime之后,您可以 加载并运行该模型:

import onnxruntime as ort

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这里有一个更详细的关于导出模型并使用ONNX Runtime运行的教程

追踪与脚本编写

在内部,torch.onnx.export() 需要一个 torch.jit.ScriptModule 而不是 一个 torch.nn.Module。如果传入的模型还不是 ScriptModuleexport() 将使用 追踪 来将其转换为一个:

  • 追踪: 如果 torch.onnx.export() 被调用时传入的 Module 不是一个 ScriptModule,它首先会执行类似于 torch.jit.trace() 的操作,这会使用给定的 args 运行模型一次,并记录该运行过程中发生的所有操作。这意味着如果你的模型是动态的,例如根据输入数据改变行为,导出的 模型将 不会 捕获这种动态行为。 我们建议检查导出的模型并确保操作符看起来合理。追踪会展开循环和 if 语句,导出一个与追踪运行完全相同的静态图。如果你想以动态控制流导出模型,你需要使用 脚本化

  • 脚本化: 通过脚本化编译模型可以保留动态控制流,并适用于不同大小的输入。要使用脚本化:

    • 使用 torch.jit.script() 来生成一个 ScriptModule

    • 调用 torch.onnx.export() 并将 ScriptModule 作为模型。 args 仍然需要, 但它们将仅用于内部生成示例输出,以便捕获输出的类型和形状。不会执行任何追踪操作。

请参阅 TorchScript简介TorchScript 以获取更多详细信息,包括如何组合追踪和脚本来满足不同模型的特定需求。

避免常见陷阱

避免使用 NumPy 和内置的 Python 类型

PyTorch模型可以使用NumPy或Python类型和函数编写,但在追踪过程中,任何NumPy或Python类型的变量(而不是torch.Tensor)都会被转换为常量,如果这些值应根据输入变化,则会产生错误结果。

例如,与其在 numpy.ndarrays 上使用 numpy 函数:

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch.Tensors 上使用 torch 操作符:

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

与其使用 torch.Tensor.item()(它将张量转换为Python内置数字):

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

利用 torch 对单元素张量的隐式类型转换支持:

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 字段可能会产生错误的跟踪记录,从而导致生成不正确的 ONNX 图。 请改用 torch.Tensor.detach()。 (目前正在努力 完全移除 Tensor.data)。

在跟踪模式下使用 tensor.shape 时,避免进行原地操作

在追踪模式下,从 tensor.shape 获得的形状会被追踪为张量,并共享相同的内存。这可能会导致最终输出值不匹配。作为解决方法,在这些情况下避免使用原地操作。例如,在模型中:

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在追踪模式下共享相同的内存。 通过重写原地操作可以避免这种情况:

real_seq_length = real_seq_length + 2

限制条件

类型

  • torch.Tensors,可以简单转换为 torch.Tensors 的数值类型(例如 float、int), 以及这些类型的元组和列表支持作为模型的输入或输出。字典和字符串输入和 输出在 追踪 模式下被接受,但:

    • 任何依赖字典或字符串输入值的计算都将被替换为在一次跟踪执行中看到的常量值

    • 任何输出为字典的将被静默替换为其 值的扁平序列(键将被移除)。例如 {"foo": 1, "bar": 2} 变为 (1, 2)

    • 任何输出为 str 类型的内容都将被静默移除。

  • 某些涉及元组和列表的操作在脚本模式下不受支持,因为ONNX对嵌套序列的支持有限。 特别是将元组附加到列表的操作不受支持。在追踪模式下,嵌套序列将在追踪过程中自动展平。

算子实现的差异

由于操作符的实现存在差异,在不同的运行时上运行导出的模型,可能会产生彼此之间或与 PyTorch 不同的结果。通常这些差异在数值上很小,因此只有在您的应用程序对这些微小差异敏感时,才需要关注这一点。

不支持的张量索引模式

无法导出的张量索引模式列表如下。 如果你在导出一个不包含以下任何不支持模式的模型时遇到问题,请确认你是否使用最新版本的 opset_version 进行导出。

读取 / 获取

当对张量进行读取索引时,不支持以下模式:

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

写入 / 设置

当对张量进行写入操作时,以下索引模式不被支持:

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

添加对运算符的支持

当导出包含不支持操作符的模型时,你会看到类似以下的错误信息:

RuntimeError: ONNX export failed: Couldn't export operator foo

当这种情况发生时,你可以采取以下几种措施:

  1. 更改模型以不使用该运算符。

  2. 创建一个符号函数以转换操作符,并将其注册为自定义符号函数。

  3. 为 PyTorch 做贡献,将相同的符号函数添加到 torch.onnx 本身。

如果你决定实现一个符号函数(我们希望你能将它贡献回 PyTorch!),以下是你可以开始的方式:

ONNX 导出器内部机制

“符号函数”是一种将 PyTorch 操作符分解为一系列 ONNX 操作符组合的函数。

在导出过程中,TorchScript 图中的每个节点(包含一个 PyTorch 操作符)会按照拓扑顺序被导出器访问。 当访问一个节点时,导出器会查找该操作符的已注册符号函数。 符号函数是用 Python 实现的。对于名为 foo 的操作符,其符号函数可能如下所示:

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Adds the ONNX operations representing this PyTorch function by updating the
  graph g with `g.op()` calls.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.

    None if it cannot be converted to ONNX.
  """
  ...

The torch._C types are Python wrappers around the types defined in C++ in ir.h.

添加符号函数的过程取决于操作符的类型。

ATen 操作符

ATen 是 PyTorch 内置的张量库。 如果该操作是 ATen 操作(在 TorchScript 图中以前缀 aten:: 显示),请确保它尚未被支持。

支持的操作符列表

访问自动生成的 支持的TorchScript操作符列表 以了解每个 opset_version 中支持的操作符详情。

添加对 aten 或量化操作符的支持

如果操作符不在上述列表中:

  • torch/onnx/symbolic_opset<version>.py 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。 确保该函数与ATen函数的名称相同,ATen函数可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中声明(这些文件在构建时生成,因此在你检出代码后直到构建PyTorch之前都不会出现)。

  • 默认情况下,第一个参数是ONNX图。 其他参数名称必须与.pyi文件中的名称完全匹配, 因为使用关键字参数进行分发。

  • 在符号函数中,如果运算符在 ONNX标准运算符集中, 我们只需要创建一个节点来表示图中的ONNX运算符。 如果不是,我们可以组合几个具有等效语义的标准运算符来表示ATen运算符。

这是一个处理 ELU 运算符缺失符号函数的示例。

如果我们运行以下代码:

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

我们看到类似这样的内容:

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图表中看到 aten::elu,我们知道这是一个ATen操作符。

我们检查了ONNX运算符列表, 并确认Elu在ONNX中是标准化的。

我们在 elu 中找到 torch/nn/functional.pyi 的特征:

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们向 symbolic_opset9.py 添加以下行:

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

现在 PyTorch 可以导出包含 aten::elu 运算符的模型!

查看 torch/onnx/symbolic_opset*.py 个文件以获取更多示例。

torch.autograd.Functions

如果操作符是 torch.autograd.Function 的子类,有三种方法可以导出它。

静态符号方法

你可以向函数类中添加一个名为 symbolic 的静态方法。它应该返回 表示该函数在 ONNX 中行为的 ONNX 运算符。例如:

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

内联 Autograd 函数

在未为其后续的 torch.autograd.Function 提供静态符号方法的情况下, 或者在未提供将 prim::PythonOp 注册为自定义符号函数的函数时, torch.onnx.export() 会尝试内联与该 torch.autograd.Function 对应的图,使得 该函数被分解为函数内部使用的各个操作符。 只要这些单独的操作符得到支持,导出应该会成功。例如:

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

此模型没有静态符号方法,但导出方式如下:

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果你需要避免 torch.autograd.Function 的内联,你应该使用 operator_export_type 设置为 ONNX_FALLTHROUGHONNX_ATEN_FALLBACK 导出模型。

自定义操作符

如果模型使用了如使用自定义C++运算符扩展TorchScript中描述的自定义C++运算符实现, 您可以按照此示例进行导出:

from torch.onnx import symbolic_helper


# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

你可以将模型导出为一个或多个标准 ONNX 操作的组合,或者作为自定义的 ONNX 操作符。

上面的示例将其导出为“custom_domain”操作集中的自定义运算符。 在导出自定义运算符时,可以使用导出时的 custom_opsets 字典来指定自定义操作集的版本。如果不指定,自定义操作集版本默认为 1。

运行时需要支持自定义操作。请参阅 Caffe2 自定义操作, ONNX Runtime 自定义操作, 或您选择的运行时文档。

一次性发现所有无法转换的 ATen 操作

当由于无法转换的 ATen 操作导致导出失败时,实际上可能存在多个这样的操作,但错误信息只会提到第一个。若要一次性发现所有无法转换的操作,你可以:

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

该集合是近似值,因为在转换过程中可能会移除一些操作(ops),这些操作不需要转换。另外一些操作可能只支持部分功能,使用特定输入时会导致转换失败,但这也应该能让你大致了解哪些操作不受支持。如需请求操作支持,请随时在 GitHub 上提交问题。

常见问题解答

Q: 我已经导出了我的LSTM模型,但它的输入大小似乎被固定了?

The tracer records the shapes of the example inputs. If the model should accept inputs of dynamic shapes, set dynamic_axes when calling torch.onnx.export().

Q: 如何导出包含循环的模型?

Q: 如何导出带有原始类型输入(例如 int、float)的模型?

Support for primitive numeric type inputs was added in PyTorch 1.9. However, the exporter does not support models with str inputs.

Q: ONNX 是否支持隐式的标量数据类型转换?

The ONNX standard does not, but the exporter will try to handle that part. Scalars are exported as constant tensors. The exporter will figure out the right data type for scalars. In rare cases when it is unable to do so, you will need to manually specify the datatype with e.g. dtype=torch.float32. If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).

Q: Tensor 列表可以导出到 ONNX 吗?

Yes, for opset_version >= 11, since ONNX introduced the Sequence type in opset 11.

贡献 / 开发

开发者文档.

函数

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False)[source]

将模型导出为 ONNX 格式。

如果 model 不是 torch.jit.ScriptModule 也不是 torch.jit.ScriptFunction,则会运行 model 一次以将其转换为可导出的 TorchScript 图(相当于 torch.jit.trace())。因此,这与 torch.jit.trace() 一样,对动态控制流的支持也有限。

Parameters:
  • 模型 (torch.nn.Module, torch.jit.ScriptModuletorch.jit.ScriptFunction) – 要导出的模型。

  • 参数 (元组torch.Tensor) –

    参数可以以以下两种方式之一进行结构化:

    1. 仅包含一组参数:

      args = (x, y, z)
      

    该元组应包含模型输入,使得 model(*args) 是对模型的有效调用。任何非张量参数将被硬编码到导出的模型中;任何张量参数将成为导出模型的输入,按照它们在元组中出现的顺序。

    1. 一个张量:

      args = torch.Tensor([1])
      

    这相当于该 Tensor 的一个一元元组。

    1. 以一个包含命名参数字典的参数元组结尾:

      args = (
          x,
          {
              "y": input_y,
              "z": input_z
          }
      )
      

    元组中除最后一个元素外的所有元素都将作为非关键字参数传递, 而命名参数将从最后一个元素中设置。如果字典中没有某个命名参数, 则将其赋值为默认值,如果没有提供默认值,则赋值为 None。

    注意

    如果一个字典是 args 元组的最后一个元素,它将被解释为包含命名参数。如果要将字典作为最后一个非关键字参数传递,请在 args 元组的最后一个元素中提供一个空字典。例如,不要使用:

    torch.onnx.export(
        model,
        (
            x,
            # WRONG: will be interpreted as named arguments
            {y: z}
        ),
        "test.onnx.pb"
    )
    

    Write:

    torch.onnx.export(
        model,
        (
            x,
            {y: z},
            {}
        ),
        "test.onnx.pb"
    )
    

  • f (Union[str, BytesIO]) – 一个类似文件的对象(例如 f.fileno() 返回文件描述符) 或包含文件名的字符串。二进制协议缓冲区将写入此文件。

  • export_params (bool, default True) – 如果为 True,将导出所有参数。如果你想导出一个未训练的模型,请将此值设为 False。 在这种情况下,导出的模型将首先将其所有参数作为参数传入,顺序由 model.state_dict().values() 指定

  • verbose (bool, 默认 False) – 如果为 True,会将导出模型的描述打印到标准输出。此外,最终的 ONNX 图将包含从导出模型中获取的字段 doc_string`,该字段提到了 model 的源代码位置。如果为 True,将启用 ONNX 导出器日志记录。

  • 训练 (枚举, 默认 TrainingMode.EVAL) –

    • TrainingMode.EVAL: 以推理模式导出模型。

    • TrainingMode.PRESERVE: export the model in inference mode if model.training is

      在评估模式下为 False,在训练模式下(当 model.training 为 True 时)为 True。

    • TrainingMode.TRAINING: export the model in training mode. Disables optimizations

      这可能会干扰训练。

  • input_names (str 列表, 默认为空列表) – 按顺序分配给图输入节点的名称。

  • output_names (str 列表, 默认为空列表) – 按顺序分配给图输出节点的名称。

  • operator_export_type (枚举, 默认 OperatorExportTypes.ONNX) –

    • OperatorExportTypes.ONNX: Export all ops as regular ONNX ops

      (在默认的 opset 域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH: Try to convert all ops

      to standard ONNX ops in the default opset domain. If unable to do so (e.g. because support has not been added to convert a particular torch op to ONNX), fall back to exporting the op into a custom opset domain without conversion. Applies to custom ops as well as ATen ops. For the exported model to be usable, the runtime must support these non-standard ops.

    • OperatorExportTypes.ONNX_ATEN: All ATen ops (in the TorchScript namespace “aten”)

      are exported as ATen ops (in opset domain “org.pytorch.aten”). ATen is PyTorch’s built-in tensor library, so this instructs the runtime to use PyTorch’s implementation of these ops.

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

      这在操作符实现中的数值差异导致 PyTorch 和 Caffe2 之间行为出现较大差异时可能会有帮助(这种情况在未训练的模型中更为常见)。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK: Try to export each ATen op

      (在 TorchScript 命名空间“aten”中)作为常规的 ONNX 操作。如果我们无法做到这一点 (例如,因为尚未添加将特定 torch 操作转换为 ONNX 的支持), 则回退到导出一个 ATen 操作。有关 OperatorExportTypes.ONNX_ATEN 的背景信息,请参阅相关文档。 例如:

      graph(%0 : Float):
      %3 : int = prim::Constant[value=0]()
      # conversion unsupported
      %4 : Float = aten::triu(%0, %3)
      # conversion supported
      %5 : Float = aten::mul(%4, %0)
      return (%5)
      

      假设 aten::triu 不被 ONNX 支持,这将被导出为:

      graph(%0 : Float):
      %1 : Long() = onnx::Constant[value={0}]()
      # not converted
      %2 : Float = aten::ATen[operator="triu"](%0, %1)
      # converted
      %3 : Float = onnx::Mul(%2, %0)
      return (%3)
      

      如果PyTorch是使用Caffe2构建的(即使用BUILD_CAFFE2=1),那么将启用Caffe2特有的行为,包括对由量化模块描述的操作的支持。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

  • opset_version (int, 默认值 14) – 要针对的 默认 (ai.onnx) opset 版本。必须 >= 7 且 <= 16。

  • do_constant_folding (bool, default True) – 应用常量折叠优化。 常量折叠将用预先计算的常量节点替换所有输入均为常量的操作。

  • dynamic_axes (dict[string, dict[int, string]] or dict[string, list(int)], 默认为空字典) –

    默认情况下,导出的模型将把所有输入和输出张量的形状设置为与 args 中给出的完全一致。若要指定张量的某些轴为动态(即仅在运行时才知道),请将 dynamic_axes 设置为具有以下模式的字典:

    • KEY (str): an input or output name. Each name must also be provided in input_names or

      output_names.

    • VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a

      列表,每个元素是一个轴索引。

    例如:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"]
    )
    

    Produces:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    While:

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # dict value: manually named axes
            "x": {0: "my_custom_axis_name"},
            # list value: automatic names
            "sum": [0],
        }
    )
    

    Produces:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (bool, 默认 None) –

    如果为 True,则导出图中的所有初始化器(通常对应参数)也将作为图的输入添加。如果为 False,则初始化器不会作为图的输入添加,仅非参数输入会被添加为输入。 这可能会允许后端/运行时进行更好的优化(例如常量折叠)。

    如果 opset_version < 9,初始化器必须是图输入的一部分,此参数将被忽略,行为将等同于将此参数设置为 True。

    如果为 None,则行为将按以下方式自动选择:

    • If operator_export_type=OperatorExportTypes.ONNX, the behavior is equivalent

      到将此参数设置为 False。

    • 否则,此参数的行为等同于将其设置为 True。

  • custom_opsets (dict[str, int], 默认空字典) –

    一个具有模式的字典:

    • KEY (str): opset 域名

    • VALUE (int): opset 版本

    如果自定义 opset 被 model 引用但未在此字典中提及, 则 opset 版本将被设置为 1。仅应通过此参数指定自定义 opset 的域名和版本。

  • export_modules_as_functions (boolset of python:type of nn.Module, 默认 False) –

    启用标志以将所有 nn.Module 前向调用作为 ONNX 中的本地函数导出。或者是一个集合,用于指定要作为 ONNX 中本地函数导出的特定模块类型。 此功能需要 opset_version >= 15,否则导出将失败。这是因为 opset_version < 15 表示 IR 版本 < 8,这意味着不支持本地函数。 模块变量将作为函数属性导出。函数属性分为两类。

    1. 带注释的属性:通过PEP 526风格进行类型注解的类变量将作为属性导出。 带注释的属性不会在ONNX本地函数的子图中使用,因为它们不是由PyTorch JIT跟踪创建的,但它们可能被消费者用于确定是否用特定融合内核替换该函数。

    2. 推断属性:在模块内部操作符中使用的变量。属性名称将带有前缀“inferred::”。这是为了与从Python模块注解中获取的预定义属性区分开来。推断属性用于ONNX本地函数的子图内部。

    • False (默认): 导出 nn.Module 次前向调用为细粒度节点。

    • True: 导出所有 nn.Module 前向调用作为本地函数节点。

    • Set of type of nn.Module: export nn.Module forward calls as local function nodes,

      仅当 nn.Module 的类型在集合中找到时。

Raises:
  • torch.onnx.errors.CheckerError – 如果 ONNX 检查器检测到无效的 ONNX 图。

  • torch.onnx.errors.UnsupportedOperatorError – 如果无法导出ONNX图,因为它使用了导出器不支持的操作符。

  • torch.onnx.errors.OnnxExporterError – 导出过程中可能发生其他错误。 所有错误都是 errors.OnnxExporterError 的子类。

torch.onnx.export_to_pretty_string(model, args, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, export_type=None, google_printer=False, opset_version=None, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True, dynamic_axes=None)[source]

export() 类似,但返回 ONNX 模型的文本表示形式。仅列出参数的不同之处。其他所有参数与 export() 相同。

Parameters:
  • add_node_names (bool, 默认 True) – 是否设置 NodeProto.name。除非 google_printer=True,否则这不会有任何影响。

  • google_printer (bool, 默认 False) – 如果为 False,将返回模型的自定义、紧凑表示形式。如果为 True,将返回 protobuf 的 Message::DebugString(),这会更加详细。

Returns:

一个包含可读的ONNX模型表示的UTF-8字符串。

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]

为自定义操作符注册一个符号函数。

当用户为自定义/贡献操作符注册符号时, 强烈建议通过setType API为该操作符添加形状推断, 否则在某些极端情况下,导出的图可能具有错误的形状推断。 setType的一个示例是 test_aten_embedding_2test_operators.py 中。

参见模块文档中的“自定义操作符”部分以了解示例用法。

Parameters:
  • symbolic_name (str) – 自定义操作符的名称,格式为“<domain>::<op>” 。

  • symbolic_fn (Callable) – 一个函数,它接收ONNX图和当前操作符的输入参数,并返回要添加到图中的新操作符节点。

  • opset_version (int) – 注册时使用的ONNX opset版本。

torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[source]

取消注册 symbolic_name

参见模块文档中的“自定义操作符”部分以了解示例用法。

Parameters:
  • symbolic_name (str) – 自定义操作符的名称,格式为“<domain>::<op>” 。

  • opset_version (int) – 要注销的ONNX opset版本。

torch.onnx.select_model_mode_for_export(model, mode)[source]

一个上下文管理器,用于临时将 model 设置为 mode 模式,在退出 with 块时将其重置。

Parameters:
  • 模型 – 同类型和含义与 model 参数传递给 export()

  • 模式 (训练模式) – 与 training 参数传递给 export() 的类型和含义相同。

torch.onnx.is_in_onnx_export()[source]

返回是否正处于 ONNX 导出过程中。

Return type:

布尔

torch.onnx.enable_log()[source]

启用 ONNX 日志记录。

torch.onnx.disable_log()[source]

禁用 ONNX 日志记录。

JitScalarType

在 torch 中定义的标量类型。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源