注意
转到末尾 以下载完整示例代码。
切片、索引和掩码¶
作者: Tom Begley
在这个教程中,您将学习如何切片、索引和掩码一个 TensorDict。
如教程操作 TensorDict 的形状中所述,当我们创建一个TensorDict时,需指定一个batch_size,该值必须与TensorDict中所有条目的前导维度一致。由于我们能确保所有条目均共享这些维度,因此可以像索引torch.Tensor一样,对批次维度进行索引和掩码操作。索引将沿批次维度应用于TensorDict中的所有条目。
例如,给定一个 TensorDict 并且有两个批量维度,
tensordict[0] 返回一个新的 TensorDict 具有相同的结构,并且
其值对应于原始
TensorDict 中每个条目的第一行。
import torch
from tensordict import TensorDict
tensordict = TensorDict(
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)
print(tensordict[0])
TensorDict(
fields={
a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)
语法与普通张量相同。例如,若要删除每个条目的第一行,可按如下方式索引:
print(tensordict[1:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 4]),
device=None,
is_shared=False)
我们可以同时索引多个维度
print(tensordict[:, 2:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
我们也可以使用 Ellipsis 来表示需要的任意数量的 :,以使选择元组与 tensordict.batch_dims 的长度相同。
print(tensordict[..., 2:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
通过索引设置值¶
一般来说,tensordict[index] = new_tensordict 只要批量大小兼容就会起作用。
tensordict = TensorDict(
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)
td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]])
掩码¶
我们掩码 TensorDict 就像我们掩码张量一样。
TensorDict(
fields={
a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([6]),
device=None,
is_shared=False)
脚本总运行时间: (0 分钟 0.004 秒)