概述¶
TensorDict 可以轻松组织数据和编写可重用的通用 PyTorch 代码。最初是为 TorchRL 开发的,我们将其衍生到一个单独的库中。
TensorDict 主要是一个字典,但也是一个类似张量的类:它支持多个张量操作,这些操作主要与形状和存储相关。它旨在有效地序列化或从一个节点传输到另一个节点或从一个过程传输到另一个过程。最后,它附带了自己的模块,该模块与模型集成和参数操作兼容,旨在简化模型集成和参数操作。tensordict.nn
functorch
在此页面上,我们将激励并提供一些示例来说明它可以做什么。TensorDict
赋予动机¶
TensorDict 允许您编写可跨范例重用的通用代码模块。例如,以下循环可以在大多数 SL、SSL、UL 和 RL 任务中重复使用。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
通过其模块,该包提供了许多工具,可以毫不费力地在代码库中使用。tensordict.nn
TensorDict
在多处理或分布式设置中,允许您将数据无缝分派给每个工作程序:tensordict
>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
... idx = splits[worker]
... pipe[worker].send(tensordict[idx])
TensorDict 提供的一些操作也可以通过 tree_map 完成,但复杂程度更高:
>>> td = TensorDict(
... {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
... for i in range(3)]
嵌套案例更引人注目:
>>> td = TensorDict(
... {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
... for i in range(3)
在应用 unbind 操作后,将输出字典分解为三个结构相似的字典在天真地使用 pytree 时很快就会变得非常麻烦。通过 tensordict,我们为想要解绑或拆分嵌套结构的用户提供一个简单的 API,而不是计算嵌套拆分/未绑定的嵌套结构。
特征¶
A 是张量的类似 dict 的容器。要实例化 ,必须指定键值对以及批处理大小。中任何值的 leading dimensions 都必须与批量大小兼容。TensorDict
TensorDict
TensorDict
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )
设置或检索值的语法与常规字典的语法非常相似。
>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)
还可以沿 tensordict 的 batch_size 为其编制索引,这样就可以在 几个字符(请注意,使用省略号对带有 tree_map 的第 n 个前导维度进行索引需要更多的编码):
>>> sub_tensordict = tensordict[..., :2]
还可以使用 set 方法和 或 method 对内容进行就地更新。
前者是后者的容错版本:如果未找到匹配的 key,它将写入一个新的 key。inplace=True
set_
现在可以集体操作 TensorDict 的内容。 例如,要将所有内容都放置在特定设备上,只需执行
>>> tensordict = tensordict.to("cuda:0")
要重塑批次维度,可以执行以下操作
>>> tensordict = tensordict.reshape(6)
该类支持许多其他操作,包括 squeeze、unsqueeze、view、permute、unbind、stack、cat 等等。如果不存在操作,则 TensorDict.apply 方法通常会提供所需的解决方案。
命名维度¶
TensorDict 和相关类也支持维度名称。 这些名称可以在构建时给出,也可以在以后进行优化。语义是 类似于火把。Tensor 维度名称功能:
>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")
嵌套 TensorDict¶
a 中的值本身可以是 TensorDicts(下面示例中的嵌套字典将转换为嵌套的 TensorDicts)。TensorDict
>>> tensordict = TensorDict(
... {
... "inputs": {
... "image": torch.rand(100, 28, 28),
... "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
... },
... "outputs": {"logits": torch.randn(100, 10)},
... },
... batch_size=[100],
... )
访问或设置嵌套键可以通过字符串 Tuples 来完成
>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits")) # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)
惰性评估¶
某些操作会延迟执行,直到访问项为止。例如,堆叠、压缩、取消压缩、排列批处理维度和创建视图不会立即对 的所有内容执行。相反,当访问 中的值时,它们会延迟执行。如果 包含许多值,这可以节省大量不必要的计算。TensorDict
TensorDict
TensorDict
TensorDict
>>> tensordicts = [TensorDict({
... "a": torch.rand(10),
... "b": torch.rand(10, 1000, 1000)}, [10])
... for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0) # no stacking happens here
>>> stacked_a = stacked["a"] # we stack the a values, b values are not stacked
它还有一个优点,即我们可以在堆栈中操作原始 tensordicts:
>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()
需要注意的是,get 方法现在已成为一项昂贵的操作,如果多次重复,可能会导致一些开销。只需在执行 stack 后调用 tensordict.contiguous() 即可避免这种情况。为了进一步缓解这种情况,TensorDict 附带了自己的元数据类 (MetaTensor),该类跟踪字典每个条目的类型、形状、数据类型和设备,而无需执行昂贵的操作。
延迟预分配¶
假设我们有一些函数 foo() -> TensorDict,并且我们执行以下操作:
>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
... tensordict[i] = foo()
当空时,将自动填充批量大小为 N 的空张量。在循环的后续迭代中,更新都将就地写入。i == 0
TensorDict
TensorDictModule¶
为了便于集成到代码库中,我们提供了一个 tensordict.nn 包,允许用户将实例传递给对象。TensorDict
TensorDict
nn.Module
TensorDictModule
包装并接受 single 作为输入。您可以指定底层模块应从何处获取其 Importing,以及它应从何处写入其 output。这是我们可以编写可重用的通用高级代码(例如 motivation 部分的 training loop)的关键原因。nn.Module
TensorDict
>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.LazyLinear(1)
...
... def forward(self, x):
... logits = self.linear(x)
... return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
... Net(),
... in_keys=["input"],
... out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))
为了便于采用此类,还可以将张量作为 kwargs 传递:
>>> tensordict = module(input=torch.randn(32, 100))
它将返回与上一个代码框中的相同。TensorDict
多个 PyTorch 用户的关键痛点是 nn.Sequential 用于处理具有多个 inputs 的模块。使用基于键的图形可以轻松解决这个问题,因为序列中的每个节点都知道需要读取哪些数据以及将其写入何处。
为此,我们提供了一个类,它通过一系列 .序列中的每个模块都从原始模块获取其输入,并将其输出写入原始模块,这意味着序列中的模块可以忽略其前身的输出,或根据需要从 tensordict 获取其他输入。下面是一个示例。TensorDictSequential
TensorDictModules
TensorDict
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
... class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]
在此示例中,第二个模块将第一个模块的输出与存储在 .TensorDict
TensorDictSequential
提供了许多其他功能:可以通过查询 in_keys 和 out_keys 属性来访问输入和输出键的列表。还可以通过使用所需的输入和输出键集进行查询来请求子图。这将返回另一个,其中仅包含满足这些要求所必需的模块。还兼容 和其他功能。select_subsequence()
TensorDictSequential
TensorDictModule
vmap
functorch
函数式编程¶
我们提供 API 以与 结合使用。例如,可以轻松连接模型权重以进行模型集成:TensorDict
functorch
TensorDict
>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(separator=".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(separator=".")
>>> params = make_functional(model)
>>> # params provided by make_functional match state_dict:
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params) # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])
函数式 API 即使不比 中实现的当前 API 更快,也是相当的。FunctionalModule
functorch