目录

TensorDictModule

作者Nicolas DufourVincent Moens

在本教程中,您将学习如何使用创建可以接受作为输入的通用和可重用模块。

为了方便使用类,在 两者之间提供了一个名为 的接口。tensordict

类是一个在调用时将 a 作为输入的类。它将读取一系列输入键,将它们传递给包装的 module 或 function 作为 input,并在执行完成后将输出写入同一个 tensordict 中。

由用户定义要作为输入和输出读取的键。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

简单示例:编写循环层

最简单的用法如下所示。 如果乍一看使用这个类似乎引入了前所未有的复杂程度,我们将看到 稍后,此 API 使用户能够以编程方式将模块连接在一起,缓存值 或以编程方式构建一个模块。 最简单的例子之一是 ResNet 等架构中的 recurrent 模块,其中 模块被缓存并添加到一个微型多层感知器 (MLP) 的输出中。

首先,让我们首先考虑我们将 MLP 分块,并使用 . 堆栈的第一层可能是一个层,将条目作为输入 (我们将其命名为 x)并输出另一个条目(我们将它命名为 y)。tensordict.nn

为了提供给我们的模块,我们有一个带有单个条目的实例:"x"

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

现在,我们使用 .默认情况下,此类在 input tensordict in-place(意味着条目与输入写入同一个 tensordict 中,而不是该条目) 就地覆盖!),这样我们就不需要明确指出输出是什么:

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模块输出多个张量(或 tensordicts),则必须以正确的顺序传递它们的条目。

支持可调用对象

在设计模型时,经常会发生您希望将任意非参数函数合并到 网络。例如,您可能希望在将图像传递到卷积网络时排列图像的维度 或视觉转换器,或将值除以 255。 有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个执行此操作的 new

适用于任何可调用对象,而不仅仅是模块,这使得 将任意函数合并到模块中。例如,让我们看看如何集成激活 函数:relu

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆叠模块

我们的 MLP 不是由单层组成的,因此我们现在需要为其添加另一层。 该层将是一个激活函数,例如 . 我们可以使用 堆叠这个模块和前一个模块。

注意

这就是真正的强大之处:与 不同,它将所有先前的输入和输出保留在内存中 (事后可以过滤掉它们),从而轻松拥有复杂的网络结构 以编程方式即时构建。tensordict.nn

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我们可以重复这个逻辑来获得完整的 MLP:

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多个输入键

残差网络的最后一步是将输入添加到最后一个线性层的输出中。 无需为此编写特殊的子类!也可以用于包装简单的函数:

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

我们现在可以将 和 for 一个完全充实的残差块放在一起:block0block1residual

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

真正的问题可能是用作输入的 tensordict 中条目的累积:在某些情况下(例如,当 gradients 是必需的)中间值可能被缓存,但情况并非总是如此,它可能很有用 让垃圾回收器知道某些条目可以被丢弃。 和 其子类(包括 ) 可以选择在执行后查看其 output keys filtered 的筛选结果。为此,只需调用 method.这将就地更新模块,并且所有 不需要的条目将被丢弃:

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

但是,将保留输入键:

assert "x" in tensordict

作为旁注,也可以传递给以避免 单独调用该方法。selected_out_keys

在没有 tensordict 的情况下使用 TensorDictModule

随时随地构建复杂架构的机会 并不意味着必须切换到 tensordict 来表示数据。多亏了 ,tensordict.nn 中的模块支持与 条目名称也:

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在后台,重新构建 tensordict,运行模块,然后解构它。 这可能会导致一些开销,但正如我们稍后将看到的那样,有一个解决方案可以消除这种情况。

运行

并且在以下情况下会产生一些开销 执行,因为它们需要从 Tensordict 中读取和写入。但是,我们可以通过使用 .为此,让我们比较一下这段代码的三个版本(有 compile 和没有 compile):

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular:  219.5165 us
TDM:  260.3091 us
Sequential:  375.0590 us
Compiled versions
Compiled regular:  326.0555 us
Compiled TDM:  333.1850 us
Compiled sequential:  342.4750 us

可以看出,由 onverhead 引入的问题已经完全解决了。

使用 TensorDictModule 的该做和不该做

  • 不要使用 .它会破坏输入/输出 key 结构。 总是尝试依赖。Sequencetensordict.nnnn:TensorDictSequential

  • 不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是就地修改的输入。 分配新的变量名称并不是严格禁止的,但这意味着您可能希望它们都消失 当一个被删除时,实际上垃圾回收器仍会看到工作区中的张量和无内存 将被释放:

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

使用分配:

是一个非参数模块,表示 概率分布。分布参数是从 tensordict 中读取的 input,并将输出写入输出 tensordict。输出为 sampled 给定一些规则,由 input 参数和全局函数指定。如果它们发生冲突, 上下文管理器 (Context Manager) 位于其前面。default_interaction_typeinteraction_type()

它可以与返回 使用 .这是最后一个层是实例的特例。ProbabilisticTensorDictSequential

负责构造 分发(通过方法)和/或 从此发行版中采样(通过对 Module 的常规 forward 调用)。相同的方法在 中公开。ProbabilisticTensorDictSequential

可以在 output tensordict 和 log 中找到参数 概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

结论

我们已经看到了如何使用 tensordict.nn 动态构建复杂的神经架构。 这为构建不关心模型签名的管道提供了可能性,即编写通用代码 以灵活的方式使用具有任意数量的输入或输出的网络。

我们还看到了如何使用 tensordict.nn 来构建这样的网络并使用 他们没有直接重复。多亏了 ,开销 introduced by 可以完全删除,给用户留下一个整洁的、 tensordict free 版本。

在下一个教程中,我们将了解如何使用它来隔离模块并导出它。torch.export

脚本总运行时间:(0 分 18.375 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源