注意
转到末尾 以下载完整示例代码。
操作张量字典的形状¶
作者: Tom Begley
在这个教程中,您将学习如何操作 TensorDict 的形状及其内容。
当我们创建一个 TensorDict 时,我们指定了一个 batch_size,它必须与 TensorDict 中所有条目的前导维度一致。由于我们保证所有条目共享这些共同的维度,TensorDict 能够提供一系列方法,通过这些方法我们可以操作 TensorDict 的形状及其内容。
import torch
from tensordict.tensordict import TensorDict
索引 TensorDict¶
由于所有条目的批量维度都保证存在,我们可以随意索引它们,并且 TensorDict 的每个条目都将以相同的方式进行索引。
a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])
重塑一个 TensorDict¶
TensorDict.reshape 的工作方式与
torch.Tensor.reshape() 相同。它适用于所有内容的
TensorDict 沿着批量维度 - 注意 b 在下面示例中的形状。它还更新了 batch_size 属性。
reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])
Splitting a TensorDict¶
TensorDict.split 与 torch.Tensor.split() 类似。它将 TensorDict 分割成块。每个块是一个 TensorDict,其结构与原始对象相同,但其条目是原始 TensorDict 中相应条目的视图。
chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])
注意
每当某个函数或方法接受一个 dim 参数时,负维度将相对于该函数或方法所作用的 TensorDict 的 batch_size 进行解释。特别地,如果存在具有不同批量大小的嵌套 TensorDict 值,则负维度始终相对于根节点的批量维度进行解释。
>>> tensordict = TensorDict(
... {
... "a": torch.rand(3, 4),
... "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
... },
... [3, 4],
... )
>>> # dim = -2 will be interpreted as the first dimension throughout, as the root
>>> # TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
>>> chunks = tensordict.split([2, 1], dim=-2)
>>> assert chunks[0].batch_size == torch.Size([2, 4])
>>> assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])
从这个例子可以看出,TensorDict.split 方法的行为与我们在调用之前将 dim=-2 替换为 dim=tensordict.batch_dims - 2 完全相同。
解除绑定¶
TensorDict.unbind 类似于
torch.Tensor.unbind(),并且概念上类似于
TensorDict.split。它移除了指定的
维度并返回该维度上所有切片的 tuple。
slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])
堆叠和连接¶
TensorDict 可以与 torch.cat 和 torch.stack 结合使用。
Stacking TensorDict¶
堆叠可以懒惰地或连续地完成。懒惰堆叠只是一个张量字典列表,以堆叠的张量字典形式呈现。它允许用户携带具有不同内容形状、设备或键集的张量字典包。另一个优点是堆叠操作可能很昂贵,如果只需要一小部分键,则懒惰堆叠将比适当的堆叠快得多。
它依赖于LazyStackedTensorDict类。
在这种情况下,值仅在访问时按需堆叠。
from tensordict import LazyStackedTensorDict
cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
[tensordict, cloned_tensordict], dim=0
)
print(stacked_tensordict)
# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:
from tensordict.utils import set_lazy_legacy
with set_lazy_legacy(True): # old behaviour
lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)
with set_lazy_legacy(False): # new behaviour
dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 3, 4]),
device=None,
is_shared=False,
stack_dim=0)
如果我们沿着堆叠维度索引一个 LazyStackedTensorDict,我们恢复
原始的 TensorDict。
assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict
访问 LazyStackedTensorDict 中的键会导致这些值被堆叠。如果键对应于嵌套的 TensorDict,那么我们将恢复另一个 LazyStackedTensorDict。
assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])
注意
由于值是按需堆叠的,多次访问一个项目意味着它会被多次堆叠,这是低效的。如果你需要在堆叠的TensorDict中多次访问一个值,你可能需要考虑将LazyStackedTensorDict转换为连续的TensorDict,这可以通过LazyStackedTensorDict.to_tensordict或LazyStackedTensorDict.contiguous方法完成。
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
调用这些方法中的任何一个之后,我们将得到一个包含堆叠值的常规TensorDict,并且在访问值时不会执行额外的计算。
Concatenating TensorDict¶
Concatenation is not done lazily, instead calling torch.cat() on a list of
TensorDict instances simply returns a TensorDict whose entries
are the concatenated entries of the elements of the list.
concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])
扩展 TensorDict¶
我们可以扩展所有 TensorDict 的条目,使用
TensorDict.expand。
exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])
挤压和取消挤压 TensorDict¶
我们可以压缩或解压缩 TensorDict 的内容,使用
squeeze() 和
unsqueeze() 方法。
tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")
unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
print(unsqueezed_tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 1, 4, 1]),
device=None,
is_shared=False)
注意
直到现在,操作如 unsqueeze(),
squeeze(), view(),
permute(), transpose()
都返回这些操作的懒惰版本(即,存储原始张量字典的容器,并且每次访问键时都会应用这些操作)。
这种行为将在未来被弃用,并且已经可以通过
set_lazy_legacy() 函数进行控制:
>>> with set_lazy_legacy(True):
... lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
... dense_unsqueeze = tensordict.unsqueeze(0)
请注意,与以往一样,这些方法仅适用于批处理维度。条目的任何非批处理维度均不受影响。
tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])
查看 TensorDict¶
TensorDict 还支持 view. 这创建了一个 _ViewedTensorDict
当访问其内容时会懒惰地创建视图。
tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))
# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])
交换批次维度¶
The TensorDict.permute 方法可以用于像 torch.permute() 一样交换批量维度。非批量维度保持不变。
此操作是懒惰的,因此只有在尝试访问条目时才会交换批量维度。一如既往,如果您可能需要多次访问特定条目,请考虑转换为 TensorDict。
tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])
assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])
使用张量字典作为装饰器¶
对于一系列可逆操作,tensordicts 可以用作装饰器。
这些操作包括 to_module() 用于函数调用,unlock_() 和 lock_()
或形状操作,如 view()、permute()
transpose()、squeeze() 和
unsqueeze()。
这里有一个使用 transpose 函数的快速示例:
tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
with tensordict.transpose(1, 0) as tdt:
tdt.set("c", torch.ones(4, 3)) # we have permuted the dims
# the ``"c"`` entry is now in the tensordict we used as decorator:
#
assert (tensordict.get("c") == 1).all()
Gathering values in TensorDict¶
The TensorDict.gather 方法可以用于沿批处理维度进行索引并将结果聚集到一个维度中,类似于 torch.gather()。
index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
index:
tensor([[2, 3, 2, 1],
[3, 3, 0, 0],
[3, 1, 1, 2]])
tensordict['a']:
tensor([[0.1814, 0.2808, 0.2381, 0.4003],
[0.1536, 0.0138, 0.4464, 0.6981],
[0.9308, 0.0727, 0.3552, 0.4791]])
gathered_tensordict['a']:
tensor([[0.2381, 0.4003, 0.2381, 0.2808],
[0.6981, 0.6981, 0.1536, 0.1536],
[0.4791, 0.0727, 0.0727, 0.3552]])
脚本总运行时间: (0 分钟 0.008 秒)