目录

torch.onnx

ONNX exporter.

Open Neural Network eXchange (ONNX) 是一种开放标准 格式来表示机器学习模型。torch.onnx 模块可以导出 PyTorch 模型添加到 ONNX。然后,支持 ONNX 的众多运行时中的任何一个都可以使用该模型。

示例:从 PyTorch 到 ONNX 的 AlexNet

下面是一个简单的脚本,它将预训练的 AlexNet 导出到名为 . 对 的调用运行模型一次以跟踪其执行情况,然后将 traced model 到指定的文件:alexnet.onnxtorch.onnx.export

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的文件包含一个二进制协议缓冲区,其中包含您导出的模型的网络结构和参数 (在本例中为 AlexNet)。该参数会导致 exporter 打印出模型的人类可读表示形式:alexnet.onnxverbose=True

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

您还可以使用 ONNX 库验证输出, 您可以使用 :pip

pip install onnx

然后,您可以运行:

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您还可以使用支持 ONNX 的众多运行时之一运行导出的模型。 例如,安装 ONNX 运行时后,您可以 加载并运行模型:

import onnxruntime as ort

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这是一个更复杂的教程,介绍如何导出模型并使用 ONNX 运行时运行它

跟踪与脚本

在内部,需要一个而不是 一个 .如果传入的模型还不是 ,将使用跟踪将其转换为 1:ScriptModuleexport()

  • 跟踪:如果使用还不是 的 Module 调用,它首先执行等效于 ,从而执行模型 once 替换为 given 和 记录该执行期间发生的所有操作。这 表示如果您的模型是动态的,例如,根据输入数据更改行为,则导出的 model 不会捕获此动态行为。 我们建议检查导出的模型并确保运算符查找 合理。跟踪将展开循环和 if 语句,导出一个完全是 与 traced run 相同。如果要使用动态控制流导出模型,则可以 需要使用脚本torch.onnx.export()ScriptModuleargs

  • 脚本:通过脚本编译模型可保留动态控制流,并且对输入有效 大小不一。要使用脚本:

    • 用于生成 .ScriptModule

    • 以 作为模型调用 。他们仍然是必需的, 但它们将仅在内部用于生成示例输出,因此 可以捕获输出。不会执行跟踪。torch.onnx.export()ScriptModuleargs

有关更多详细信息,包括如何编写跟踪和脚本以适应 不同型号的特殊要求。

避免陷阱

避免使用 NumPy 和内置 Python 类型

PyTorch 模型可以使用 NumPy 或 Python 类型和函数编写,但 在跟踪期间,NumPy 或 Python 的任何变量 类型(而不是 torch.Tensor)转换为常量,这将产生 如果这些值应根据输入而变化,则结果错误。

例如,不要在 numpy.ndarrays 上使用 numpy 函数:

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch 上使用 torch 操作器。张:

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

而不是使用 (将 Tensor 转换为 Python 内置编号):

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

使用 torch 对单元素张量的隐式转换的支持:

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 字段可能会产生不正确的跟踪,从而产生不正确的 ONNX 图。 请改用。(完全删除 Tensor.data 的工作正在进行中)。

在追踪模式下使用 tensor.shape 时避免就地操作

在追踪模式下,从 获得的形状被追踪为张量, 并共享相同的内存。这可能会导致最终输出值不匹配。 解决方法是,在这些情况下避免使用就地操作。 例如,在模型中:tensor.shape

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_length并在跟踪模式下共享相同的内存。 这可以通过重写就地操作来避免:seq_length

real_seq_length = real_seq_length + 2

局限性

