torch.export¶
警告
此功能是正在积极开发的原型,将会有 未来的重大变化。
概述¶
采用任意 Python 可调用对象(、函数
或方法)并生成跟踪图
仅表示 Ahead-of-Time 中函数的 Tensor 计算
(AOT) 方式执行,随后可以使用不同的输出执行,或者
序列 化。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
# code: a = torch.sin(x)
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
# code: b = torch.cos(y)
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
return (add,)
Graph signature: ExportGraphSignature(
parameters=[],
buffers=[],
user_inputs=['arg0_1', 'arg1_1'],
user_outputs=['add'],
inputs_to_parameters={},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
torch.export
生成一个干净的中间表示 (IR),其中
遵循不变量。有关 IR 的更多规格,请点击此处。
健全性:保证是原件的合理表示 program 的调用,并保持与原始程序相同的调用约定。
规范化:图中没有 Python 语义。子模块 从原始程序内联以形成一个完全展平的程序 计算图。
图形属性:图形是纯函数式的,这意味着它不是 包含具有副作用(如突变或别名)的操作。它确实 不改变任何中间值、参数或缓冲区。
元数据:该图包含在跟踪期间捕获的元数据,例如 stacktrace 从用户的代码中获取。
在后台,利用以下最新技术:torch.export
TorchDynamo (torch._dynamo) 是一个使用 CPython 功能的内部 API 调用了帧评估 API 以安全地跟踪 PyTorch 图形。这 提供大幅改进的图形捕获体验,而 需要重写才能完全跟踪 PyTorch 代码。
AOT Autograd 提供功能化的 PyTorch 图,并确保该图 分解/降低到 ATen 运算符集。
Torch FX (torch.fx) 是图形的底层表示, 允许基于 Python 的灵活转换。
现有框架¶
也使用与 相同的 PT2 堆栈,但
略有不同:
torch.export
部分图形捕获与完整图形捕获:当
遇到 无法追踪的部分,它将 “graph break” 并回退到正在运行的 急切的 Python 运行时中的程序。相比之下,目标 来获取 PyTorch 模型的完整图形表示,因此它会出错 当到达无法追踪的东西时。由于会生成一个完整的 graph 与任何 Python 功能或运行时不相交,那么这个图形可以是 在不同的环境和语言中保存、加载和运行。
torch.export
torch.export
可用性权衡:由于
能够回退到 每当 Python 运行时达到无法追踪的程度时,它就会多得多 灵活。 将要求用户提供更多 信息或重写其代码以使其可跟踪。
torch.export
与 , 使用
TorchDynamo 在 Python 字节码级别运行,使其能够
跟踪不受 Python 运算符限制的任意 Python 构造
超载支持。此外,还可以对
Tensor 元数据,因此 Tensor 形状等内容上的条件不会
失败跟踪。一般来说,预期会对更多的用户起作用
程序生成较低级别的图形(在运算符
级别)。请注意,用户仍然可以将
预处理步骤 。
torch.export
torch.export
torch.export
torch.ops.aten
torch.export
与 相比,不捕获 Python
控制流或数据结构,但它支持更多的 Python 语言功能
比 TorchScript 多(因为它更容易全面覆盖 Python
字节码)。生成的图形更简单,并且只有直线控制
flow (显式控制流运算符除外)。
torch.export
与 相比,是声音:它能够
跟踪代码,该代码对 sizes 执行整数计算并记录所有
side-条件,以表明特定跟踪对其他
输入。
torch.export
导出 PyTorch 模型¶
示例¶
主入口点是通过 ,它采用
callable (
、 function 或 method) 和示例输入,以及
将计算图捕获到
.一
例:
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
# code: a = self.conv(x)
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
);
# code: a.add_(constant)
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
# code: return self.maxpool(self.relu(a))
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
relu, [3, 3], [3, 3]
);
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
return (getitem,)
Graph signature: ExportGraphSignature(
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
buffers=[],
user_inputs=['arg2_1', 'arg3_1'],
user_outputs=['getitem'],
inputs_to_parameters={
'arg0_1': 'L__self___conv.weight',
'arg1_1': 'L__self___conv.bias',
},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
检查 ,我们可以注意到以下内容:ExportedProgram
该图仅包含在此处找到的运算符和自定义运算符,并且功能齐全,没有任何就地运算符 如。
torch.ops.aten
torch.add_
图中每个节点生成的张量的最终形状和 dtype 为 著名的。例如,该节点将产生 dtype 和 shape (1, 16, 256, 256) 的张量。
convolution
torch.float32
非严格导出¶
在 PyTorch 2.3 中,我们引入了一种新的跟踪模式,称为非严格模式。 它仍在进行强化,因此如果您遇到任何问题,请提交 他们发送到 Github 并使用 “oncall: export” 标签。
在非严格模式下,我们使用 Python 解释器跟踪程序。 您的代码将完全按照 Eager 模式执行;唯一的区别是 所有 Tensor 对象都将替换为 ProxyTensors,后者将记录所有 它们的操作合并到一个图表中。
在当前默认的 strict 模式下,我们首先通过 程序。TorchDynamo 不会 实际执行您的 Python 代码。相反,它象征性地分析了它,并且 根据结果构建图形。此分析允许 torch.export 为 提供更强的安全保证,但并非所有 Python 代码都受支持。
您可能希望使用非严格模式的一个示例是,如果运行 转换为可能不容易解决的不受支持的 TorchDynamo 功能,并且您 知道 Python 代码并不是计算所完全需要的。例如:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在此示例中,使用非严格模式(通过标志)的第一次调用成功跟踪,而使用strict
mode(默认)导致失败,其中 TorchDynamo 无法支持
上下文管理器。一种选择是重写代码(参见 torch.export 的限制),但因为上下文管理器不会影响张量
计算,我们可以采用非严格模式的结果。strict=False
表达活力¶
默认情况下,将跟踪程序,假设所有输入形状都是静态的,并将导出的程序专门化到这些维度。然而
某些维度(如批次维度)可以是动态的,并且会因 Run 到 而异
跑。必须使用 API 创建此类维度,并通过参数将它们传递给来指定此类维度。一个例子:
torch.export
torch.export.Dim()
dynamic_shapes
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
# code: out1 = self.branch1(x1)
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
# code: out2 = self.branch2(x2)
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# code: return (out1 + self.buffer, out2)
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
return (add, relu_1)
Graph signature: ExportGraphSignature(
parameters=[
'branch1.0.weight',
'branch1.0.bias',
'branch2.0.weight',
'branch2.0.bias',
],
buffers=['L__self___buffer'],
user_inputs=['arg5_1', 'arg6_1'],
user_outputs=['add', 'relu_1'],
inputs_to_parameters={
'arg0_1': 'branch1.0.weight',
'arg1_1': 'branch1.0.bias',
'arg2_1': 'branch2.0.weight',
'arg3_1': 'branch2.0.bias',
},
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
一些需要注意的其他事项:
通过 API 和参数,我们指定了第一个 维度设置为动态的。查看输入 和 ,它们的符号形状为 (s0, 64) 和 (s0, 128),而不是 我们作为示例输入传入的 (32, 64) 和 (32, 128) 形状的张量。 是一个符号,表示此维度可以是范围 的值。
torch.export.Dim()
dynamic_shapes
arg5_1
arg6_1
s0
exported_program.range_constraints
描述每个元件的范围 显示在图表中。在本例中,我们看到 具有 [2, inf].由于此处难以解释的技术原因,它们是 假定不是 0 或 1。这不是一个错误,也不一定意味着 导出的程序将不适用于维度 0 或 1。有关此主题的深入讨论,请参阅 0/1 特化问题。s0
我们还可以在输入形状之间指定更具表现力的关系,例如 一对形状可能相差 1 时,形状可能是 另一个或一个形状是偶数。一个例子:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None
add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None
return (add,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
需要注意的一些事项:
通过指定第一个输入,我们可以看到生成的 第一个输入的形状现在是动态的,为 。现在,通过指定第二个输入,我们看到 第二个输入也是动态的。然而,因为我们表达了 , 的 形状包含新元件,而不是 ,我们看到它是 现在用 , 中使用的相同符号表示。我们可以 请看 的关系是通过 显示的。
{0: dimx}
[s0]
{0: dimy}
dimy = dimx + 1
arg1_1
arg0_1
s0
dimy = dimx + 1
s0 + 1
查看范围约束,我们看到 的范围为 [3, 6], 它最初指定,我们可以看到 已解决 范围 [4, 7]。
s0
s0 + 1
序列化¶
要保存 ,用户可以使用 和
API。惯例是使用文件扩展名保存。
ExportedProgram
ExportedProgram
.pt2
一个例子:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
专门化¶
理解 的行为的一个关键概念是
static 和 dynamic 值之间的差异。torch.export
动态值是可以在每次运行之间变化的值。这些 cookie 的行为类似于 normal 参数传递给 Python 函数 - 您可以为 参数,并期望你的函数做正确的事情。Tensor 数据为 被视为 Dynamic。
静态值是在导出时固定且无法更改的值 在导出程序的执行之间。当在 tracing 时,导出器会将其视为常量并将其硬编码到 图。
当执行操作(例如 )并且所有 inputs 都是静态的时,则
操作的输出将直接硬编码到图形中,并且
operation 不会显示(即它会被常量折叠)。x + y
当一个值被硬编码到图形中时,我们说该图形已经专门化了该值。
以下值是静态的:
输入 Tensor 形状¶
默认情况下,将跟踪专门处理 input 的程序
张量的形状,除非通过参数将维度指定为 dynamic。这意味着如果存在
形状相关的控制流,将专门用于分支
这是使用给定的样本输入获取的。例如:torch.export
dynamic_shapes
torch.export
torch.export
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 2]):
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
return (add,)
() 的条件不会出现在 中,因为示例输入具有静态
的形状 (10, 2)。Since 专门研究输入的 static
shapes 时,将永远不会到达 else 分支 ()。要保留动态
需要使用基于跟踪图中张量形状的分支行为来指定维度
的 API 设置为 dynamic 的,源代码将
需要重写。x.shape[0] > 5
ExportedProgram
torch.export
x - 1
torch.export.Dim()
x.shape[0]
请注意,作为模块状态一部分的张量(例如 parameters 和 buffers) 始终具有静态形状。
Python 基元¶
torch.export
还专门研究 Python 原始
如 、 、 和 。但是,它们确实具有动态
变体,如 、 和 。int
float
bool
str
SymInt
SymFloat
SymBool
例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
return (add_2,)
由于整数是专用的,因此运算
都使用硬编码常量 而不是 进行计算。如果
用户在运行时传递的值(如 2)与使用的值不同
在导出时间 1 期间,这将导致错误。
此外,循环中使用的迭代器也是 “inlined”
在图中,通过 3 个重复调用,以及
从不使用 input。torch.ops.aten.add.Tensor
1
arg1_1
arg1_1
times
for
torch.ops.aten.add.Tensor
arg2_1
Python 容器¶
Python 容器(、、 、 等)被视为
具有静态结构。List
Dict
NamedTuple
torch.export 的限制¶
图形中断¶
从
PyTorch 程序中,它最终可能会遇到程序中无法追踪的部分,如
几乎不可能支持跟踪所有 PyTorch 和 Python 功能。在
的情况下,不支持的操作将导致“图
break“,则不支持的操作将使用默认 Python 评估运行。
相反,将要求用户提供额外的
信息或重写其代码的某些部分以使其可跟踪。由于
跟踪基于 TorchDynamo,它在 Python
bytecode 级别,则与
以前的跟踪框架。torch.export
torch.compile
torch.export
当遇到图形中断时,ExportDB 非常有用 用于了解支持的程序类型的资源,以及 unsupported,以及重写程序以使其可跟踪的方法。
避免处理此图形分隔线的一个选项是使用非严格导出
数据/形状相关控制流¶
当形状没有专用化时,在数据依赖的控制流 () 上也会遇到图形中断,因为跟踪编译器不能
可能无需为组合爆炸生成代码即可处理
路径数。在这种情况下,用户需要使用
特殊控制流操作器。目前,我们支持 torch.cond 来表示类似 if-else 的控制流(更多功能即将推出!if
x.shape[0] > 2
缺少运算符的 fake/meta/abstract 内核¶
跟踪时,FakeTensor 内核(又名 meta kernel,abstract impl)是 所有运算符都需要。这用于推断输入/输出形状 对于此运算符。
不幸的是,如果您的模型使用了 ATen 运算符,而 则 没有 尚未实现 FakeTensor 内核,请提交 issue。
API 参考¶
- torch.export 中。export(mod, args, kwargs=无, *, dynamic_shapes=无, strict=True, preserve_module_call_signature=())[来源]¶
采用任意 Python 可调用对象(nn.Module、函数或 方法)以及示例输入,并生成一个表示 仅以 Ahead-of-Time (AOT) 方式对函数进行 Tensor 计算, 随后可以使用不同的 inputs 执行或序列化。这 描摹图 (1) 在函数式 ATen 运算符集中生成归一化运算符 (以及任何用户指定的自定义运算符),(2) 已消除所有 Python 控件 flow 和数据结构(有一些例外),以及 (3) 记录 形状约束需要表明这种归一化和控制流消除 对于将来的输入来说是声音。
稳健性保证
追踪时,
注意与形状相关的假设 由用户程序和底层 PyTorch 算子内核制作。 只有在以下情况下,输出
才被视为有效 假设是正确的。
跟踪对输入张量的形状(而不是值)进行假设。 必须在图形捕获时验证此类假设才能
成功。具体说来:
对输入张量的静态形状的假设会自动验证,无需额外工作。
对输入张量动态形状的假设需要明确指定 通过使用 API 构建动态维度,并通过关联 它们通过参数提供示例输入。
Dim()
dynamic_shapes
如果任何假设无法验证,则会引发致命错误。当这种情况发生时, 错误消息将包含对规范的所需修复建议 来验证假设。例如
,可能会建议 以下修复动态维度的定义,假设出现在 形状 与 input 关联,该 input 之前定义为:
dim0_x
x
Dim("dim0_x")
dim = Dim("dim0_x", max=5)
此示例意味着生成的代码要求输入的维度 0 小于 大于或等于 5 才有效。您可以检查对动态维度的建议修复 定义,然后将它们逐字复制到您的代码中,而无需将参数更改为您的
调用。
x
dynamic_shapes
- 参数
mod (Module) – 我们将跟踪此模块的 forward 方法。
dynamic_shapes (可选[Union[Dict[str, any], tuple[any], list[any]]]) –
一个可选参数,其中类型应为: 1) 从参数名称到其动态形状规范的字典, 2) 一个元组,它按原始顺序为每个输入指定动态形状规范。 如果要在关键字 arg 上指定动态性,则需要按照 在原始函数签名中定义。
f
张量参数的动态形状可以指定为 (1) 从动态维度索引到类型的字典,其中 不需要在此字典中包含静态维度索引,但当它们包含时, 它们应该映射到 None;或 (2) 元组/类型列表或 None, 其中,类型对应于动态维度和静态维度 由 None 表示。作为 dict 或 tuples / 张量列表的参数是 通过使用包含的规范的映射或序列递归指定。
Dim()
Dim()
Dim()
strict (bool) – 启用(默认)后,导出函数将跟踪程序 TorchDynamo,它将确保结果图形的健全性。否则, exported program 不会验证 Graph 中烘焙的隐式假设,并且 可能会导致原始模型和导出的模型之间出现行为差异。这是 当用户需要解决 Tracer 中的错误,或者只是希望逐步解决时,这很有用 在他们的模型中启用安全性。请注意,这不会影响生成的 IR 规范 不同,并且无论值如何,模型都将以相同的方式序列化 在此处传递。 警告:此选项是实验性的,使用此选项的风险由您自己承担。
- 返回
- 返回类型
可接受的输入/输出类型
可接受的输入类型(for 和 )和输出包括:
args
kwargs
- torch.export 中。save(ep, f, *, extra_files=无, opset_version=无)[来源]¶
警告
在积极开发中,保存的文件可能无法在较新的版本中使用 PyTorch 中。
将 保存到
类似文件的对象。然后它可以是 使用 Python API
加载。
- 参数
ep (ExportedProgram) – 要保存的导出程序。
f (Union[str, os.PathLike、io.BytesIO) – 一个类似文件的对象 (必须 implement write 和 flush)或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 从文件名映射到内容 将作为 F 的一部分存储。
opset_version (Optional[Dict[str, int]]) – opset 名称的映射 到这个 opset 的版本
例:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export 中。load(f, *, extra_files=无, expected_opset_version=无)[来源]¶
警告
在积极开发中,保存的文件可能无法在较新的版本中使用 PyTorch 中。
- 参数
ep (ExportedProgram) – 要保存的导出程序。
f (Union[str, os.PathLike、io.BytesIO) – 一个类似文件的对象 (必须 implement write 和 flush)或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 在 此映射将被加载,其内容将存储在 提供的地图。
expected_opset_version (Optional[Dict[str, int]]) – opset 名称的映射 到预期的 Opset 版本
- 返回
- 返回类型
例:
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export 中。register_dataclass(cls, *, serialized_type_name=无)[来源]¶
-
- 参数
例:
@dataclass class InputDataClass: feature: torch.Tensor bias: int class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) def fn(o: InputDataClass) -> torch.Tensor: res = res=o.feature + o.bias return OutputDataClass(res=res) ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes。Dim(name, *, min=None, max=None)[来源]¶
构造一个类型,该类型类似于具有 range 的命名符号整数。 它可用于描述动态张量维度的多个可能值。 请注意,同一张量或不同张量的不同动态维度, 可以用相同的类型来描述。
- torch.export.dynamic_shapes 类。ShapesCollection[来源]¶
Builder for dynamic_shapes。 用于将动态形状规范分配给 inputs 中出现的张量。
- 例::
args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})
dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (暗淡, 暗淡 + 1, 8) dynamic_shapes[tensor_y] = {0: 暗淡 * 2} # 这相当于以下内容(现在是自动生成的): # dynamic_shapes = {“x”: (dim, dim + 1, 8), “others”: [{0: dim * 2}, None]}
torch.export(..., args, dynamic_shapes=dynamic_shapes)
- torch.export.dynamic_shapes。refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[来源]¶
对于导出的动态形状,建议的修复和/或自动动态形状。 在给定 ConstraintViolation 错误消息和原始动态形状的情况下,优化给定的动态形状规范。
在大多数情况下,行为很简单 - 即对于专门化或优化 Dim 范围的建议修复, 或建议派生关系的修复,则新的 Dynamic Shapes 规范将按此方式更新。
例如 建议的修复方法:
dim = Dim('dim', min=3, max=6) -> 这只是优化了 dim 的范围 dim = 4 -> 这专门用于常数 dy = dx + 1 -> dy 被指定为独立的 dim,但实际上通过此关系与 dx 相关联
但是,与派生的 dim 相关的建议修复可能更复杂。 例如,如果为根 dim 提供了建议的修复,则根据根评估新的派生 dim 值。
例如 dx = Dim('dx') dy = dx + 2 dynamic_shapes = {“x”: (dx,), “y”: (dy,)}
建议的修复方法:
dx = 4 # 专业化将导致 dy 也专业化 = 6 dx = Dim('dx', max=6) # dy 现在有 max = 8
派生的 dims 建议的修复也可用于表示整除约束。 这涉及创建不绑定到特定输入形状的新根暗淡。 在这种情况下,根 dims 不会直接显示在新等级库中,而是作为 其中一个 Dims。
例如 建议的修复方法:
_dx = Dim('_dx', max=1024) # 这不会出现在返回结果中,但 dx 会 dx = 4*_dx # dx 现在可以被 4 整除,最大值为 4096
- torch.export 中。约束¶
- 类 torch.export 中。ExportedProgram(根, 图形, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=无, 常量=无, *, verifiers=无)[来源]¶
来自
的程序的包 。它包含 一个
表示 Tensor 计算的 Tensor 计算,一个state_dict包含 所有提升的参数和缓冲区的张量值,以及各种元数据。
您可以调用ExportedProgram,就像使用相同的调用约定跟踪的原始
可调用对象一样。
要对图形执行转换,请使用 property 访问 一个
.然后,您可以使用 FX 转换来重写图形。之后,您可以简单地再次使用
来构建正确的 ExportedProgram。
.module
- run_decompositions(decomp_table=无,_preserve_ops=()))[来源]¶
在导出的程序上运行一组分解,并返回一个新的 exported 程序。默认情况下,我们将运行 Core ATen 分解以 获取 Core ATen Operator Set 中的运算符。
目前,我们不分解关节图。
- 返回类型
- 类 torch.export 中。ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[来源]¶
- 类 torch.export 中。ExportGraphSignature(input_specs, output_specs)[来源]¶
对 Export Graph 的输入/输出签名进行建模, 这就是 FX。具有更强不变量保证的图形。
Export Graph 是功能性的,并且不访问参数之类的“状态” 或通过节点在图中的缓冲区。相反,
保证参数、缓冲区和常量张量从 将图形作为输入。同样,缓冲区的任何更改也不包括 在图中,mutated buffers 的更新值为 建模为 Export Graph 的附加输出。
getattr
所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出了以下模块:
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的 Graph 将为:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的 ExportGraphSignature 将为:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- 类 torch.export 中。ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument torch.export.graph_signature。SymIntArgument, torch.export.graph_signature.ConstantArgument 的 torch.export.graph_signature。CustomObjArgument torch.export.graph_signature.TokenArgument]],输出:List[Union[torch.export.graph_signature.TensorArgument torch.export.graph_signature。SymIntArgument, torch.export.graph_signature.ConstantArgument 的 torch.export.graph_signature。CustomObjArgument torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec,out_spec:torch.utils._pytree。TreeSpec)[来源]¶
- 类 torch.export 中。ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[来源]¶
- torch.export.graph_signature 类。InputSpec(种类: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument torch.export.graph_signature。SymIntArgument 的 torch.export.graph_signature.ConstantArgument 的 torch.export.graph_signature。CustomObjArgument torch.export.graph_signature.TokenArgument], 目标: Optional[str], 持久: Optional[bool] = 无)[来源]¶
- torch.export.graph_signature 类。OutputSpec(种类: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument torch.export.graph_signature。SymIntArgument 的 torch.export.graph_signature.ConstantArgument 的 torch.export.graph_signature。CustomObjArgument torch.export.graph_signature.TokenArgument],目标:Optional[str])[源]¶
- torch.export.graph_signature 类。ExportGraphSignature(input_specs, output_specs)[来源]¶
对 Export Graph 的输入/输出签名进行建模, 这就是 FX。具有更强不变量保证的图形。
Export Graph 是功能性的,并且不访问参数之类的“状态” 或通过节点在图中的缓冲区。相反,保证参数、缓冲区和常量张量从 将图形作为输入。同样,缓冲区的任何更改也不包括 在图中,mutated buffers 的更新值为 建模为 Export Graph 的附加输出。
getattr
export()
所有输入和输出的顺序为:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出了以下模块:
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的 Graph 将为:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的 ExportGraphSignature 将为:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- torch.export.graph_signature 类。CustomObjArgument(name: str, class_fqn: str, fake_val: 可选[torch._library.fake_class_registry.FakeScriptObject] = None)[来源]¶
- class torch.export.unflatten 中。InterpreterModule(图)[来源]¶
使用 torch.fx.Interpreter 执行的模块,而不是通常的 codegen 中。这提供了更好的堆栈跟踪信息 ,并且更容易调试执行。
- torch.export.unflatten 的unflatten(module, flat_args_adapter=None)[来源]¶
取消展平 ExportedProgram,生成具有相同模块的模块 hierarchy 作为原始 Eager 模块。如果您正在尝试,这可能很有用 与另一个需要 module 的系统一起使用
hierachy 而不是通常生成的平面图
。
注意
未拼合的模块的 args/kwargs 不一定匹配 eager 模块,因此进行 module swap (例如 ) 不一定有效。如果您需要换出模块,您可以 需要设置 的参数
。
self.submod = new_mod
preserve_module_call_signature
- 参数
module (ExportedProgram) – 要展平的 ExportedProgram。
flat_args_adapter (Optional[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出的模块不匹配,则调整平面参数。
- 返回
的实例 ,具有相同的模块 hierarchy 作为原始的 Eager Module 预导出。
UnflattenedModule
- 返回类型
未展平的模块
- torch.export.passes 的move_to_device_pass(EP, Location)[来源]¶
将导出的程序移动到给定的设备。
- 参数
ep (ExportedProgram) – 要移动的导出程序。
location (Union[torch.device, str, Dict[str, str]]) - 要将导出的程序移动到的设备。 如果是字符串,则将其解释为设备名称。 如果是 dict,则将其解释为 将现有设备转换为预期设备
- 返回
移动的导出程序。
- 返回类型