目录

torch.onnx

开放神经网络交换 (ONNX) 是一种用于表示机器学习模型的开放式标准格式。torch.onnx 模块可以将 PyTorch 模型导出为 ONNX 格式。然后,该模型可以被任何支持 ONNX 的 运行时环境 使用。

示例:从 PyTorch 到 ONNX 的 AlexNet

这是一个简单的脚本,用于将预训练的AlexNet导出为名为 alexnet.onnx 的ONNX文件。 对 torch.onnx.export 的调用会运行模型一次以追踪其执行过程,然后将追踪后的模型导出到指定文件中:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 文件包含一个二进制 协议缓冲区 其中包含了您导出的模型(在此情况下为 AlexNet)的网络结构和参数。 参数 verbose=True 会导致导出器打印出模型的人类可读表示:

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

你也可以使用 ONNX 库来验证输出结果, 你可以通过 conda 安装它:

conda install -c conda-forge onnx

然后,你可以运行:

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您还可以使用支持ONNX的众多运行时之一来运行导出的模型。 例如,在安装ONNX Runtime之后,您可以 加载并运行该模型:

import onnxruntime as ort

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这里有一个更详细的关于导出模型并使用ONNX Runtime运行的教程

追踪与脚本编写

在内部,torch.onnx.export() 需要一个 torch.jit.ScriptModule 而不是 一个 torch.nn.Module。如果传入的模型还不是 ScriptModuleexport() 将使用 追踪 来将其转换为一个:

  • 追踪: 如果 torch.onnx.export() 被调用时传入的 Module 不是一个 ScriptModule,它会首先执行类似于 torch.jit.trace() 的操作,这会使用给定的 args 运行模型一次,并记录该运行过程中发生的所有操作。这意味着如果你的模型是动态的,例如根据输入数据改变行为,导出的 模型将 不会 捕获这种动态行为。同样,追踪可能仅适用于特定的输入尺寸。我们建议检查导出的模型并确保操作符看起来合理。追踪会展开循环和 if 语句,导出一个静态图,与追踪运行完全相同。如果你想以动态控制流导出模型,你需要使用 脚本化

  • 脚本化: 通过脚本化编译模型可以保留动态控制流,并适用于不同大小的输入。要使用脚本化:

    • 使用 torch.jit.script() 来生成一个 ScriptModule

    • 调用 torch.onnx.export() 并将 ScriptModule 作为模型,设置 example_outputs 参数。这是必需的,这样可以在不执行模型的情况下捕获输出的类型和形状。

请参阅 TorchScript简介TorchScript 以获取更多详细信息,包括如何组合追踪和脚本来满足不同模型的特定需求。

避免常见陷阱

避免使用 NumPy 和内置的 Python 类型

PyTorch模型可以使用NumPy或Python类型和函数编写,但在追踪过程中,任何NumPy或Python类型的变量(而不是torch.Tensor)都会被转换为常量,如果这些值应根据输入变化,则会产生错误结果。

例如,与其在 numpy.ndarrays 上使用 numpy 函数:

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch.Tensors 上使用 torch 操作符:

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

而且,与其使用 torch.Tensor.item()(它将张量转换为Python内置数字):

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

利用 torch 对单元素张量的隐式类型转换支持:

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 字段可能会产生错误的跟踪记录,从而导致生成不正确的 ONNX 图。 请改用 torch.Tensor.detach()。 (目前正在努力 完全移除 Tensor.data)。

在跟踪模式下使用 tensor.shape 时,避免进行原地操作

在追踪模式下,从 tensor.shape 获得的形状值会被追踪为张量, 并且共享相同的内存。这可能会导致最终输出值的不匹配。 作为一种解决方法,在这些情况下避免使用原地操作。 例如,在模型中:

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在追踪模式下共享相同的内存。 通过重写原地操作可以避免这种情况:

real_seq_length = real_seq_length + 2

限制条件

