注意
转到末尾 以下载完整示例代码。
TensorDictModule¶
作者: Nicolas Dufour, Vincent Moens
在本教程中,您将学习如何使用 TensorDictModule 和
TensorDictSequential 来创建通用且可重用的模块,这些模块可以接受
TensorDict 作为输入。
为方便将TensorDict类与Module结合使用,
tensordict提供了一个名为TensorDictModule的接口。
The TensorDictModule 类是一个 Module,在调用时接受一个
TensorDict 作为输入。它将读取一系列输入键,将它们传递给包装的
模块或函数作为输入,并在执行完成后将输出写入相同的 tensordict。
用户需要定义读取的输入和输出键。
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
简单示例:编码循环层¶
TensorDictModule 的最简单用法如下所示。
初看时,使用该类似乎会引入不必要的复杂性,但稍后我们将看到,此 API 使用户能够以编程方式将模块串联起来、在模块之间缓存值,或以编程方式构建模块。
此类最简单的示例之一是 ResNet 等架构中的循环模块:该模块的输入被缓存,并与一个小型多层感知机(MLP)的输出相加。
首先,让我们先考虑如何将MLP分块,并使用tensordict.nn进行编码。
堆栈的第一层很可能是Linear层,将一个输入项(我们将其命名为x)作为输入,并输出另一个项(我们将其命名为y)。
要传递给我们的模块,我们有一个 TensorDict 实例,其中包含一个条目,
"x":
tensordict = TensorDict(
x=torch.randn(5, 3),
batch_size=[5],
)
现在,我们使用 tensordict.nn.TensorDictModule 构建我们的简单模块。默认情况下,这个类会在输入的 tensordict 中进行原地写入(意味着条目会写入与输入相同的 tensordict 中,而不是在原地覆盖!),因此我们不需要显式地指示输出是什么:
linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)
assert "linear0" in tensordict
如果模块输出多个张量(或张量字典!),它们的条目必须按正确的顺序传递给
TensorDictModule。
对可调用对象的支持¶
在设计模型时,经常需要将任意的非参数函数融入网络中。例如,你可能希望在将图像传递给卷积网络或视觉变换器时调整其维度顺序,或者将值除以255。
有几种方法可以实现这一点:你可以使用一个forward_hook,例如,或者设计一个新的Module来执行此操作。
TensorDictModule 可与任意可调用对象配合使用,而不仅限于模块,这使得将任意函数轻松集成到模块中变得十分便捷。例如,我们来看一下如何在不使用 ReLU 模块的情况下集成 relu 激活函数:
relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])
堆叠模块¶
我们的MLP不是由单层组成的,所以我们现在需要向它添加另一层。
这一层将是一个激活函数,例如 ReLU。
我们可以使用 TensorDictSequential 将这个模块和前一个模块堆叠起来。
注意
这里展示了tensordict.nn的真正力量:与Sequential不同,
TensorDictSequential会将所有之前的输入和输出保留在内存中
(之后可以过滤掉),这使得可以轻松地构建动态和程序化的复杂网络结构。
block0 = TensorDictSequential(linear0, relu0)
block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict
我们可以重复这个逻辑来构建一个完整的多层感知机:
linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)
多个输入键¶
残差网络的最后一步是将输入加到最后一个线性层的输出上。
不需要为这个编写一个特殊的Module子类!TensorDictModule
也可以用来包装简单的函数:
residual = TensorDictModule(
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)
现在我们可以将 block0、block1 和 residual 组合起来,形成一个完整的残差块:
block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict
一个真正的担忧可能是输入的 tensordict 中条目的累积:在某些情况下(例如,需要梯度时)可能会缓存中间值,但这并不总是如此,并且让垃圾收集器知道可以丢弃某些条目可能会很有用。tensordict.nn.TensorDictModuleBase 和其子类(包括 tensordict.nn.TensorDictModule 和 tensordict.nn.TensorDictSequential)
有选择在执行后过滤输出键的选项。要做到这一点,只需调用
tensordict.nn.TensorDictModuleBase.select_out_keys 方法。这将就地更新模块,并且所有不需要的条目将被丢弃:
block.select_out_keys("y")
tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict
assert "linear1" not in tensordict
然而,输入键被保留:
assert "x" in tensordict
作为附注,selected_out_keys 也可以传递给 tensordict.nn.TensorDictSequential 以避免
单独调用此方法。
使用 TensorDictModule 而不使用 tensordict¶
由 tensordict.nn.TensorDictSequential 提供的机会可以在运行时构建复杂的架构
并不意味着必须切换到 tensordict 来表示数据。多亏了
dispatch,来自 tensordict.nn 的模块支持与条目名称匹配的参数和关键字参数:
x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)
在幕后,dispatch 重建了一个张量字典,运行模块然后解构它。
这可能会导致一些开销,但正如我们稍后将看到的,有一个解决方案可以摆脱这个问题。
运行时¶
tensordict.nn.TensorDictModule 和 tensordict.nn.TensorDictSequential 在执行时确实会带来一些开销,因为它们需要从 tensordict 读取和写入。但是,我们可以通过使用
compile() 大大减少这种开销。为此,让我们比较这三种代码版本在编译前后的情况:
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear0 = nn.Linear(3, 128)
self.relu0 = nn.ReLU()
self.linear1 = nn.Linear(128, 128)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(128, 3)
def forward(self, x):
y = self.linear0(x)
y = self.relu0(y)
y = self.linear1(y)
y = self.relu1(y)
return self.linear2(y) + x
print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block
from torch.utils.benchmark import Timer
print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular: 219.5165 us
TDM: 260.3091 us
Sequential: 375.0590 us
Compiled versions
Compiled regular: 326.0555 us
Compiled TDM: 333.1850 us
Compiled sequential: 342.4750 us
可以看出,由 TensorDictSequential 引入的开销已完全解决。
TensorDictModule的使用指南¶
不要在
Sequence中使用来自tensordict.nn的模块。这会破坏输入/输出键结构。 总是尝试依赖于nn:TensorDictSequential。不要将输出的 TensorDict 赋值给一个新变量,因为输出的 TensorDict 仅是对输入 TensorDict 的原地修改。 为输出 TensorDict 赋予一个新变量名虽非严格禁止,但可能造成误解:你或许希望其中一个变量被删除时,另一个也一并消失;而实际上,垃圾回收器仍能看到工作区中的张量,因此内存不会被释放。
>>> tensordict = module(tensordict) # ok! >>> tensordict_out = module(tensordict) # don't!
与分布一起工作:ProbabilisticTensorDictModule¶
ProbabilisticTensorDictModule 是一个非参数模块,表示概率分布。分布参数从 tensordict 输入中读取,并将输出写入输出 tensordict。根据输入 default_interaction_type 参数和 interaction_type() 全局函数指定的某些规则进行采样。如果它们冲突,则上下文管理器优先。
它可以与一个返回张量字典(tensordict)的 TensorDictModule 组合使用,该张量字典已通过 ProbabilisticTensorDictSequential 更新了分布参数。这是 TensorDictSequential 的一种特殊情况,其最后一层为一个 ProbabilisticTensorDictModule 实例。
ProbabilisticTensorDictModule 负责构建分布(通过 get_dist() 方法)和/或从该分布中采样(通过对该模块进行常规 forward 调用)。相同的 get_dist() 方法在 ProbabilisticTensorDictSequential 中暴露。
用户可以在输出的 tensordict 中找到参数,如有需要,也可获取对数概率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
结论¶
我们已经看到如何使用tensordict.nn来动态构建复杂的神经网络架构。 这为构建与模型签名无关的管道提供了可能性,即编写通用代码, 以灵活的方式使用具有任意数量输入或输出的网络。
我们还看到了如何 dispatch 使我们能够使用 tensordict.nn 来构建这样的网络并使用它们,而无需直接诉诸于 TensorDict。由于 compile(),由 tensordict.nn.TensorDictSequential 引入的开销可以完全消除,为用户留下一个整洁、无 tensordict 版本的模块。
在下一个教程中,我们将看到如何使用torch.export来隔离一个模块并导出它。
脚本总运行时间: (0 分钟 18.375 秒)