注意
转到末尾下载完整的示例代码。
TensorDictModule¶
作者: Nicolas Dufour, Vincent Moens
在本教程中,您将学习如何使用和
创建可以接受作为输入的
通用和可重用模块。
为了方便使用类,在
两者之间提供了一个名为
的接口。
tensordict
该类是一个
在调用时将 a
作为输入的类。它将读取一系列输入键,将它们传递给包装的
module 或 function 作为 input,并在执行完成后将输出写入同一个 tensordict 中。
由用户定义要作为输入和输出读取的键。
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
简单示例:编写循环层¶
最简单的用法如下所示。
如果乍一看使用这个类似乎引入了前所未有的复杂程度,我们将看到
稍后,此 API 使用户能够以编程方式将模块连接在一起,缓存值
或以编程方式构建一个模块。
最简单的例子之一是 ResNet 等架构中的 recurrent 模块,其中
模块被缓存并添加到一个微型多层感知器 (MLP) 的输出中。
首先,让我们首先考虑我们将 MLP 分块,并使用 .
堆栈的第一层可能是一个层,将条目作为输入
(我们将其命名为 x)并输出另一个条目(我们将它命名为 y)。
tensordict.nn
为了提供给我们的模块,我们有一个带有单个条目的实例:"x"
tensordict = TensorDict(
x=torch.randn(5, 3),
batch_size=[5],
)
现在,我们使用 .默认情况下,此类在
input tensordict in-place(意味着条目与输入写入同一个 tensordict 中,而不是该条目)
就地覆盖!),这样我们就不需要明确指出输出是什么:
linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)
assert "linear0" in tensordict
如果模块输出多个张量(或 tensordicts),则必须以正确的顺序传递它们的条目。
支持可调用对象¶
在设计模型时,经常会发生您希望将任意非参数函数合并到
网络。例如,您可能希望在将图像传递到卷积网络时排列图像的维度
或视觉转换器,或将值除以 255。
有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个执行此操作的 new。
适用于任何可调用对象,而不仅仅是模块,这使得
将任意函数合并到模块中。例如,让我们看看如何集成激活
函数:
relu
relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])
堆叠模块¶
我们的 MLP 不是由单层组成的,因此我们现在需要为其添加另一层。
该层将是一个激活函数,例如 .
我们可以使用
堆叠这个模块和前一个模块。
block0 = TensorDictSequential(linear0, relu0)
block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict
我们可以重复这个逻辑来获得完整的 MLP:
linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)
多个输入键¶
残差网络的最后一步是将输入添加到最后一个线性层的输出中。
无需为此编写特殊的子类!
也可以用于包装简单的函数:
residual = TensorDictModule(
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)
我们现在可以将 和 for 一个完全充实的残差块放在一起:block0
block1
residual
block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict
真正的问题可能是用作输入的 tensordict 中条目的累积:在某些情况下(例如,当
gradients 是必需的)中间值可能被缓存,但情况并非总是如此,它可能很有用
让垃圾回收器知道某些条目可以被丢弃。 和
其子类(包括
和
)
可以选择在执行后查看其 output keys filtered 的筛选结果。为此,只需调用
method.这将就地更新模块,并且所有
不需要的条目将被丢弃:
block.select_out_keys("y")
tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict
assert "linear1" not in tensordict
但是,将保留输入键:
assert "x" in tensordict
在没有 tensordict 的情况下使用 TensorDictModule¶
随时随地构建复杂架构的机会
并不意味着必须切换到 tensordict 来表示数据。多亏了
,tensordict.nn 中的模块支持与
条目名称也:
x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)
在后台,重新构建 tensordict,运行模块,然后解构它。
这可能会导致一些开销,但正如我们稍后将看到的那样,有一个解决方案可以消除这种情况。
运行¶
并且
在以下情况下会产生一些开销
执行,因为它们需要从 Tensordict 中读取和写入。但是,我们可以通过使用
.为此,让我们比较一下这段代码的三个版本(有 compile 和没有 compile):
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear0 = nn.Linear(3, 128)
self.relu0 = nn.ReLU()
self.linear1 = nn.Linear(128, 128)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(128, 3)
def forward(self, x):
y = self.linear0(x)
y = self.relu0(y)
y = self.linear1(y)
y = self.relu1(y)
return self.linear2(y) + x
print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block
from torch.utils.benchmark import Timer
print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular: 219.5165 us
TDM: 260.3091 us
Sequential: 375.0590 us
Compiled versions
Compiled regular: 326.0555 us
Compiled TDM: 333.1850 us
Compiled sequential: 342.4750 us
使用 TensorDictModule 的该做和不该做¶
不要使用 .它会破坏输入/输出 key 结构。 总是尝试依赖。
Sequence
tensordict.nn
nn:TensorDictSequential
不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是就地修改的输入。 分配新的变量名称并不是严格禁止的,但这意味着您可能希望它们都消失 当一个被删除时,实际上垃圾回收器仍会看到工作区中的张量和无内存 将被释放:
>>> tensordict = module(tensordict) # ok! >>> tensordict_out = module(tensordict) # don't!
使用分配:
¶
是一个非参数模块,表示
概率分布。分布参数是从 tensordict 中读取的
input,并将输出写入输出 tensordict。输出为
sampled 给定一些规则,由 input 参数和全局函数指定。如果它们发生冲突,
上下文管理器 (Context Manager) 位于其前面。
default_interaction_type
interaction_type()
它可以与返回
使用 .这是最后一个层是
实例的特例。
ProbabilisticTensorDictSequential
负责构造
分发(通过
方法)和/或
从此发行版中采样(通过对 Module 的常规 forward 调用)。相同的
方法在 中公开。
ProbabilisticTensorDictSequential
可以在 output tensordict 和 log 中找到参数 概率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
结论¶
我们已经看到了如何使用 tensordict.nn 动态构建复杂的神经架构。 这为构建不关心模型签名的管道提供了可能性,即编写通用代码 以灵活的方式使用具有任意数量的输入或输出的网络。
我们还看到了如何使用 tensordict.nn 来构建这样的网络并使用
他们没有直接重复。
多亏了
,开销
introduced by
可以完全删除,给用户留下一个整洁的、
tensordict free 版本。
在下一个教程中,我们将了解如何使用它来隔离模块并导出它。torch.export
脚本总运行时间:(0 分 18.375 秒)