类型

  • 仅支持torch.Tensors、可以简单转换为torch.Tensors的数值类型(例如float、int)以及这些类型的元组和列表作为模型输入或输出。在追踪模式下接受字典和字符串输入和输出,但:

    • 任何依赖于字典或字符串输入值的计算都将被替换为在一次跟踪执行过程中观察到的常数值。

    • 任何输出为字典的内容将被静默替换为其值的扁平化序列(键将被移除)。例如 {"foo": 1, "bar": 2} 变为 (1, 2)

    • 任何输出为 str 类型的内容都将被静默移除。

  • 某些涉及元组和列表的操作在脚本模式下不受支持,因为ONNX对嵌套序列的支持有限。 特别是将元组附加到列表的操作不受支持。在追踪模式下,嵌套序列将在追踪过程中自动展平。

算子实现的差异

由于操作符的实现存在差异,在不同的运行时上运行导出的模型,可能会产生彼此之间或与 PyTorch 不同的结果。通常这些差异在数值上很小,因此只有在您的应用程序对这些微小差异敏感时,才需要关注这一点。

不支持的张量索引模式

无法导出的张量索引模式列表如下。 如果你在导出一个不包含以下任何不支持模式的模型时遇到问题,请确认你是否使用最新版本的 opset_version 进行导出。

读取 / 获取

当对张量进行读取索引时,不支持以下模式:

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

写入 / 设置

当对张量进行写入操作时,以下索引模式不被支持:

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

添加对运算符的支持

当导出包含不支持操作符的模型时,你会看到类似以下的错误信息:

RuntimeError: ONNX export failed: Couldn't export operator foo

当发生这种情况时,你需要要么修改模型以不使用该运算符, 要么为该运算符添加支持。

添加对运算符的支持需要对PyTorch的源代码进行贡献。 请参阅 CONTRIBUTING 以获取一般性指导,以下部分则提供针对支持运算符所需代码更改的具体说明。

在导出过程中,TorchScript 图中的每个节点都会按拓扑顺序被访问。 访问一个节点时,导出器会尝试查找该节点的已注册符号函数。 符号函数是用 Python 实现的。对于名为 foo 的操作,其符号函数可能如下所示:

