目录

切片、索引和掩码

作者Tom Begley

在本教程中,您将学习如何对 .

如教程操作 TensorDict 的形状中所述,当我们创建 a 时,我们会指定一个 ,它必须同意 替换为 .由于我们有 保证所有条目共享这些共同的维度,我们能够索引 并屏蔽批处理维度,其方式与索引 a 相同。索引将沿批处理维度应用于所有 .batch_size

例如,给定具有两个批次维度的 a,返回具有相同结构的 new,并且 其值对应于 original 中每个条目的第一个 “row” 。tensordict[0]

import torch
from tensordict import TensorDict

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

print(tensordict[0])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

语法与常规张量相同。例如,如果我们想删除 我们可以按如下方式索引每个条目的第一行

print(tensordict[1:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 4]),
    device=None,
    is_shared=False)

我们可以同时为多个维度编制索引

print(tensordict[:, 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

我们还可以使用 to 表示所需的数量,以使 Selection 元组的长度与 相同。Ellipsis:tensordict.batch_dims

print(tensordict[..., 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

使用索引设置值

一般来说,只要 batch 就可以工作 尺寸兼容。tensordict[index] = new_tensordict

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.]])

掩蔽

我们像屏蔽张量一样进行掩码。TensorDict

mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])
tensordict[mask]
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([6]),
    device=None,
    is_shared=False)

脚本总运行时间:(0 分 0.004 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源