torch.onnx¶
Open Neural Network eXchange (ONNX) 是一种开放标准 格式来表示机器学习模型。torch.onnx 模块可以导出 PyTorch 模型添加到 ONNX。然后,支持 ONNX 的众多运行时中的任何一个都可以使用该模型。
示例:从 PyTorch 到 ONNX 的 AlexNet¶
下面是一个简单的脚本,它将预训练的 AlexNet 导出到名为 .
对 的调用运行模型一次以跟踪其执行情况,然后将
traced model 到指定的文件:alexnet.onnx
torch.onnx.export
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
生成的文件包含一个二进制协议缓冲区,其中包含您导出的模型的网络结构和参数
(在本例中为 AlexNet)。该参数会导致
exporter 打印出模型的人类可读表示形式:alexnet.onnx
verbose=True
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
您还可以使用 ONNX 库验证输出, 您可以使用 conda 安装:
conda install -c conda-forge onnx
然后,您可以运行:
import onnx
# Load the ONNX model
model = onnx.load("alexnet.onnx")
# Check that the model is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
您还可以使用支持 ONNX 的众多运行时之一运行导出的模型。 例如,安装 ONNX 运行时后,您可以 加载并运行模型:
import onnxruntime as ort
ort_session = ort.InferenceSession("alexnet.onnx")
outputs = ort_session.run(
None,
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])
这是一个更复杂的教程,介绍如何导出模型并使用 ONNX 运行时运行它。
跟踪与脚本¶
在内部,需要一个而不是
一个
.如果传入的模型还不是 ,将使用跟踪将其转换为 1:
torch.onnx.export()
ScriptModule
export()
跟踪:如果使用还不是 的 Module 调用,它首先执行等效于
,从而执行模型 once 替换为 given 和 记录该执行期间发生的所有操作。这 表示如果您的模型是动态的,例如,根据输入数据更改行为,则导出的 model 不会捕获此动态行为。同样,跟踪可能仅对 特定的输入大小。我们建议检查导出的模型并确保运算符查找 合理。跟踪将展开循环和 if 语句,导出一个完全是 与 traced run 相同。如果要使用动态控制流导出模型,则可以 需要使用脚本。
torch.onnx.export()
ScriptModule
args
脚本:通过脚本编译模型可保留动态控制流,并且对输入有效 大小不一。要使用脚本:
有关更多详细信息,包括如何编写跟踪和脚本以适应 不同型号的特殊要求。
避免陷阱¶
避免使用 NumPy 和内置 Python 类型¶
PyTorch 模型可以使用 NumPy 或 Python 类型和函数编写,但 在跟踪期间,NumPy 或 Python 的任何变量 类型(而不是 torch.Tensor)转换为常量,这将产生 如果这些值应根据输入而变化,则结果错误。
例如,不要在 numpy.ndarrays 上使用 numpy 函数:
# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)
在 torch 上使用 torch 操作器。张:
# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)
而不是使用 (将 Tensor 转换为 Python
内置编号):
# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
return x.reshape(y.item(), -1)
使用 torch 对单元素张量的隐式转换的支持:
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
return x.reshape(y, -1)
避免使用 Tensor.data¶
使用 Tensor.data 字段可能会产生不正确的跟踪,从而产生不正确的 ONNX 图。
请改用。(完全删除 Tensor.data 的工作正在进行中)。
局限性¶
类型¶
只有Torch。Tensors,可以简单地转换为 torch 的数字类型。张量(例如 float、int)、 支持将这些类型的元组和列表作为模型输入或输出。Dict 和 str 输入以及 在跟踪模式下接受输出,但是:
任何依赖于 dict 或 str 输入的值的计算都将被替换为 constant 值。
任何作为 dict 的输出都将被静默替换为其值的扁平化序列 (键将被删除)。例如 成为。
{"foo": 1, "bar": 2}
(1, 2)
任何作为 str 的输出都将被静默删除。
由于 ONNX 对嵌套序列的支持有限,因此脚本模式不支持某些涉及元组和列表的操作。 特别是,不支持将 Tuples 附加到列表。在跟踪模式下,嵌套序列 将在跟踪期间自动拼合。
Operator 实现的差异¶
由于 Operator 的实现存在差异,因此在不同的运行时上运行导出的模型 可能会产生彼此不同的结果,也可能产生不同的 PyTorch 结果。通常这些差异是 数值较小,因此仅当应用程序对这些敏感时,才应考虑此问题 微小的差异。
不支持的 Tensor 索引模式¶
下面列出了无法导出的张量索引模式。
如果您在导出不包含任何
下面不支持的模式,请仔细检查您是否正在使用
最新的 .opset_version
读取 / 获取¶
当索引到张量中进行读取时,不支持以下模式:
# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
写入 / 设置¶
当索引到 Tensor 中进行写入时,不支持以下模式:
# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
添加对运算符的支持¶
导出包含不支持的运算符的模型时,您将看到如下错误消息:
RuntimeError: ONNX export failed: Couldn't export operator foo
发生这种情况时,您需要更改模型以不使用该运算符, 或添加对 Operator 的支持。
添加对 Operator 的支持需要对 PyTorch 的源代码进行更改。 有关此的一般说明,请参阅 CONTRIBUTING ,有关代码的具体说明,请参阅下面的 支持 Operator 所需的更改。
在导出过程中,将按拓扑顺序访问 TorchScript 图中的每个节点。
访问节点时,导出器会尝试查找已注册的符号函数
那个节点。符号函数是在 Python 中实现的。一个
名为 的操作 将如下所示:foo
def foo(
g: torch._C.Graph,
input_0: torch._C.Value,
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
"""
Modifies g (e.g., using "g.op()"), adding the ONNX operations representing
this PyTorch function.
Args:
g (Graph): graph to write the ONNX representation into.
input_0 (Value): value representing the variables which contain
the first input for this operator.
input_1 (Value): value representing the variables which contain
the second input for this operator.
Returns:
A Value or List of Values specifying the ONNX nodes that compute something
equivalent to the original PyTorch operator with the given inputs.
Returns None if it cannot be converted to ONNX.
"""
...
这些类型是 ir.h 中 C++ 中定义的类型的 Python 包装器。torch._C
添加符号函数的过程取决于运算符的类型。
ATen 运算符¶
ATen 是 PyTorch 的内置张量库。
如果运算符是 ATen 运算符(在 TorchScript 图中显示为前缀):aten::
在 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。 确保该函数与 ATen 函数具有相同的名称,该函数可以在 或 (这些文件在 构建时,因此在您构建 PyTorch 之前不会出现在您的检出中)。
torch/onnx/symbolic_opset<version>.py
torch/_C/_VariableFunctions.pyi
torch/nn/functional.pyi
第一个 arg 始终是为导出而构建的 ONNX 图。 其他 arg 名称必须与文件中的名称完全匹配, 因为 dispatch 是通过 keyword 参数完成的。
.pyi
在 symbolic 函数中,如果运算符在 ONNX 标准运算符集中,则 我们只需要创建一个节点来表示图中的 ONNX 运算符。 如果不是,我们可以创建一个包含多个标准运算符的图形,这些运算符具有 与 ATen 运算符等效的语义。
如果输入参数是 Tensor,但 ONNX 要求提供标量,则必须 显式执行转换。 可以将 scalar tensor 转换为 Python 标量,并且可以将 Python 标量转换为 PyTorch 张量。
symbolic_helper._scalar()
symbolic_helper._if_scalar_type_as()
下面是处理运算符缺少的符号函数的示例。ELU
如果我们运行以下代码:
print(
torch.jit.trace(torch.nn.ELU(), # module
torch.ones(1) # example input
).graph)
我们看到类似以下内容:
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%4 : float = prim::Constant[value=1.]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
return (%7)
由于我们在图中看到,我们知道这是一个 ATen 运算符。aten::elu
我们检查 ONNX 运算符列表,
并确认 ONNX 中已标准化。Elu
我们在 中找到 的签名 :elu
torch/nn/functional.pyi
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
我们将以下行添加到 :symbolic_opset9.py
def elu(g, input, alpha, inplace=False):
return g.op("Elu", input, alpha_f=_scalar(alpha))
现在 PyTorch 能够导出包含 Operator 的模型!aten::elu
有关更多示例,请参阅文件。symbolic_opset*.py
torch.autograd.函数¶
静态符号方法¶
您可以向函数类添加名为 的静态方法。它应该返回
表示函数在 ONNX 中的行为的 ONNX 运算符。例如:symbolic
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch._C.graph, input: torch._C.Value) -> torch._C.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
PythonOp 符号¶
或者,您可以注册自定义符号函数。
这使符号函数可以通过
TorchScript 对象,该对象作为第二个
参数(在对象之后)。Node
Graph
所有 autograd 在 TorchScript 图中都显示为节点。
为了区分不同的子类,
symbolic 函数应使用设置为类名称的 kwarg。Function
prim::PythonOp
Function
name
不允许在
命名空间,因此对于此用例,有一个后门:注册
symbolic 表示 .
prim
"::prim_PythonOp"
自定义符号函数应在返回 Value 对象之前通过调用 Value 对象来添加类型和形状信息(在 C++ 中由 实现)。这不是必需的,但它可以帮助导出者的
下游节点的形状和类型推理。有关 的重要示例,请参见 test_operators.py。setType(...)
torch::jit::Value::setType
setType
test_aten_embedding_2
下面的示例显示了如何通过对象进行访问:requires_grad
Node
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min):
ctx.save_for_backward(input)
return input.clamp(min=min)
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
def symbolic_pythonop(g: torch._C.Graph, n: torch._C.Node, *args, **kwargs):
print("original node: ", n)
for i, out in enumerate(n.outputs()):
print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
import torch.onnx.symbolic_helper as sym_helper
for i, arg in enumerate(args):
requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
name = kwargs["name"]
ret = None
if name == "MyClip":
ret = g.op("Clip", args[0], min_f=args[1])
elif name == "MyRelu":
ret = g.op("Relu", args[0])
else:
# Logs a warning and returns None
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
# Copy type and shape from original node.
ret.setType(n.type())
return ret
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 1)
自定义运算符¶
如果模型使用以 C++ 实现的自定义运算符,如使用自定义 C++ 运算符扩展 TorchScript 中所述, 您可以按照以下示例导出它:
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
# Define custom symbolic function
@parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super(FooModule, self).__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2})
您可以将其导出为标准 ONNX 操作的一个或组合,或者导出为自定义运算符。
上面的示例将其导出为 “custom_domain” opset 中的自定义运算符。
导出自定义运算符时,您可以在导出时使用字典指定自定义域版本。如果未指定,则自定义 opset 版本默认为 1。
连接模型的运行时需要支持自定义运算。请参阅 Caffe2 自定义操作、ONNX 运行时自定义操作、
或您选择的运行时的文档。custom_opsets
常见问题解答¶
Q: 我已经导出了 LSTM 模型,但其输入大小似乎是固定的?
Q: 如何导出包含 Loop 的模型?
请参阅跟踪与脚本。
Q: 如何导出具有基元类型输入(例如 int、float)的模型?
PyTorch 1.9 中添加了对基元数字类型输入的支持。 但是,导出器不支持具有 str 输入的模型。
问:ONNX 是否支持隐式标量数据类型转换?
否,但导出器将尝试处理该部分。标量导出为常量张量。 导出器将尝试找出标量的正确数据类型。但是,当它无法时 为此,您需要手动指定 DataType。这通常发生在 脚本化模型,其中不记录数据类型。例如:
class ImplicitCastType(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): # Exporter knows x is float32, will export "2" as float32 as well. y = x + 2 # Currently the exporter doesn't know the datatype of y, so # "3" is exported as int64, which is wrong! return y + 3 # To fix, replace the line above with: # return y + torch.tensor([3], dtype=torch.float32) x = torch.tensor([1.0], dtype=torch.float32) torch.onnx.export(ImplicitCastType(), x, "implicit_cast.onnx", example_outputs=ImplicitCastType()(x))我们正在尝试改进 exporter 中的数据类型传播,以便隐式转换 在更多情况下受支持。
问:张量列表是否可以导出到 ONNX?
是的,对于 >= 11,因为 ONNX 在操作集 11 中引入了 Sequence 类型。
opset_version
功能¶
-
torch.onnx.
export
(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=无, output_names=无, operator_export_type=无, opset_version=无, _retain_param_name=无, do_constant_folding=真,example_outputs=无,strip_doc_string=无,dynamic_axes=无,keep_initializers_as_inputs=无,custom_opsets=无,enable_onnx_checker=无,use_external_data_format=无)[来源]¶ 将模型导出为 ONNX 格式。如果不是 a
或 a
,则运行一次,以便将其转换为要导出的 TorchScript 图形 (相当于
)。因此,这具有相同的有限支持 对于动态控制流,如
.
model
model
- 参数
model (torch.nn.Module、torch.jit.ScriptModule 或 torch.jit.ScriptFunction) – 要导出的模型。
ARGS 的结构可以是:
只有一个参数元组:
args = (x, y, z)
元组应包含模型输入,以便 调用模型。任何非 Tensor 参数都将被硬编码到 导出的模型;任何 Tensor 参数都将成为导出模型的输入, 按照它们在 Tuples 中出现的顺序。
model(*args)
一个张量:
args = torch.Tensor([1])
这相当于该 Tensor 的 1-ary 元组。
以命名参数字典结尾的参数元组:
args = (x, {'y': input_y, 'z': input_z})
除了元组的最后一个元素之外,所有元素都将作为非关键字参数传递。 命名参数将从最后一个元素开始设置。如果命名参数为 不存在,则为其分配默认值,如果 default 值。
注意
如果字典是 args 元组的最后一个元素,则它将是 解释为包含命名参数。为了将 dict 作为 last 非关键字 arg,则提供一个空 dict 作为 args 的最后一个元素 元。例如,而不是:
torch.onnx.export( model, (x, # WRONG: will be interpreted as named arguments {y: z}), "test.onnx.pb")
写:
torch.onnx.export( model, (x, {y: z}, {}), "test.onnx.pb")
f – 类文件对象(例如返回文件描述符) 或包含文件名的字符串。将写入二进制协议缓冲区 添加到此文件。
f.fileno()
export_params (bool, default True) – 如果为 True,则所有参数都将 被导出。如果要导出未经训练的模型,请将此项设置为 False。 在这种情况下,导出的模型将首先采用其所有参数 作为参数,其顺序由
model.state_dict().values()
verbose (bool, default False) – 如果为 True,则打印 模型正在导出到 stdout 中。此外,最终的 ONNX 图将包括 字段,其中提到了源代码位置 为。
doc_string`
model
training (enum,默认 TrainingMode.EVAL) –
TrainingMode.EVAL
:在推理模式下导出模型。TrainingMode.PRESERVE
:如果 model.training 为 如果 model.training 为 True,则为 False,则处于训练模式。TrainingMode.TRAINING
:在训练模式下导出模型。禁用优化 这可能会干扰训练。
input_names (list of str, default empty list) – 要分配给 input 节点。
output_names (list of str, default empty list) – 要分配给 output 节点。
operator_export_type (enum, default None) –
none 通常表示 。 但是,如果 PyTorch 是使用 构建的,则 None 表示 .
OperatorExportTypes.ONNX
-DPYTORCH_ONNX_CAFFE2_BUNDLE
OperatorExportTypes.ONNX_ATEN_FALLBACK
OperatorExportTypes.ONNX
:将所有操作导出为常规 ONNX 操作 (在默认的 Opset 域中)。OperatorExportTypes.ONNX_FALLTHROUGH
: 尝试转换所有操作 到默认 opset 域中的标准 ONNX 操作。如果无法执行此操作 (例如,因为尚未添加将特定 torch op 转换为 ONNX 的支持), 回退到将运算导出到自定义 OpSet 域而不进行转换。适用 自定义操作以及 ATen 操作。要使导出的模型可用,运行时必须支持 这些非标准 Ops.OperatorExportTypes.ONNX_ATEN
:所有 ATen 操作(在 TorchScript 命名空间 “aten” 中) 导出为 ATen 操作(在 opset 域 “org.pytorch.aten” 中)。ATen 是 PyTorch 的内置张量库,因此 这会指示运行时使用 PyTorch 的这些运算实现。警告
以这种方式导出的模型可能只能由 Caffe2 运行。
如果运算符实现中的数值差异为 导致 PyTorch 和 Caffe2 之间的行为存在很大差异(即 常见于未经训练的模型)。
OperatorExportTypes.ONNX_ATEN_FALLBACK
:尝试导出每个 ATen op (在 TorchScript 命名空间 “aten” 中)作为常规 ONNX 操作。如果我们无法做到这一点 (例如,因为尚未添加将特定 torch op 转换为 ONNX 的支持), 回退到导出 ATen 运算。请参阅 OperatorExportTypes.ONNX_ATEN 上的文档 上下文。 例如:graph(%0 : Float): %3 : int = prim::Constant[value=0]() # conversion unsupported %4 : Float = aten::triu(%0, %3) # conversion supported %5 : Float = aten::mul(%4, %0) return (%5)
假设 ONNX 不支持,这将导出为:
aten::triu
graph(%0 : Float): %1 : Long() = onnx::Constant[value={0}]() # not converted %2 : Float = aten::ATen[operator="triu"](%0, %1) # converted %3 : Float = onnx::Mul(%2, %0) return (%3)
如果运算在 TorchScript 命名空间 “quantized” 中,则将被导出 在 ONNX opset 域 “caffe2” 中。这些操作由 Quantization 中描述的模块。
警告
以这种方式导出的模型可能只能由 Caffe2 运行。
opset_version (int, default 9) – 必须为 , 在 torch/onnx/symbolic_helper.py 中定义。
== _onnx_main_opset or in _onnx_stable_opsets
_retain_param_name (bool, default True) – [已弃用并忽略。将在下一个 PyTorch 中删除 release]
do_constant_folding (bool, default False) – 应用常量折叠优化。 常量折叠将替换一些具有所有常量输入的 operations 具有预先计算的常量节点。
example_outputs (T 或 T 的元组,其中 T 是 Tensor 或可转换为 Tensor,默认为 None) – [已弃用并忽略。将在下一个 PyTorch 版本中删除]。 必须在导出 ScriptModule 或 ScriptFunction 时提供,否则忽略。 用于确定输出的类型和形状,而不跟踪 模型。单个对象被视为等效于一个元素的元组。
strip_doc_string (bool, default True) – [已弃用并忽略。将在下一个 PyTorch 版本中删除]
dynamic_axes (dict<string, dict<python:int, string>> 或 dict<string, list(int)>,默认为空 dict) –
默认情况下,导出的模型将具有所有输入和输出张量的形状 设置为完全匹配 中给定的那些(并且当该 arg 为 必需)。要将张量轴指定为动态轴(即仅在运行时已知),请使用架构设置为 dict:
args
example_outputs
dynamic_axes
KEY (str):输入或输出名称。每个名称也必须在 或 中提供。
input_names
output_names
VALUE (dict or list):如果是 dict,则键是轴索引,值是轴名称。如果 list 中,每个元素都是一个轴索引。
例如:
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"])
生产:
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
而:
torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], })
生产:
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
keep_initializers_as_inputs (bool,默认 None) –
如果为 True,则所有 初始化器(通常对应于参数)在 导出的图形也将作为输入添加到图形中。如果为 False,则 然后,初始化器不会作为输入添加到图形中,并且只会 非参数输入将添加为输入。 这可能允许通过以下方式进行更好的优化(例如常量折叠) backends/runtimes 的 Runtimes。
如果< 9,则初始化器必须是 graph 的一部分 inputs,并且此参数将被忽略,行为将为 等效于将此参数设置为 True。
opset_version
如果为 None,则自动选择行为,如下所示:
如果 ,则行为是等效的 将此参数设置为 False。
operator_export_type=OperatorExportTypes.ONNX
否则,该行为等效于将此参数设置为 True。
custom_opsets (dict<str, int> , 默认为空 dict) –
要指示的字典
具有 schema 的 dict:
KEY (str):opset 域名
VALUE (int):opset 版本
如果自定义 opset 被此字典引用但未提及,则 Opset Version 设置为 1。
model
enable_onnx_checker (bool, default True) – 已弃用并忽略。将在下次删除 Pytorch 版本。
use_external_data_format (bool, default False) – [已弃用并忽略。将在 下一个 Pytorch 版本。 如果为 True,则某些模型参数存储在外部数据文件中,而不是 ONNX 模型文件本身。大于 2GB 的模型不能导出到一个文件中,因为 Protocol Buffers 施加的大小限制。 有关详细信息,请参阅 onnx.proto。 如果为 True,则 argument 必须是指定模型位置的字符串。 外部数据文件将存储在与 相同的目录中。 除非 .
f
f
operator_export_type=OperatorExportTypes.ONNX
- 提高
ONNXCheckerError – 如果 ONNX 检查器检测到无效的 ONNX 图形。仍将导出 model 添加到文件中,即使这被引发。
f
-
torch.onnx.
register_custom_op_symbolic
(symbolic_name、symbolic_fn、opset_version)[来源]¶ 要处理的寄存器。看 模块文档中的 “Custom Operators” 以获取示例用法。
symbolic_fn
symbolic_name