torch.onnx¶
概述¶
Open Neural Network eXchange (ONNX) 是一种开放标准
格式来表示机器学习模型。该模块从
native PyTorch 模型并将其转换为 ONNX 图。
torch.onnx
导出的模型可以由支持 ONNX 的许多运行时中的任何一个使用,包括 Microsoft 的 ONNX 运行时。
您可以使用两种风格的 ONNX 导出器 API,如下所示。两者都可以通过 function 调用。
下一个示例演示如何导出简单模型。
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 128, 5)
def forward(self, x):
return torch.relu(self.conv1(x))
input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)
model = MyModel()
torch.onnx.export(
model, # model to export
(input_tensor,), # inputs of the model,
"my_model.onnx", # filename of the ONNX model
input_names=["input"], # Rename inputs for the ONNX model
dynamo=True # True or False to select the exporter to use
)
下一节将介绍导出器的两个版本。
基于 TorchDynamo 的 ONNX 导出器¶
基于 TorchDynamo 的 ONNX 导出器是 PyTorch 2.1 及更高版本的最新(和测试版)导出器
TorchDynamo 引擎被用来挂接到 Python 的帧评估 API 中,并动态地重写其 字节码转换为 FX Graph 中。然后,在最终将其转换为 ONNX 图。
这种方法的主要优点是 FX 图表是使用 字节码分析,它保留了模型的动态性质,而不是使用传统的静态跟踪技术。
基于 TorchScript 的 ONNX 导出器¶
基于 TorchScript 的 ONNX 导出器从 PyTorch 1.2.0 开始可用
利用 TorchScript 来跟踪(通过 )
模型并捕获静态计算图。
因此,生成的图形有几个限制:
它不记录任何控制流,如 if 语句或循环;
不处理 和 模式之间的细微差别;
training
eval
不能真正处理动态输入
为了支持静态跟踪限制,导出器还支持 TorchScript 脚本
(通过 ),例如,它增加了对数据依赖型 control-flow 的支持。但是,TorchScript
本身是 Python 语言的子集,因此并非支持 Python 中的所有功能,例如就地操作。