类型

  • Only ,可以简单地转换为 torch 的数字类型。张量(例如 float、int)、 支持将这些类型的元组和列表作为模型输入或输出。Dict 和 str 输入以及 在跟踪模式下接受输出,但是:torch.Tensors

    • 任何依赖于 dict 或 str 输入的值的计算都将被替换为 constant 值

    • 任何作为 dict 的输出都将被静默替换为其值的扁平化序列 (键将被删除)。例如 成为。{"foo": 1, "bar": 2}(1, 2)

    • 任何作为 str 的输出都将被静默删除。

  • 由于 ONNX 对嵌套序列的支持有限,因此脚本模式不支持某些涉及元组和列表的操作。 特别是,不支持将 Tuples 附加到列表。在跟踪模式下,嵌套序列 将在跟踪期间自动拼合。

Operator 实现的差异

由于 Operator 的实现存在差异,因此在不同的运行时上运行导出的模型 可能会产生彼此不同的结果,也可能产生不同的 PyTorch 结果。通常这些差异是 数值较小,因此仅当应用程序对这些敏感时,才应考虑此问题 微小的差异。

不支持的 Tensor 索引模式

下面列出了无法导出的张量索引模式。 如果您在导出不包含任何 下面不支持的模式,请仔细检查您是否正在使用 最新的 .opset_version

读取 / 获取

当索引到张量中进行读取时,不支持以下模式:

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

写入 / 设置

当索引到 Tensor 中进行写入时,不支持以下模式:

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

添加对运算符的支持

导出包含不支持的运算符的模型时,您将看到如下错误消息:

RuntimeError: ONNX export failed: Couldn't export operator foo

发生这种情况时,您可以执行以下操作:

  1. 更改模型以不使用该运算符。

  2. 创建一个符号函数来转换运算符并将其注册为自定义符号函数。

  3. 为 PyTorch 做出贡献,以向自身添加相同的符号函数

如果您决定实现一个符号函数(我们希望您能将其贡献回 PyTorch!),以下是如何开始:

ONNX exporter internals

“符号函数”是将 PyTorch 运算符分解为 一系列 ONNX 运算符的组合。

