torch.export¶
警告
此功能是一个正在积极开发的原型,未来将会有重大变更。
概述¶
torch.export.export() 接受任意的Python可调用对象(一个
torch.nn.Module、函数或方法),并生成一个仅表示函数中Tensor计算的跟踪图,以提前编译(AOT)的方式进行,随后可以使用不同的输出执行或序列化。
import torch
from torch.export import export
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
f, args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
# code: a = torch.sin(x)
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
# code: b = torch.cos(y)
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
return (add,)
Graph signature: ExportGraphSignature(
parameters=[],
buffers=[],
user_inputs=['arg0_1', 'arg1_1'],
user_outputs=['add'],
inputs_to_parameters={},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
Equality constraints: []
torch.export 生成具有以下不变量的干净中间表示(IR)。更多关于 IR 的规范可以在
这里找到。
正确性: 它保证是原始程序的正确表示,并且保持了原始程序相同的调用约定。
已规范化: 图中没有Python语义。来自原始程序的子模块被内联以形成一个完全扁平化的计算图。
定义的操作符集: 生成的图仅包含少量定义的 Core ATen IR 操作符集和注册的自定义操作符。
图属性: 该图是纯粹的功能性,意味着它不包含具有副作用的操作,如突变或别名。它不会改变任何中间值、参数或缓冲区。
元数据: 图形包含在跟踪过程中捕获的元数据,例如来自用户代码的堆栈跟踪。
在幕后,torch.export 利用了以下最新技术:
TorchDynamo (torch._dynamo) 是一个内部API,它使用CPython的一个特性 称为帧评估API来安全地追踪PyTorch图。这 提供了一个极大改进的图捕获体验,需要重写的次数大大减少 以便完全追踪PyTorch代码。
提前自动微分(AOT Autograd) 提供了一个函数化的 PyTorch 计算图,并确保该计算图 被分解/转换为小范围定义的核心 ATen 操作符集合。
PyTorch FX (torch.fx) 是图的底层表示形式,允许灵活的基于Python的转换。
现有框架¶
torch.compile() 同样使用了与 torch.export 相同的PT2堆栈,但
略有不同:
即时编译与提前编译:
torch.compile()是一个即时编译器,而 则不打算用于在部署之外生成编译工件。部分图捕获与全图捕获: 当
torch.compile()遇到模型中无法追踪的部分时,它将“图中断”并回退到急切的Python运行时执行程序。相比之下,torch.export旨在获取PyTorch模型的完整图表示,因此当遇到无法追踪的内容时会报错。由于torch.export生成的图与任何Python特性或运行时无关,因此该图可以保存、加载并在不同的环境和语言中运行。可用性权衡: 由于
torch.compile()在遇到无法追踪的内容时能够回退到Python运行时,因此它更加灵活。而torch.export则需要用户提供更多信息或重写代码以使其可追踪。
与torch.fx.symbolic_trace()相比,torch.export使用TorchDynamo进行追踪,它在Python字节码级别运行,因此具有追踪任意Python构造的能力,而不受Python运算符重载支持的限制。此外,torch.export对张量元数据进行细粒度跟踪,因此基于张量形状等条件的追踪不会失败。一般来说,torch.export预计可以在更多用户程序上工作,并生成较低级别的图(在torch.ops.aten运算符级别)。请注意,用户仍然可以将torch.fx.symbolic_trace()作为torch.export之前的预处理步骤。
与 torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图更简单,只有直线控制流(除了显式的控制流操作符)。
与 torch.jit.trace() 相比,torch.export 是可靠的:它能够追踪执行整数计算的代码,并记录所有必要的附加条件,以证明特定的追踪对其他输入也是有效的。
导出PyTorch模型¶
一个示例¶
主要入口点是通过 torch.export.export(),它接受一个可调用对象(torch.nn.Module、函数或方法)和示例输入,并将计算图捕获到一个 torch.export.ExportedProgram 中。一个例子:
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
# code: a = self.conv(x)
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
);
# code: a.add_(constant)
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
# code: return self.maxpool(self.relu(a))
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
relu, [3, 3], [3, 3]
);
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
return (getitem,)
Graph signature: ExportGraphSignature(
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
buffers=[],
user_inputs=['arg2_1', 'arg3_1'],
user_outputs=['getitem'],
inputs_to_parameters={
'arg0_1': 'L__self___conv.weight',
'arg1_1': 'L__self___conv.bias',
},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
Equality constraints: []
检查ExportedProgram,我们可以注意到以下内容:
该
torch.fx.Graph包含原始程序的计算图,以及原始代码记录,便于调试。该图仅包含在Core ATen IR opset和自定义运算符中找到的
torch.ops.aten个运算符,并且是完全功能性的,没有任何就地运算符,如torch.add_。参数(权重和偏差到卷积)被提升为图的输入,
导致图中没有get_attr个节点,这些节点以前存在于torch.fx.symbolic_trace()的结果中。该
torch.export.ExportGraphSignature模型描述了输入和输出的签名,并指定了哪些输入是参数。图中每个节点生成的张量的形状和数据类型都被标注出来。例如,
convolution节点将生成一个数据类型为torch.float32且形状为 (1, 16, 256, 256) 的张量。
表达动态性¶
默认情况下,torch.export 会假设所有输入形状都是 静态 的,并将导出的程序专门化为这些维度。但是,某些维度(例如批次维度)可能是动态的,并且可以从运行到运行变化。这样的维度必须通过使用 torch.export.Dim() API 来创建它们,并通过 torch.export.export() 的 dynamic_shapes 参数传递它们。一个例子:
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
# code: out1 = self.branch1(x1)
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
# code: out2 = self.branch2(x2)
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# code: return (out1 + self.buffer, out2)
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
return (add, relu_1)
Graph signature: ExportGraphSignature(
parameters=[
'branch1.0.weight',
'branch1.0.bias',
'branch2.0.weight',
'branch2.0.bias',
],
buffers=['L__self___buffer'],
user_inputs=['arg5_1', 'arg6_1'],
user_outputs=['add', 'relu_1'],
inputs_to_parameters={
'arg0_1': 'branch1.0.weight',
'arg1_1': 'branch1.0.bias',
'arg2_1': 'branch2.0.weight',
'arg3_1': 'branch2.0.bias',
},
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
Equality constraints: [(InputDim(input_name='arg5_1', dim=0), InputDim(input_name='arg6_1', dim=0))]
需要注意的一些其他事项:
通过
torch.export.Dim()API 和dynamic_shapes参数,我们指定了每个输入的第一个维度是动态的。查看输入arg5_1和arg6_1,它们具有 (s0, 64) 和 (s0, 128) 的符号形状,而不是我们作为示例输入传递的 (32, 64) 和 (32, 128) 形状的张量。s0是一个符号,表示该维度可以是一系列值。exported_program.range_constraints描述了图表中每个符号的范围。 在这种情况下,我们看到s0的范围是 [2, inf]。由于技术原因,在这里难以解释,它们被假定为不是 0 或 1。这不是一个错误,并不一定意味着 导出的程序在维度为 0 或 1 时无法工作。请参阅 0/1 特殊化问题 以深入了解此主题。exported_program.equality_constraints描述了哪些维度需要相等。由于我们在约束中指定了每个参数的第一个维度是等价的, (dynamic_dim(example_args[0], 0) == dynamic_dim(example_args[1], 0)), 我们在相等性约束中看到指定arg5_1维度 0 和arg6_1维度 0 相等的元组。
(一种用于指定动态形状的遗留机制
涉及使用 torch.export.dynamic_dim() API 标记和约束动态维度,并通过 torch.export.export()
传递它们到 constraints 参数。该机制现在已 废弃,将来将不再支持。)
序列化¶
要保存ExportedProgram,用户可以使用torch.export.save()和
torch.export.load() API。一种约定是使用ExportedProgram
并以.pt2文件扩展名保存。
一个示例:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
专业领域¶
输入形状¶
如前所述,默认情况下,torch.export 将追踪针对输入张量形状优化的程序,除非通过 torch.export.dynamic_dim() API 指定某个维度为动态。这意味着如果存在依赖形状的控制流,torch.export 将根据给定的示例输入所选择的分支进行优化。例如:
import torch
from torch.export import export
def fn(x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(fn, example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 2]):
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
return (add,)
The conditional of (x.shape[0] > 5) does not appear in the
ExportedProgram because the example inputs have the static
shape of (10, 2). Since torch.export specializes on the inputs’ static
shapes, the else branch (x - 1) will never be reached. To preserve the dynamic
branching behavior based on the shape of a tensor in the traced graph,
torch.export.dynamic_dim() will need to be used to specify the dimension
of the input tensor (x.shape[0]) to be dynamic, and the source code will
need to be rewritten.
非张量输入¶
torch.export 还根据非 torch.Tensor 的输入值专门化跟踪的图,例如 int、float、bool 和 str。
然而,我们可能会在不久的将来对此进行更改,不再对基本类型输入进行专门化处理。
例如:
import torch
from torch.export import export
def fn(x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(fn, example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
return (add_2,)
由于整数是专门化的,torch.ops.aten.add.Tensor 操作
都是使用内联常量 1 计算的,而不是 arg1_1。
此外,在 for 循环中使用的 times 迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor 调用在图中“内联”了,
并且输入 arg2_1 从未被使用过。
torch.export的限制¶
图中断点¶
由于torch.export是一个一次性过程,用于从PyTorch程序中捕获计算图,因此它最终可能会遇到程序中无法追踪的部分,因为几乎不可能支持追踪所有PyTorch和Python功能。在torch.compile的情况下,不支持的操作将导致“图形中断”,并且不支持的操作将使用默认的Python评估运行。相比之下,torch.export将要求用户提供额外的信息或重写代码的部分以使其可追踪。由于追踪是基于TorchDynamo的,后者在Python字节码级别进行评估,因此与以前的追踪框架相比,所需的重写将显著减少。
当遇到图中断时,ExportDB 是一个很好的资源,可以了解支持和不支持的程序类型,以及如何重写程序以使其可跟踪。
数据/形状依赖的控制流¶
图中断点也可能在数据依赖的控制流(if
x.shape[0] > 2)中遇到,当形状没有被专门化时,因为跟踪编译器不可能处理这种情况,而不会生成组合爆炸数量的路径代码。在这种情况下,用户需要使用特殊的控制流操作符重写他们的代码。目前,我们支持 torch.cond
来表达类似if-else的控制流(更多即将推出!)。
API 参考¶
- torch.export.export(f, args, kwargs=None, *, constraints=None, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source]¶
export()接受任意的Python可调用对象(一个nn.Module、函数或方法)以及示例输入,并生成一个仅表示函数中Tensor计算的跟踪图,以提前编译(AOT)的方式进行,随后可以使用不同的输入执行或序列化。跟踪图 (1) 生成功能ATen运算符集中的标准化运算符(以及任何用户指定的自定义运算符),(2) 已消除了所有Python控制流和数据结构(某些例外情况除外),并且 (3) 记录了一组形状约束,以证明这种规范化和控制流消除对于未来的输入是合理的。健全性保证
在跟踪过程中,
export()会记录用户程序和底层PyTorch运算符内核所做的与形状相关的假设。 只有当这些假设成立时,输出ExportedProgram才被认为是有效的。跟踪对输入张量的形状(而不是值)做出假设。 这些假设必须在图形捕获时进行验证,以使
export()成功。具体来说:对输入张量的静态形状假设会自动进行验证,无需额外的努力。
对输入张量动态形状的假设需要通过使用
Dim()API 来显式指定以构建动态维度,并通过dynamic_shapes参数将其与示例输入相关联。
如果任何假设无法验证,将引发致命错误。当这种情况发生时, 错误消息将包括对规范的建议修复,这些修复是验证假设所需的。 例如
export()可能建议对动态维度定义进行以下修复dim0_x, 比如说出现在与输入x相关的形状中,该输入之前被定义为Dim("dim0_x"):dim = Dim("dim0_x", max=5)
这个示例意味着生成的代码要求输入的第0维
x必须小于或等于5才能有效。您可以检查对动态维度定义的建议修复,然后将它们逐字复制到您的代码中,而无需更改dynamic_shapes参数到您的export()调用。- Parameters
f (Callable) – 要跟踪的可调用对象。
约束条件 (可选[列表[约束条件]]) – [已废弃:请使用
dynamic_shapes替代,详见下方] 一个可选的动态参数约束列表,用于指定它们可能的形状范围。默认情况下,输入 torch.Tensor 的形状被视为静态。如果预期某个输入 torch.Tensor 具有动态形状,请使用dynamic_dim()来定义Constraint对象,以指定动态性和可能的形状范围。有关如何使用的示例,请参阅dynamic_dim()的文档字符串。动态形状 (可选[Union[Dict[str, Any], Tuple[Any]]]) –
应该是以下两种情况之一: 1) 一个字典,从
f的参数名称到它们的动态形状规范; 2) 一个元组,按原始顺序为每个输入指定动态形状规范。 如果你要对关键字参数指定动态性,则需要按照原始函数签名中定义的顺序传递它们。张量参数的动态形状可以指定为以下两种方式之一: (1) 一个从动态维度索引到
Dim()类型的字典,其中不需要包含静态维度索引,但如果包含的话, 它们应该映射到 None;或者 (2) 一个由Dim()类型或 None 组成的元组 / 列表, 其中Dim()类型对应于动态维度,而静态维度用 None 表示。对于是张量的字典或元组 / 列表类型的参数, 可以通过使用嵌套的映射或序列规范递归地进行指定。严格 (布尔值) – 当启用(默认)时,导出函数将通过TorchDynamo跟踪程序,这将确保生成的图的健全性。否则,导出的程序将不会验证图中隐含的假设,可能会导致原始模型和导出模型之间的行为差异。当用户需要绕过跟踪器中的错误,或者只是希望逐步在其模型中启用安全性时,这很有用。请注意,这不会影响生成的IR规范不同,模型将以相同的方式序列化,而不论此处传递的值是什么。 警告:此选项是实验性的,使用它需自行承担风险。
- Returns
一个
ExportedProgram包含被跟踪的可调用对象。- Return type
可接受的输入/输出类型
可接受的输入类型(对于
args和kwargs)和输出包括:基本类型,即
torch.Tensor、int、float、bool和str。数据类,但它们必须通过调用
register_dataclass()首先进行注册。(嵌套) 包含
dict,list,tuple,namedtuple和OrderedDict的数据结构,包含以上所有类型。
- torch.export.dynamic_dim(t, index)[source]¶
警告
(此功能已废弃。请参见
Dim()代替。)dynamic_dim()构造一个Constraint对象,该对象描述了张量t的维度index的动态性。应将Constraint对象传递给export()的constraints参数。- Parameters
t (torch.Tensor) – 示例输入张量,具有动态维度大小
索引 (int) – 动态维度的索引
- Returns
一个
Constraint对象,用于描述形状的动态性。它可以传递给export(), 这样export()就不会假设指定张量的静态大小,即保持其动态性作为符号大小, 而不是根据示例跟踪输入的大小进行专门化。
具体来说,
dynamic_dim()可以用来表达以下类型的动态性。维度的大小是动态且无界的:
t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size rather than always being static size 2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints)
维度的大小是动态的,并且有一个下限:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints)
维度的大小是动态的,并且有一个上限:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints)
一个维度的大小是动态的,并且它总是等于另一个动态维度的大小:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # Sizes of second dimension of t0 and first dimension are always equal constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints)
混合和匹配以上所有类型,只要它们不表达相互冲突的要求
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source]¶
警告
正在积极开发中,保存的文件可能无法在较新版本的PyTorch中使用。
将一个
ExportedProgram保存到类似文件的对象中。然后可以使用Python APItorch.export.load加载它。- Parameters
ep (导出的程序) – 要保存的导出程序。
f (Union[str, pathlib.Path, io.BytesIO) – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。
extra_files (可选[Dict[str, Any]]) – 从文件名到内容的映射,这些内容将作为f的一部分进行存储。
Example:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]¶
警告
正在积极开发中,保存的文件可能无法在较新版本的PyTorch中使用。
加载一个
ExportedProgram之前使用torch.export.save保存的。- Parameters
ep (导出的程序) – 要保存的导出程序。
f (Union[str, pathlib.Path, io.BytesIO) – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。
extra_files (可选[Dict[str, Any]]) – 此映射中给出的额外文件名将被加载,其内容将存储在提供的映射中。
- Returns
一个
ExportedProgram对象- Return type
Example:
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls)[source]¶
将一个数据类注册为
torch.export.export()的有效输入/输出类型。Example:
@dataclass class InputDataClass: feature: torch.Tensor bias: int class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) def fn(o: InputDataClass) -> torch.Tensor: res = res=o.feature + o.bias return OutputDataClass(res=res) ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.Dim(name, *, min=None, max=None)[source]¶
Dim()构建了一种类似于命名符号整数的类型,具有一定的范围。 它可以用来描述动态张量维度的多种可能值。 请注意,同一张量或不同张量的不同动态维度, 可以用相同的类型来描述。
- class torch.export.Constraint(*args, **kwargs)[source]¶
警告
不要直接构造
Constraint,请改用dynamic_dim()。这表示对输入张量维度的约束,例如,要求它们完全多态或在某个范围内。
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, equality_constraints, module_call_graph, example_inputs=None, verifier=None, tensor_constants=None)[source]¶
程序包来自
export()。它包含一个torch.fx.Graph,表示张量计算,一个state_dict包含所有提升参数和缓冲区的张量值,以及各种元数据。你可以像调用原始可追踪的
export()一样,以相同的调用约定调用ExportedProgram。要对图进行转换,请使用
.module属性来访问 一个torch.fx.GraphModule。然后你可以使用 FX转换 来重写图。之后,你可以简单地再次使用export()来构建一个正确的ExportedProgram。
- class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[source]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[source]¶
ExportGraphSignature模型的输入/输出签名是导出图的,这是一个具有更强不变性保证的fx.Graph。导出图是功能性的,并不通过
getattr节点访问像参数或缓冲区这样的“状态”。相反,export()确保参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不包含在图中,而是将更新后的缓冲区值建模为导出图的额外输出。所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出了以下模块:
class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图将为:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的ExportGraphSignature将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[source]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Union[torch.export.exported_program.ModuleCallSignature, NoneType] = None)[source]¶
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument], target: Union[str, NoneType])[source]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument], target: Union[str, NoneType])[source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]¶
ExportGraphSignature模型的输入/输出签名是导出图的,这是一个具有更强不变性保证的fx.Graph。导出图是功能性的,并不通过
getattr节点访问像参数或缓冲区这样的“状态”。相反,export()保证参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不包含在图中,而是将更新后的缓冲区值建模为导出图的额外输出。所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出了以下模块:
class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图将为:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的ExportGraphSignature将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )