目录

TensorDictModule

作者Nicolas DufourVincent Moens

在本教程中,您将学习如何使用TensorDictModuleTensorDictSequential创建通用且可重用的模块,这些模块可以接受TensorDict作为输入。

为了方便使用TensorDictclass 替换为Module在两个名为tensordictTensorDictModule.

TensorDictModuleclass 是一个Module这需要TensorDict作为 input 调用时。它将读取一系列输入键,将它们传递给包装的 module 或 function 作为 input,并在执行完成后将输出写入同一个 tensordict 中。

由用户定义要作为输入和输出读取的键。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

简单示例:编写循环层

最简单的用法TensorDictModule示例如下。 如果乍一看使用这个类似乎引入了前所未有的复杂程度,我们将看到 稍后,此 API 使用户能够以编程方式将模块连接在一起,缓存值 或以编程方式构建一个模块。 最简单的例子之一是 ResNet 等架构中的 recurrent 模块,其中 模块被缓存并添加到一个微型多层感知器 (MLP) 的输出中。

首先,让我们首先考虑我们将 MLP 分块,并使用 . 堆栈的第一层大概是一个tensordict.nnLinear层,将条目作为输入 (我们将其命名为 x)并输出另一个条目(我们将它命名为 y)。

为了馈送到我们的模块,我们有一个TensorDict实例,其中包含单个条目 :"x"

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

现在,我们使用tensordict.nn.TensorDictModule.默认情况下,此类在 input tensordict in-place(意味着条目与输入写入同一个 tensordict 中,而不是该条目) 就地覆盖!),这样我们就不需要明确指出输出是什么:

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模块输出多个张量(或 tensordicts),则必须将它们的条目传递给TensorDictModule以正确的顺序。

支持可调用对象

在设计模型时,经常会发生您希望将任意非参数函数合并到 网络。例如,您可能希望在将图像传递到卷积网络时排列图像的维度 或视觉转换器,或将值除以 255。 有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个新的Module执行此作。

TensorDictModule适用于任何可调用对象,而不仅仅是模块,这使得 将任意函数合并到模块中。例如,让我们看看如何集成激活 函数而不使用reluReLU模块:

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆叠模块

我们的 MLP 不是由单层组成的,因此我们现在需要为其添加另一层。 例如,该层将是一个激活函数ReLU. 我们可以使用TensorDictSequential.

注意

这就是 的真正力量 : 不同tensordict.nnSequential,TensorDictSequential将之前的所有输入和输出保存在内存中 (事后可以过滤掉它们),从而轻松拥有复杂的网络结构 以编程方式即时构建。

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我们可以重复这个逻辑来获得完整的 MLP:

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多个输入键

残差网络的最后一步是将输入添加到最后一个线性层的输出中。 无需编写特殊的Module子类!TensorDictModule也可以用于包装简单的函数:

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

我们现在可以将 和 for 一个完全充实的残差块放在一起:block0block1residual

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

真正的问题可能是用作输入的 tensordict 中条目的累积:在某些情况下(例如,当 gradients 是必需的)中间值可能被缓存,但情况并非总是如此,它可能很有用 让垃圾回收器知道某些条目可以被丢弃。tensordict.nn.TensorDictModuleBase和 其子类(包括tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential) 可以选择在执行后查看其 output keys filtered 的筛选结果。为此,只需调用tensordict.nn.TensorDictModuleBase.select_out_keys方法。这将就地更新模块,并且所有 不需要的条目将被丢弃:

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

但是,将保留输入键:

assert "x" in tensordict

作为旁注,也可以传递给selected_out_keystensordict.nn.TensorDictSequential避免 单独调用该方法。

在没有 tensordict 的情况下使用 TensorDictModule

提供的机会tensordict.nn.TensorDictSequential随时随地构建复杂的架构 并不意味着必须切换到 tensordict 来表示数据。由于dispatch中,tensordict.nn 中的模块支持与 条目名称也:

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在引擎盖下,dispatch重新构建 Tensordict,运行该模块,然后解构它。 这可能会导致一些开销,但正如我们稍后将看到的那样,有一个解决方案可以消除这种情况。

运行

tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential在以下情况下会产生一些开销 执行,因为它们需要从 Tensordict 中读取和写入。但是,我们可以通过使用compile().为此,让我们比较一下这段代码的三个版本(有 compile 和没有 compile):

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular:  219.5165 us
TDM:  260.3091 us
Sequential:  375.0590 us
Compiled versions
Compiled regular:  326.0555 us
Compiled TDM:  333.1850 us
Compiled sequential:  342.4750 us

如您所见,由TensorDictSequential已完全解决。

使用 TensorDictModule 的该做和不该做

  • 不要使用 .它会破坏输入/输出 key 结构。 总是尝试依赖。Sequencetensordict.nnnn:TensorDictSequential

  • 不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是就地修改的输入。 分配新的变量名称并不是严格禁止的,但这意味着您可能希望它们都消失 当一个被删除时,实际上垃圾回收器仍会看到工作区中的张量和无内存 将被释放:

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

使用分配:ProbabilisticTensorDictModule

ProbabilisticTensorDictModule是一个非参数模块,表示 概率分布。分布参数是从 tensordict 中读取的 input,并将输出写入输出 tensordict。输出为 sampled 给定一些规则,由 input 参数和全局函数指定。如果它们发生冲突, 上下文管理器 (Context Manager) 位于其前面。default_interaction_typeinteraction_type()

它可以与TensorDictModule返回 使用 .这是ProbabilisticTensorDictSequentialTensorDictSequential其最后一层是ProbabilisticTensorDictModule实例。

ProbabilisticTensorDictModule负责构造 分发(通过get_dist()方法)和/或 从此发行版中采样(通过对 Module 的常规 forward 调用)。一样get_dist()方法在 中公开。ProbabilisticTensorDictSequential

可以在 output tensordict 和 log 中找到参数 概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

结论

我们已经看到了如何使用 tensordict.nn 动态构建复杂的神经架构。 这为构建不关心模型签名的管道提供了可能性,即编写通用代码 以灵活的方式使用具有任意数量的输入或输出的网络。

我们还看到了dispatch允许使用 tensordict.nn 构建此类网络并使用 他们没有重复到TensorDict径直。由于compile()、开销 介绍tensordict.nn.TensorDictSequential可以完全移除,给用户留下整洁、 tensordict free 版本。

在下一个教程中,我们将了解如何使用它来隔离模块并导出它。torch.export

脚本总运行时间:(0 分 18.375 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源