注意
转到末尾下载完整的示例代码。
导出 tensordict 模块¶
作者: Vincent Moens
先决条件¶
阅读 TensorDictModule 教程最好能从本教程中充分受益。
使用 编写模块后,隔离计算图并导出通常很有用
那个图表。这样做的目标可能是在硬件(例如,机器人、无人机、边缘设备)上执行模型,或者消除
完全依赖 tensordict。tensordict.nn
PyTorch 提供了多种导出模块的方法,包括 和 ,这两种方式都是
兼容 。onnx
torch.export
tensordict
在这个简短的教程中,我们将了解如何使用它来隔离模型的计算图。 支持遵循相同的逻辑。torch.export
torch.onnx
主要学习成果¶
import time
import torch
from tensordict.nn import (
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as Prob,
set_interaction_type,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)
from torch import distributions as dists, nn
设计模型¶
在许多应用程序中,使用随机模型非常有用,即输出变量的模型不是 确定性定义,但根据参数分布进行采样。例如,生成式 AI 当提供相同的输入时,模型通常会生成不同的输出,因为它们基于输出 在分布中,哪些参数由 Input.
该库通过类处理此问题。
此原语是使用 distribtion 类(在本例中为 )和指示器构建的
的输入键。
tensordict
Normal
因此,我们正在构建的网络将是三个主要组件的组合:
将输入映射到潜在参数的网络;
一个模块,将输入拆分为位置 “loc” 和 “scale” 参数,以传递给分配;
tensordict.nn.NormalParamExtractor
Normal
一个 distribution 构造函数模块。
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
# 3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
让我们运行这个模型,看看输出是什么样子的:
x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.0000, 0.2604, 0.0000, 0.0000]], grad_fn=<ReluBackward0>), tensor([[-0.1580, -0.5222, -0.3319, 0.5519]], grad_fn=<AddmmBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>), tensor([[0.8046, 1.3804]], grad_fn=<ClampMinBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>))
正如预期的那样,使用张量输入运行模型会返回与模块的输出键一样多的张量!对于大型 模型,这可能非常烦人和浪费。稍后,我们将看到如何限制 model 来处理此问题。
与torch.export
TensorDictModule
¶
现在我们已经成功构建了我们的模型,我们想将其计算图提取到一个对象中,该对象
独立于 。 是一个 PyTorch 模块,专门用于隔离模块的图形,
以标准化的方式表示它。它的主要入口点是返回一个对象。反过来,这个对象有几个我们感兴趣的属性,我们将在下面探讨:一个 ,
它表示由 捕获的 FX 图,其中包含图的输入、输出等,
最后是 a,它返回一个可以代替原始模块使用的可调用对象。
tensordict
torch.export
ExportedProgram
graph_module
export
graph_signature
module()
虽然我们的模块同时接受 args 和 kwargs,但我们将重点介绍它与 kwargs 的用法,因为这更清楚。
from torch.export import export
model_export = export(model, args=(), kwargs={"x": x})
让我们看看这个模块:
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
这个模块可以像我们原来的模块一样运行(开销更低):
Time for TDModule: 469.45 micro-seconds
Time for exported module: 340.70 micro-seconds
和 FX 图表:
print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
getitem: "f32[1, 2]" = split[0]
getitem_1: "f32[1, 2]" = split[1]; split = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
fx graph: class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
getitem: "f32[1, 2]" = split[0]
getitem_1: "f32[1, 2]" = split[1]; split = None
# File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
使用嵌套键¶
嵌套键是 tensordict 库的核心功能,能够导出读写模块
因此,嵌套条目是需要支持的重要功能。
因为关键字参数必须是正则表达式字符串,所以 不可能工作
直接与他们。相反,将解包用常规下划线 (“_”) 连接的嵌套键,因为
以下示例显示。
dispatch
model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))
model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()
def forward(self, some_key):
some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None
sub = torch.ops.aten.sub.Tensor(add, 1); add = None
return pytree.tree_unflatten((sub,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
保存导出的模块¶
torch.export
具有自己的序列化协议,并且
.
通常,使用 “.pt2 ”扩展名。
>>> torch.export.save(model_export, "model.pt2")
选择输出¶
回想一下,除非用户特别要求,否则 this 将保留输出中的每个中间值
仅针对特定值。在训练过程中,这可能非常有用:可以很容易地记录
图,或将它们用于其他目的(例如,根据其保存的参数重建分布,而不是
保存对象本身)。人们还可以争辩说,在训练期间,
注册中间值对内存的影响可以忽略不计,因为它们是计算图的一部分
用于计算参数梯度。tensordict.nn
Distribution
torch.autograd
但是,在推理过程中,我们很可能只对模型的最终样本感兴趣。
由于我们希望为独立于库的用法提取模型,因此
隔离我们唯一想要的输出。
为此,我们有几种选择:tensordict
使用该方法,该方法将就地修改属性(这可以通过 进行还原)。
select_out_keys()
out_keys
reset_out_keys()
-
>>> module_filtered = Seq(module, selected_out_keys=["sample"])
让我们在选择其输出键后测试模型。 当提供 x 输入时,我们希望我们的模型输出一个对应于 分配:
tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>)
我们看到输出现在是单个张量,对应于分布的样本。 我们可以从中创建新的导出图。其计算图应简化:
model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1); linear_1 = None
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]; broadcast_tensors = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
控制采样策略¶
我们尚未讨论 distribution 中的样本如何。
抽样是指根据特定策略在分布定义的空间内获得一个值。
例如,您可能希望在训练期间获得随机样本,但确定性样本(例如,均值或
模式)的 Pod S Interface。为了解决这个问题,使用了
decorator 和 context manager,它接受 Enum 输入:
tensordict
InteractionType
>>> with set_interaction_type(InteractionType.MEAN):
... output = module(input) # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked
默认值为 ,如果未直接实现,则为
具有实域的分布的均值,或具有离散域的分布模式。此默认值
可以使用 的关键字参数 进行更改。InteractionType
InteractionType.DETERMINISTIC
default_interaction_type
ProbabilisticTensorDictModule
让我们回顾一下:为了控制我们网络的采样策略,我们可以在
构造函数,或在运行时通过 Context Manager 覆盖它。set_interaction_type
从下面的例子中可以看出,正确响应装饰器的用法:如果我们要求
随机样本,则输出与我们要求平均值时的输出不同:torch.export
with set_interaction_type(InteractionType.RANDOM):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
with set_interaction_type(InteractionType.MEAN):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1); linear_1 = None
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
normal_functional = torch.ops.aten.normal_functional.default(empty); empty = None
mul = torch.ops.aten.mul.Tensor(normal_functional, getitem_3); normal_functional = getitem_3 = None
add_2 = torch.ops.aten.add.Tensor(getitem_2, mul); getitem_2 = mul = None
return pytree.tree_unflatten((add_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
split = torch.ops.aten.split.Tensor(linear_1, 2, -1); linear_1 = None
getitem = split[0]
getitem_1 = split[1]; split = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]; broadcast_tensors = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
这就是使用 .有关更多信息,请参阅官方文档。torch.export
后续步骤和延伸阅读¶
在此处查看教程
torch.export
;ONNX 支持:查看 ONNX 教程以了解有关此功能的更多信息。导出到 ONNX 与此处解释的 torch.export 非常相似。
如需在没有 python 环境的服务器上部署 PyTorch 代码,请查看 AOTInductor 文档。
脚本总运行时间:(0 分 1.695 秒)