在导出过程中,TorchScript 中的每个节点(包含 PyTorch 运算符) graph 由 exporter 按拓扑顺序访问。 访问节点时,导出器会查找已注册的符号函数 那个运算符。符号函数是在 Python 中实现的。一个 名为 的操作 将如下所示:foo

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Adds the ONNX operations representing this PyTorch function by updating the
  graph g with `g.op()` calls.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.

    None if it cannot be converted to ONNX.
  """
  ...

这些类型是 ir.h 中 C++ 中定义的类型的 Python 包装器。torch._C

添加符号函数的过程取决于运算符的类型。

ATen 运算符

ATen 是 PyTorch 的内置张量库。 如果运算符是 ATen 运算符(在 TorchScript 图形中显示为前缀 ),请确保它尚未受支持。aten::

支持的运算符列表

有关每个运算符支持哪些运算符的详细信息,请访问自动生成的受支持TorchScript运算符列表opset_version

添加对 aten 或量化运算符的支持

如果运算符不在上面的列表中:

  • 在 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。 确保该函数与 ATen 函数具有相同的名称,该函数可以在 或 (这些文件在 构建时,因此在您构建 PyTorch 之前不会出现在您的检出中)。torch/onnx/symbolic_opset<version>.pytorch/_C/_VariableFunctions.pyitorch/nn/functional.pyi

  • 默认情况下,第一个参数是 ONNX 图。 其他 arg 名称必须与文件中的名称完全匹配, 因为 dispatch 是通过 keyword 参数完成的。.pyi

  • 在 symbolic 函数中,如果运算符在 ONNX 标准运算符集中,则 我们只需要创建一个节点来表示图中的 ONNX 运算符。 如果没有,我们可以组合几个标准运算符,它们具有 与 ATen 运算符等效的语义。

下面是处理运算符缺少的符号函数的示例。ELU

如果我们运行以下代码:

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

我们看到类似以下内容:

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图中看到,我们知道这是一个 ATen 运算符。aten::elu

我们检查 ONNX 运算符列表, 并确认 ONNX 中已标准化。Elu

我们在 中找到 的签名 :elutorch/nn/functional.pyi

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们将以下行添加到 :symbolic_opset9.py

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

现在 PyTorch 能够导出包含 Operator 的模型!aten::elu

有关更多示例,请参阅文件。torch/onnx/symbolic_opset*.py

torch.autograd.函数

如果运算符是 的子类,则有三种方法 以导出它。

静态符号方法

您可以向函数类添加名为 的静态方法。它应该返回 表示函数在 ONNX 中的行为的 ONNX 运算符。例如:symbolic

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

内联 autograd 函数

如果 static(静态)符号方法未为其后续 或 如果未提供注册为自定义符号函数的函数,则尝试内联与该函数相对应的图形,以便 此函数被分解为函数中使用的单个运算符。 只要支持这些单独的运算符,导出就应该成功。例如:prim::PythonOp

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

此模型不存在 static symbolic method,但导出如下:

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果需要避免内联 ,则应将模型导出为 设置为 或 。operator_export_typeONNX_FALLTHROUGHONNX_ATEN_FALLBACK

自定义运算符

如果模型使用以 C++ 实现的自定义运算符,如使用自定义 C++ 运算符扩展 TorchScript 中所述, 您可以按照以下示例导出它:

from torch.onnx import symbolic_helper


# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

您可以将模型导出为多个标准 ONNX 运算的一个或组合,或导出为自定义 ONNX 运算符。

上面的示例将其导出为 “custom_domain” opset 中的自定义运算符。 导出自定义运算符时,您可以在导出时使用字典指定自定义域版本。如果未指定,则自定义 opset 版本默认为 1。custom_opsets

使用模型的运行时需要支持自定义运算。请参阅 Caffe2 自定义操作ONNX 运行时自定义操作、 或您选择的运行时的文档。

一次发现所有不可转换的 ATen 运算

当由于不可转换的 ATen 操作而导致导出失败时,实际上可能会有更多 而不是一个这样的 OP,但错误消息只提到了第一个。探索 一次性完成所有不可转换的操作,您可以:

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

该集合是近似的,因为在转换过程中可能会删除一些操作 处理,无需转换。其他一些 Ops 可能提供部分支持 这将导致特定输入的转换失败,但这应该会为您提供一个 不支持哪些 OPS 的一般概念。请随时打开 Github Issues 用于 OP 支持请求。

常见问题解答

Q: 我已经导出了 LSTM 模型,但其输入大小似乎是固定的?

跟踪器记录示例输入的形状。如果模型应该接受 动态形状的输入,在调用 时设置。dynamic_axes

Q: 如何导出包含 Loop 的模型?

Q: 如何导出具有基元类型输入(例如 int、float)的模型?

PyTorch 1.9 中添加了对基元数字类型输入的支持。 但是,导出器不支持具有 str 输入的模型。

问:ONNX 是否支持隐式标量数据类型转换?

ONNX 标准没有,但导出器将尝试处理该部分。 标量导出为常量张量。 导出器将为标量找出正确的数据类型。在极少数情况下,当它无法时 为此,您需要手动指定数据类型,例如 dtype=torch.float32。 如果您看到任何错误,请 [创建 GitHub issue](https://github.com/pytorch/pytorch/issues)。

问:张量列表是否可以导出到 ONNX?

是的,对于 >= 11,因为 ONNX 在操作集 11 中引入了 Sequence 类型。opset_version

贡献 / 开发

开发人员文档

功能

torch.onnx 中。exportmodelargsfexport_params=Trueverbose=Falsetraining=<TrainingMode.EVAL: 0>input_names=无output_names=无operator_export_type=<OperatorExportTypes.ONNX: 0>opset_version=无do_constant_folding=真dynamic_axes=无keep_initializers_as_inputs=无custom_opsets=无export_modules_as_functions=假[来源]

将模型导出为 ONNX 格式。

如果不是 a 或 a ,则运行一次,以便将其转换为要导出的 TorchScript 图形 (相当于 )。因此,这具有相同的有限支持 对于动态控制流,如 .modelmodel

参数
  • model ) – 要导出的模型。

  • args元组torch.张量) –

    ARGS 的结构可以是:

    1. 只有一个参数元组:

      args = (x, y, z)
      

    元组应包含模型输入,以便 调用模型。任何非 Tensor 参数都将被硬编码到 导出的模型;任何 Tensor 参数都将成为导出模型的输入, 按照它们在 Tuples 中出现的顺序。model(*args)

    1. 一个张量:

      args = torch.Tensor([1])
      

    这相当于该 Tensor 的 1-ary 元组。

    1. 以命名参数字典结尾的参数元组:

      args = (
          x,
          {
              "y": input_y,
              "z": input_z
          }
      )
      

    除了元组的最后一个元素之外,所有元素都将作为非关键字参数传递。 命名参数将从最后一个元素开始设置。如果命名参数为 不存在,则为其分配默认值,如果 default 值。

    注意

    如果字典是 args 元组的最后一个元素,则它将是 解释为包含命名参数。为了将 dict 作为 last 非关键字 arg,则提供一个空 dict 作为 args 的最后一个元素 元。例如,而不是:

    torch.onnx.export(
        model,
        (
            x,
            # WRONG: will be interpreted as named arguments
            {y: z}
        ),
        "test.onnx.pb"
    )
    

    写:

    torch.onnx.export(
        model,
        (
            x,
            {y: z},
            {}
        ),
        "test.onnx.pb"
    )
    

  • fUnion[strBytesIO]) – 一个类似文件的对象(这样返回一个文件描述符) 或包含文件名的字符串。将写入二进制协议缓冲区 添加到此文件。f.fileno()

  • export_paramsbooldefault True) – 如果为 True,则所有参数都将 被导出。如果要导出未经训练的模型,请将此项设置为 False。 在这种情况下,导出的模型将首先采用其所有参数 作为参数,其顺序由model.state_dict().values()

  • verbosebooldefault False) – 如果为 True,则打印 模型正在导出到 stdout 中。此外,最终的 ONNX 图将包括 字段,其中提到了源代码位置 为。如果为 True,则将打开 ONNX 导出器日志记录。doc_string`model

  • trainingenum默认 TrainingMode.EVAL)

    • TrainingMode.EVAL:在推理模式下导出模型。

    • TrainingMode.PRESERVE:如果 model.training 为

      如果 model.training 为 True,则为 False,则处于训练模式。

    • TrainingMode.TRAINING:在训练模式下导出模型。禁用优化

      这可能会干扰训练。

  • input_nameslist of strdefault empty list) – 要分配给 input 节点。

  • output_nameslist of strdefault empty list) – 要分配给 output 节点。

  • operator_export_typeenum默认 OperatorExportTypes.ONNX) –

    • OperatorExportTypes.ONNX:将所有操作导出为常规 ONNX 操作

      (在默认的 Opset 域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH: 尝试转换所有操作

      到默认 opset 域中的标准 ONNX 操作。如果无法执行此操作 (例如,因为尚未添加将特定 torch op 转换为 ONNX 的支持), 回退到将运算导出到自定义 OpSet 域而不进行转换。适用 自定义操作以及 ATen 操作。要使导出的模型可用,运行时必须支持 这些非标准 Ops.

    • OperatorExportTypes.ONNX_ATEN:所有 ATen 操作(在 TorchScript 命名空间 “aten” 中)

      导出为 ATen 操作(在 opset 域 “org.pytorch.aten” 中)。ATen 是 PyTorch 的内置张量库,因此 这会指示运行时使用 PyTorch 的这些运算实现。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

      如果运算符实现中的数值差异为 导致 PyTorch 和 Caffe2 之间的行为存在很大差异(即 常见于未经训练的模型)。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK:尝试导出每个 ATen op

      (在 TorchScript 命名空间 “aten” 中)作为常规 ONNX 操作。如果我们无法做到这一点 (例如,因为尚未添加将特定 torch op 转换为 ONNX 的支持), 回退到导出 ATen 运算。请参阅 OperatorExportTypes.ONNX_ATEN 上的文档 上下文。 例如:

      graph(%0 : Float):
      %3 : int = prim::Constant[value=0]()
      # conversion unsupported
      %4 : Float = aten::triu(%0, %3)
      # conversion supported
      %5 : Float = aten::mul(%4, %0)
      return (%5)
      

      假设 ONNX 不支持,这将导出为:aten::triu

      graph(%0 : Float):
      %1 : Long() = onnx::Constant[value={0}]()
      # not converted
      %2 : Float = aten::ATen[operator="triu"](%0, %1)
      # converted
      %3 : Float = onnx::Mul(%2, %0)
      return (%3)
      

      如果 PyTorch 是使用 Caffe2 构建的(即使用 ),则 将启用 Caffe2 特定的行为,包括特殊支持 for ops 由 Quantization 中描述的模块生成。BUILD_CAFFE2=1

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

  • opset_versionintdefault 14) – 目标的默认 (ai.onnx) opset 的版本。必须为 >= 7 且 <= 16。

  • do_constant_foldingbooldefault True) – 应用常量折叠优化。 常量折叠将替换一些具有所有常量输入的 operations 具有预先计算的常量节点。

  • dynamic_axesdict[stringdict[intstring]] 或 dict[stringlistint]默认为空 dict) –

    默认情况下,导出的模型将具有所有输入和输出张量的形状 设置为与 中给出的值完全匹配。要将张量的轴指定为 dynamic(即仅在运行时已知),设置为具有 schema 的 dict:argsdynamic_axes

    • KEY (str):输入或输出名称。每个名称还必须在 或input_names

      output_names.

    • VALUE (dict or list):如果是 dict,则键是轴索引,值是轴名称。如果

      list 中,每个元素都是一个轴索引。

    例如:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"]
    )
    

    生产:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    而:

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # dict value: manually named axes
            "x": {0: "my_custom_axis_name"},
            # list value: automatic names
            "sum": [0],
        }
    )
    

    生产:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputsbool默认 None) –

    如果为 True,则所有 初始化器(通常对应于参数)在 导出的图形也将作为输入添加到图形中。如果为 False,则 然后,初始化器不会作为输入添加到图形中,并且只会 非参数输入将添加为输入。 这可能允许通过以下方式进行更好的优化(例如常量折叠) backends/runtimes 的 Runtimes。

    如果 ,则初始化器必须是图形的一部分 inputs,并且此参数将被忽略,行为将为 等效于将此参数设置为 True。opset_version < 9

    如果为 None,则自动选择行为,如下所示:

    • 如果 ,则行为是等效的operator_export_type=OperatorExportTypes.ONNX

      将此参数设置为 False。

    • 否则,该行为等效于将此参数设置为 True。

  • custom_opsetsdict[strint]默认为空 dict) –

    具有 schema 的 dict:

    • KEY (str):opset 域名

    • VALUE (int):opset 版本

    如果自定义 opset 被此字典引用但未提及,则 Opset Version 设置为 1。只有自定义 opset 域名和版本应为 通过此参数表示。model

  • export_modules_as_functionsboolpython:type of nn.模块默认 False)–

    要启用的标志 将所有转发调用导出为 ONNX 中的本地函数。或者一个集合来指示 在 ONNX 中导出为本地函数的特定类型的模块。 此功能需要 >= 15,否则导出将失败。这是因为 < 15 意味着 IR 版本 < 8,这意味着不支持本地函数。 模块变量将导出为函数属性。函数分为两类 属性。nn.Moduleopset_versionopset_version

    1. 带注释的属性:通过 PEP 526 样式具有类型注释的类变量将导出为属性。 带注释的属性不在 ONNX 本地函数的子图中使用,因为 它们不是由 PyTorch JIT 跟踪创建的,但可供使用者使用 来确定是否将函数替换为特定的融合内核。

    2. 推断属性:模块内部运算符使用的变量。属性名称 将具有前缀 “inferred::”。这是为了与从 Python 模块注释。推断的属性在 ONNX 本地函数的子图中使用。

    • False(默认):将转发调用导出为细粒度节点。nn.Module

    • True:将所有 forward 调用导出为本地函数节点。nn.Module

    • nn 类型的集合。Module:将 forward call 导出为本地函数节点,nn.Module

      仅当在 set 中找到 the 的类型时。nn.Module

提升
  • torch.onnx.errors.CheckerError – 如果 ONNX 检查器检测到无效的 ONNX 图形。

  • torch.onnx.errors.UnsupportedOperatorError – 如果 ONNX 图形无法导出,因为它 使用导出器不支持的运算符。

  • torch.onnx.errors.OnnxExporterError – 导出过程中可能发生的其他错误。 所有错误都是 的子类。errors.OnnxExporterError

torch.onnx 中。export_to_pretty_stringmodelargsexport_params=Trueverbose=Falsetraining=<TrainingMode.EVAL: 0>input_names=Noneoutput_names=Noneoperator_export_type=<OperatorExportTypes.ONNX: 0>export_type=无google_printer=Falseopset_version=无keep_initializers_as_inputs=无custom_opsets=无add_node_names=真do_constant_folding=真dynamic_axes=无[来源]

类似,但返回 ONNX 的文本表示形式 型。仅在下面列出了 args 中的差异。所有其他参数都是相同的 作为 .

参数
  • add_node_namesbooldefault True) – 是否设置 NodeProto.name. 除非 .google_printer=True

  • google_printerbooldefault False) – 如果为 False,将返回一个自定义 模型的紧凑表示。如果为 True,则返回 protobuf 的 Message::D ebugString(),它更详细。

结果

一个 UTF-8 str,包含 ONNX 模型的人类可读表示形式。

torch.onnx 中。register_custom_op_symbolicsymbolic_namesymbolic_fnopset_version[来源]

为自定义运算符注册符号函数。

当用户为 custom/contrib 操作注册 symbolic 时, 强烈建议通过 setType API 为该算子添加形状推断, 否则,在某些极端情况下,导出的图形可能会有不正确的形状推断。 setType 的一个示例是 test_operators.py 中的 test_aten_embedding_2

有关示例用法,请参阅模块文档中的“自定义运算符”。

参数
  • symbolic_namestr) – “<domain>::<op>” 中自定义运算符的名称 格式。

  • symbolic_fnCallable) – 一个函数,它接受 ONNX 图和 input 参数传递给 current 运算符,并返回新的 要添加到图表中的 operator 节点。

  • opset_versionint) – 要在其中注册的 ONNX opset 版本。

torch.onnx 中。unregister_custom_op_symbolicsymbolic_nameopset_version[来源]

Unregisters (取消注册) 。symbolic_name

有关示例用法,请参阅模块文档中的“自定义运算符”。

参数
  • symbolic_namestr) – “<domain>::<op>” 中自定义运算符的名称 格式。

  • opset_versionint) – 要取消注册的 ONNX opset 版本。

torch.onnx 中。select_model_mode_for_export模型模式[来源]

一个上下文管理器,用于临时将训练模式设置为 ,当我们退出 with-block 时重置它。modelmode

参数
  • model – 与 arg to 的类型和含义相同。model

  • modeTrainingMode) – 与 arg to 的类型和含义相同。training

torch.onnx 中。is_in_onnx_export[来源]

返回它是否在 ONNX 导出过程中。

返回类型

布尔

torch.onnx 中。enable_log[来源]

启用 ONNX 日志记录。

torch.onnx 中。disable_log[来源]

禁用 ONNX 日志记录。

JitScalarType

在 torch 中定义的标量类型。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源