目录

作 TensorDict 的形状

作者Tom Begley

在本教程中,您将学习如何作TensorDict及其内容。

当我们创建一个TensorDict我们指定一个 ,它必须同意 替换为batch_sizeTensorDict.由于我们有 保证所有条目共享这些共同的维度,TensorDict能够公开许多方法,我们可以使用这些方法作TensorDict及其内容。

import torch
from tensordict.tensordict import TensorDict

TensorDict

由于保证 batch 维度存在于所有条目上,因此我们可以为它们编制索引 随心所欲,并且TensorDict将在同一 道路。

a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])

indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])

重塑TensorDict

TensorDict.reshape就像torch.Tensor.reshape().它适用于TensorDict沿批次维度 - 注意 的形状 示例如下。它还会更新属性。bbatch_size

reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])

拆分TensorDict

TensorDict.split类似于torch.Tensor.split().它会拆分TensorDict分成块。每 chunk 是一个TensorDict与原始结构相同,但 其条目是原始TensorDict.

chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])

注意

每当函数或方法接受参数时,负维度为 相对于 的 解释dimbatch_sizeTensorDict该 function 或 method 被调用。特别是,如果存在嵌套的TensorDict值具有不同的批量大小,则负维度为 始终相对于根的批次维度进行解释。

>>> tensordict = TensorDict(
...     {
...         "a": torch.rand(3, 4),
...         "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
...     },
...     [3, 4],
... )
>>> # dim = -2 will be interpreted as the first dimension throughout, as the root
>>> # TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
>>> chunks = tensordict.split([2, 1], dim=-2)
>>> assert chunks[0].batch_size == torch.Size([2, 4])
>>> assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])

从此示例中可以看出,TensorDict.split方法的行为与 虽然我们已经用 叫。dim=-2dim=tensordict.batch_dims - 2

TensorDict.unbind类似于torch.Tensor.unbind(),在概念上类似于TensorDict.split.它会删除指定的 维度,并返回该维度上所有切片的 A。tuple

slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])

堆叠和连接

TensorDict可以与 和 结合使用。torch.cattorch.stack

堆垛TensorDict

堆叠可以延迟或连续完成。惰性堆栈只是一个 tensordict 列表 以 Tensordict 堆栈的形式呈现。它允许用户携带一袋 tensordict 具有不同的内容形状、设备或键集。另一个优点是 堆栈作可能成本高昂,如果只需要一小部分键, 延迟堆栈将比适当的堆栈快得多。 它依赖于LazyStackedTensorDict类。 在这种情况下,只有在访问值时才会按需堆叠值。

from tensordict import LazyStackedTensorDict

cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
    [tensordict, cloned_tensordict], dim=0
)
print(stacked_tensordict)

# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:

from tensordict.utils import set_lazy_legacy

with set_lazy_legacy(True):  # old behaviour
    lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)

with set_lazy_legacy(False):  # new behaviour
    dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
LazyStackedTensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False,
    stack_dim=0)

如果我们将LazyStackedTensorDict沿着堆叠维度,我们恢复 原版TensorDict.

assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict

访问LazyStackedTensorDict导致这些值为 堆叠。如果 key 对应于嵌套的TensorDict然后我们会恢复 另一个LazyStackedTensorDict.

assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])

注意

由于值是按需堆叠的,因此多次访问一个项目将意味着它 多次堆叠,效率低下。如果需要访问值 在堆叠TensorDict不止一次,您可能需要考虑 将 转换为 连续LazyStackedTensorDictTensorDict,这可以通过LazyStackedTensorDict.to_tensordictLazyStackedTensorDict.contiguous方法。

>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)
>>> assert isinstance(stacked_tensordict.contiguous(), TensorDict)

调用这些方法中的任何一个后,我们将得到一个包含堆叠值的 regular,并且当 值。TensorDict

串联TensorDict

串联不是懒惰地完成的,而是调用torch.cat()TensorDict实例只是返回一个TensorDict其条目 是列表元素的串联条目。

concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])

扩大TensorDict

我们可以扩展TensorDictTensorDict.expand.

exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])

挤压和解压TensorDict

我们可以挤压或解压缩TensorDict使用squeeze()unsqueeze()方法。

tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")

unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
print(unsqueezed_tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 4]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 1, 4, 1]),
    device=None,
    is_shared=False)

注意

到目前为止,像unsqueeze(),squeeze(),view(),permute(),transpose()都返回了这些作的惰性版本(即,原始 tensordict 的 Tensordict 存储,并且每次访问键时都会应用作)。 此行为将在未来被弃用,并且已经可以通过set_lazy_legacy()功能:

>>> with set_lazy_legacy(True):
...     lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
...     dense_unsqueeze = tensordict.unsqueeze(0)

请记住,与往常一样,这些方法仅适用于批次维度。任何非 条目的 Batch 维度将不受影响

tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])

查看 TensorDict

TensorDict还支持 。这将创建一个 it 在访问其内容时懒惰地创建视图。view_ViewedTensorDict

tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))

# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])

排列批次维度

TensorDict.permutemethod 可用于 排列 Batch 维度,这与 Batch 维度非常相似torch.permute().非批次维度是 保持不变。

此作是惰性的,因此只有在我们尝试访问 条目。与往常一样,如果您可能需要访问特定的条目倍数 倍数,请考虑转换为TensorDict.

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])

assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])

使用 tensordict 作为装饰器

对于一堆可逆作,tensordicts 可以用作装饰器。 这些作包括to_module()用于功能性 调用unlock_()lock_()或 shape作,例如view(),permute() transpose(),squeeze()unsqueeze(). 以下是该函数的快速示例:transpose

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])

with tensordict.transpose(1, 0) as tdt:
    tdt.set("c", torch.ones(4, 3))  # we have permuted the dims

# the ``"c"`` entry is now in the tensordict we used as decorator:
#

assert (tensordict.get("c") == 1).all()

在 中收集值TensorDict

TensorDict.gathermethod 可用于 沿 Batch 维度编制索引,并将结果大量收集到单个维度中 喜欢torch.gather().

index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
index:
 tensor([[2, 3, 2, 1],
        [3, 3, 0, 0],
        [3, 1, 1, 2]])

tensordict['a']:
 tensor([[0.1814, 0.2808, 0.2381, 0.4003],
        [0.1536, 0.0138, 0.4464, 0.6981],
        [0.9308, 0.0727, 0.3552, 0.4791]])

gathered_tensordict['a']:
 tensor([[0.2381, 0.4003, 0.2381, 0.2808],
        [0.6981, 0.6981, 0.1536, 0.1536],
        [0.4791, 0.0727, 0.0727, 0.3552]])

脚本总运行时间:(0 分 0.008 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源