def foo(
  g: torch._C.Graph,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Modifies g (e.g., using "g.op()"), adding the ONNX operations representing
  this PyTorch function.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.
    Returns None if it cannot be converted to ONNX.
  """
  ...

The torch._C types are Python wrappers around the types defined in C++ in ir.h.

添加符号函数的过程取决于操作符的类型。

ATen 操作符

ATen 是 PyTorch 内置的张量库。 如果操作符是 ATen 操作符(在 TorchScript 图中以前缀 aten:: 显示):

  • torch/onnx/symbolic_opset<version>.py 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。 确保该函数与ATen函数的名称相同,ATen函数可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中声明(这些文件在构建时生成,因此在你检出代码后直到构建PyTorch之前都不会出现)。

  • 第一个参数始终是要导出的ONNX图。 其他参数名称必须与.pyi文件中的名称完全匹配, 因为使用关键字参数进行分发。

  • 在符号函数中,如果运算符在 ONNX标准运算符集中, 我们只需要创建一个节点来表示图中的ONNX运算符。 如果不是,我们可以创建一个由几个具有等效语义的标准运算符组成的图来表示ATen运算符。

  • 如果输入参数是一个张量,但 ONNX 需要一个标量,我们必须显式地进行转换。symbolic_helper._scalar() 可以将标量张量转换为 Python 标量,而 symbolic_helper._if_scalar_type_as() 可以将 Python 标量转换为 PyTorch 张量。

这是一个处理 ELU 运算符缺失符号函数的示例。

如果我们运行以下代码:

print(
  torch.jit.trace(torch.nn.ELU(), # module
                  torch.ones(1)   # example input
                  ).graph)

我们看到类似这样的内容:

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图表中看到 aten::elu,我们知道这是一个ATen操作符。

我们检查了ONNX运算符列表, 并确认Elu在ONNX中是标准化的。

我们在 elu 中找到 torch/nn/functional.pyi 的特征:

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们向 symbolic_opset9.py 添加以下行:

def elu(g, input, alpha, inplace=False):
    return g.op("Elu", input, alpha_f=_scalar(alpha))

现在 PyTorch 可以导出包含 aten::elu 运算符的模型!

查看 symbolic_opset*.py 个文件以获取更多示例。

torch.autograd.Functions

如果操作符是 torch.autograd.Function 的子类,有两种方法可以导出它。

静态符号方法

你可以向函数类中添加一个名为 symbolic 的静态方法。它应该返回 表示该函数在 ONNX 中行为的 ONNX 运算符。例如:

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch._C.graph, input: torch._C.Value) -> torch._C.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

PythonOp 符号

或者,您可以注册一个自定义的符号函数。 这使得符号函数可以通过原始操作的 TorchScript Node 对象获取更多信息,该对象作为第二个参数传入(在 Graph 对象之后)。

所有 autograd Function 都会以 prim::PythonOp 节点的形式出现在 TorchScript 图中。 为了区分不同的 Function 子类,符号函数应使用 name 参数,该参数会被设置为类的名称。

自定义符号函数应在返回 Value 对象之前,通过调用 setType(...) 为 Value 对象添加类型和形状信息(由 C++ 中的 torch::jit::Value::setType 实现)。这并非必需,但它可以帮助导出器对下游节点进行形状和类型推断。关于 setType 的非平凡示例,请参见 test_aten_embedding_2test_operators.py 中的实现。

以下示例展示了如何通过 requires_grad 访问 Node 对象:

class MyClip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min):
        ctx.save_for_backward(input)
        return input.clamp(min=min)

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

def symbolic_python_op(g: torch._C.Graph, n: torch._C.Node, *args, **kwargs):
    print("original node: ", n)
    for i, out in enumerate(n.outputs()):
        print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
    import torch.onnx.symbolic_helper as sym_helper
    for i, arg in enumerate(args):
        requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
        print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))

    name = kwargs["name"]
    ret = None
    if name == "MyClip":
        ret = g.op("Clip", args[0], args[1])
    elif name == "MyRelu":
        ret = g.op("Relu", args[0])
    else:
        # Logs a warning and returns None
        return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
    # Copy type and shape from original node.
    ret.setType(n.type())
    return ret

from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)

自定义操作符

如果模型使用了如使用自定义C++运算符扩展TorchScript中描述的自定义C++运算符实现, 您可以按照此示例进行导出:

from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args

# Define custom symbolic function
@parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)

# Register custom symbolic function
register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)

class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super(FooModule, self).__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)

model = FooModel(attr1, attr2)
torch.onnx.export(
  model,
  (example_input1, example_input1),
  "model.onnx",
  # only needed if you want to specify an opset version > 1.
  custom_opsets={"custom_domain": 2})

您可以将其导出为一个或多个标准ONNX操作符,或者作为自定义操作符。 上面的例子将其导出为“custom_domain”操作集中的自定义操作符。 在导出自定义操作符时,您可以在导出时使用custom_opsets字典指定自定义域版本。如果没有指定,默认的自定义操作集版本为1。 消耗模型的运行时需要支持自定义操作符。请参阅 Caffe2自定义操作符, ONNX Runtime自定义操作符, 或您选择的运行时文档。

一次性发现所有无法转换的 ATen 操作

当由于无法转换的 ATen 操作导致导出失败时,实际上可能存在多个这样的操作,但错误信息只会提到第一个。若要一次性发现所有无法转换的操作,你可以:

from torch.onnx import utils as onnx_utils

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = onnx_utils.unconvertible_ops(
    model, args, opset_version=opset_version)

print(set(unconvertible_ops))

常见问题解答

Q: 我已经导出了我的LSTM模型,但它的输入大小似乎被固定了?

The tracer records the shapes of the example inputs. If the model should accept inputs of dynamic shapes, set dynamic_axes when calling torch.onnx.export().

Q: 如何导出包含循环的模型?

Q: 如何导出带有原始类型输入(例如 int、float)的模型?

Support for primitive numeric type inputs was added in PyTorch 1.9. However, the exporter does not support models with str inputs.

Q: ONNX 是否支持隐式的标量数据类型转换?

No, but the exporter will try to handle that part. Scalars are exported as constant tensors. The exporter will try to figure out the right datatype for scalars. However when it is unable to do so, you will need to manually specify the datatype. This often happens with scripted models, where the datatypes are not recorded. For example:

class ImplicitCastType(torch.jit.ScriptModule):
    @torch.jit.script_method
    def forward(self, x):
        # Exporter knows x is float32, will export "2" as float32 as well.
        y = x + 2
        # Currently the exporter doesn't know the datatype of y, so
        # "3" is exported as int64, which is wrong!
        return y + 3
        # To fix, replace the line above with:
        # return y + torch.tensor([3], dtype=torch.float32)

x = torch.tensor([1.0], dtype=torch.float32)
torch.onnx.export(ImplicitCastType(), x, "implicit_cast.onnx",
                  example_outputs=ImplicitCastType()(x))

We are trying to improve the datatype propagation in the exporter such that implicit casting is supported in more cases.

Q: Tensor 列表可以导出到 ONNX 吗?

Yes, for opset_version >= 11, since ONNX introduced the Sequence type in opset 11.

函数

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=None, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False)[source]

将模型导出为ONNX格式。如果 model 既不是 torch.jit.ScriptModule 也不是 torch.jit.ScriptFunction,则会运行 model 一次以将其转换为TorchScript图以便导出 (等同于 torch.jit.trace())。因此,这与 torch.jit.trace() 一样,对动态控制流的支持也有限。

Parameters
  • 模型 (torch.nn.Module, torch.jit.ScriptModuletorch.jit.ScriptFunction) – 要导出的模型。

  • 参数 (元组torch.Tensor) –

    参数可以以以下两种方式之一进行结构化:

    1. 仅包含一组参数:

      args = (x, y, z)
      

    该元组应包含模型输入,使得 model(*args) 是对模型的有效调用。任何非张量参数将被硬编码到导出的模型中;任何张量参数将成为导出模型的输入,按照它们在元组中出现的顺序。

    1. 一个张量:

      args = torch.Tensor([1])
      

    这相当于该 Tensor 的一个一元元组。

    1. 以一个包含命名参数字典的参数元组结尾:

      args = (x,
              {'y': input_y,
               'z': input_z})
      

    元组中除最后一个元素外的所有元素都将作为非关键字参数传递, 而命名参数将从最后一个元素中设置。如果字典中没有某个命名参数, 则将其赋值为默认值,如果没有提供默认值,则赋值为 None。

    注意

    如果一个字典是 args 元组的最后一个元素,它将被解释为包含命名参数。如果要将字典作为最后一个非关键字参数传递,请在 args 元组的最后一个元素中提供一个空字典。例如,不要使用:

    torch.onnx.export(
        model,
        (x,
         # WRONG: will be interpreted as named arguments
         {y: z}),
        "test.onnx.pb")
    

    Write:

    torch.onnx.export(
        model,
        (x,
         {y: z},
         {}),
        "test.onnx.pb")
    

  • f – 一个类似文件的对象(例如 f.fileno() 返回一个文件描述符) 或包含文件名的字符串。二进制协议缓冲区将被写入此文件。

  • export_params (bool, default True) – 如果为 True,将导出所有参数。如果你想导出一个未训练的模型,请将此值设为 False。 在这种情况下,导出的模型将首先将其所有参数作为参数传入,顺序由 model.state_dict().values() 指定

  • verbose (bool, 默认 False) – 如果为 True,会将导出模型的描述打印到标准输出。此外,最终的 ONNX 图将包含从导出模型中获取的字段 doc_string`,其中提到了 model 的源代码位置。

  • 训练 (枚举, 默认 TrainingMode.EVAL) –

    • TrainingMode.EVAL: 以推理模式导出模型。

    • TrainingMode.PRESERVE: 在模型处于推理模式(即 model.training 为 False)时导出模型,在模型处于训练模式(即 model.training 为 True)时进行训练。

    • TrainingMode.TRAINING: 以训练模式导出模型。禁用可能干扰训练的优化选项。

  • input_names (str 列表, 默认为空列表) – 按顺序分配给图输入节点的名称。

  • output_names (str 列表, 默认为空列表) – 按顺序分配给图输出节点的名称。

  • operator_export_type (枚举, 默认 None) –

    None 通常表示 OperatorExportTypes.ONNX. 然而,如果 PyTorch 是使用 -DPYTORCH_ONNX_CAFFE2_BUNDLE 构建的,None 表示 OperatorExportTypes.ONNX_ATEN_FALLBACK

    • OperatorExportTypes.ONNX: 将所有操作导出为常规ONNX操作 (在默认的opset域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH: 尝试将所有操作转换为默认opset域中的标准ONNX操作。如果无法这样做(例如,因为尚未添加将特定torch操作转换为ONNX的支持),则回退到将操作导出到自定义opset域而不进行转换。适用于自定义操作以及ATen操作。对于导出的模型要可用,运行时必须支持这些非标准操作。

    • OperatorExportTypes.ONNX_ATEN: 所有 ATen 操作(在 TorchScript 命名空间 “aten” 中) 都会以 ATen 操作的形式导出(在 opset 域 “org.pytorch.aten” 中)。 ATen 是 PyTorch 内置的张量库,因此 这指示运行时使用 PyTorch 对这些操作的实现。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

      这在操作符实现中的数值差异导致 PyTorch 和 Caffe2 之间行为出现较大差异时可能会有帮助(这种情况在未训练的模型中更为常见)。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK: 尝试将每个ATen操作 (在TorchScript命名空间“aten”中)导出为常规的ONNX操作。如果无法做到这一点 (例如,因为尚未添加将特定torch操作转换为ONNX的支持), 则回退到导出ATen操作。有关OperatorExportTypes.ONNX_ATEN的上下文,请参阅文档。 例如:

      graph(%0 : Float):
        %3 : int = prim::Constant[value=0]()
        # conversion unsupported
        %4 : Float = aten::triu(%0, %3)
        # conversion supported
        %5 : Float = aten::mul(%4, %0)
        return (%5)
      

      假设 aten::triu 不被 ONNX 支持,这将被导出为:

      graph(%0 : Float):
        %1 : Long() = onnx::Constant[value={0}]()
        # not converted
        %2 : Float = aten::ATen[operator="triu"](%0, %1)
        # converted
        %3 : Float = onnx::Mul(%2, %0)
        return (%3)
      

      如果PyTorch是使用Caffe2构建的(即使用BUILD_CAFFE2=1),那么将启用Caffe2特有的行为,包括对由量化模块描述的操作的支持。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

  • opset_version (int, 默认值 9) – 要针对的 默认 (ai.onnx) opset 版本。必须 >= 7 且 <= 15。

  • do_constant_folding (bool, default True) – 应用常量折叠优化。 常量折叠将用预先计算的常量节点替换所有输入均为常量的操作。

  • dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict) –

    默认情况下,导出的模型将把所有输入和输出张量的形状设置为与 args 中给出的完全一致。若要指定张量的某些轴为动态(即仅在运行时才知道),请将 dynamic_axes 设置为具有以下模式的字典:

    • KEY (str): 一个输入或输出名称。每个名称也必须在 input_namesoutput_names 中提供。

    • VALUE (字典或列表):如果是字典,键是轴索引,值是轴名称。如果是列表,每个元素是一个轴索引。

    例如:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"])
    

    Produces:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    While:

    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"],
                      dynamic_axes={
                          # dict value: manually named axes
                          "x": {0: "my_custom_axis_name"},
                          # list value: automatic names
                          "sum": [0],
                      })
    

    Produces:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (bool, 默认 None) –

    如果为 True,则导出图中的所有初始化器(通常对应参数)也将作为图的输入添加。如果为 False,则初始化器不会作为图的输入添加,仅非参数输入会被添加为输入。 这可能会允许后端/运行时进行更好的优化(例如常量折叠)。

    如果 opset_version < 9,初始化器必须是图输入的一部分,此参数将被忽略,行为将等同于将此参数设置为 True。

    如果为 None,则行为将按以下方式自动选择:

    • 如果 operator_export_type=OperatorExportTypes.ONNX,则行为等同 于将此参数设置为 False。

    • 否则,此参数的行为等同于将其设置为 True。

  • custom_opsets (dict<str, int>, default empty dict) –

    一个具有模式的字典:

    • KEY (str): opset 域名

    • VALUE (int): opset 版本

    如果自定义 opset 被 model 引用但未在此字典中提及, 则 opset 版本将被设置为 1。仅应通过此参数指定自定义 opset 的域名和版本。

  • export_modules_as_functions (boolset of python:type of nn.Module, 默认 False) –

    启用标志以将所有 nn.Module 前向调用作为 ONNX 中的本地函数导出。或是一个集合,用于指定要作为 ONNX 中本地函数导出的特定模块类型。 此功能需要 opset_version >= 15,否则导出将失败。这是因为 opset_version < 15 表示 IR 版本 < 8,这意味着不支持本地函数。

    • False``(default): export ``nn.Module 前向调用作为细粒度节点。

    • True: 导出所有 nn.Module 前向调用作为本地函数节点。

    • Set of type of nn.Module: export nn.Module forward calls as local function nodes, only if the type of the nn.Module is found in the set。

Raises

CheckerError – 如果 ONNX 检查器检测到无效的 ONNX 图。即使引发此错误,仍会将模型导出到文件 f

torch.onnx.export_to_pretty_string(*args, **kwargs)[source]

export() 类似,但返回 ONNX 模型的文本表示形式。仅列出参数的不同之处。其他所有参数与 export() 相同。

Parameters
  • add_node_names (bool, 默认 True) – 是否设置 NodeProto.name。除非 google_printer=True,否则这不会有任何影响。

  • google_printer (bool, 默认 False) – 如果为 False,将返回模型的自定义、紧凑表示形式。如果为 True,将返回 protobuf 的 Message::DebugString(),这会更加详细。

Returns

一个包含可读的ONNX模型表示的UTF-8字符串。

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]

Registers symbolic_fn to handle symbolic_name。请参阅模块文档中的“自定义操作符”以了解示例用法。

Parameters
  • symbolic_name (str) – 自定义操作符的名称,格式为“<domain>::<op>” 。

  • symbolic_fn (Callable) – 一个函数,它接收ONNX图和当前操作符的输入参数,并返回要添加到图中的新操作符节点。

  • opset_version (int) – 注册时使用的ONNX opset版本。

torch.onnx.select_model_mode_for_export(model, mode)[source]

一个上下文管理器,用于临时将 model 的训练模式设置为 mode,在退出 with 块时将其重置。如果 mode 为 None,则不执行任何操作。

Parameters
  • 模型 – 同类型和含义与 model 参数传递给 export()

  • 模式 – 与 training 参数传递给 export() 的类型和含义相同。

torch.onnx.is_in_onnx_export()[source]

如果 export() 正在当前线程中运行,则返回 True

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源