目录

操作张量字典的键

作者: Tom Begley

在这个教程中,您将学习如何处理和操作TensorDict中的键,包括获取和设置键、迭代键、操作嵌套值以及展平键。

设置和获取键

我们可以使用与 Python dict 相同的语法来设置和获取键。

import torch
from tensordict.tensordict import TensorDict

tensordict = TensorDict()

# set a key
a = torch.rand(10)
tensordict["a"] = a

# retrieve the value stored under "a"
assert tensordict["a"] is a

注意

与Python dict不同,TensorDict中的所有键都必须是字符串。然而 正如我们将看到的,也可以使用字符串元组来操作嵌套 值。

我们也可以使用方法 .get().set 来完成相同的事情。

tensordict = TensorDict()

# set a key
a = torch.rand(10)
tensordict.set("a", a)

# retrieve the value stored under "a"
assert tensordict.get("a") is a

就像 dict,我们可以为 get 提供一个默认值,如果请求的键未找到,则应返回该默认值。

assert tensordict.get("banana", a) is a

同样地,像 dict一样,我们可以使用 TensorDict.setdefault() 来获取特定键的值,如果找不到该键则返回默认值,并且在 TensorDict 中设置该值。

assert tensordict.setdefault("banana", a) is a
# a is now stored under "banana"
assert tensordict["banana"] is a

删除键的方式与Python dict相同,使用del语句和选定的键。等价地,我们可以使用TensorDict.del_方法。

del tensordict["banana"]

此外,当使用 .set() 设置键时,我们可以使用关键字参数 inplace=True 进行原地更新,或者等效地使用 .set_() 方法。

tensordict.set("a", torch.zeros(10), inplace=True)

# all the entries of the "a" tensor are now zero
assert (tensordict.get("a") == 0).all()
# but it's still the same tensor as before
assert tensordict.get("a") is a

# we can achieve the same with set_
tensordict.set_("a", torch.ones(10))
assert (tensordict.get("a") == 1).all()
assert tensordict.get("a") is a

重命名键

要重命名一个键,只需使用 TensorDict.rename_key_ 方法。存储在原始键下的值将保留在 TensorDict 中,但键将更改为指定的新键。

tensordict.rename_key_("a", "b")
assert tensordict.get("b") is a
print(tensordict)
TensorDict(
    fields={
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

更新多个值

The TensorDict.update 方法可以用于 更新一个 TensorDict` 为另一个或与一个 dict. 已经存在的键会被覆盖,而不存在的键会被创建。

tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
assert (tensordict["a"] == 0).all()
assert (tensordict["b"] != 0).all()
assert (tensordict["c"] == 0).all()
print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

嵌套值

TensorDict 的值本身可以是另一个 TensorDict。我们可以在实例化过程中添加嵌套值,方法包括直接添加 TensorDict,或使用嵌套字典。

# creating nested values with a nested dict
nested_tensordict = TensorDict(
    {"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
)
# creating nested values with a TensorDict
tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

要访问这些嵌套的值,我们可以使用字符串元组。例如

double_nested_a = tensordict["nested", "double_nested", "a"]
nested_a = tensordict.get(("nested", "a"))

同样地,我们也可以使用字符串元组来设置嵌套值

tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
tensordict.set(("nested", "b"), torch.rand(2, 3))

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

遍历TensorDict的内容

我们可以使用 .keys() 方法遍历 TensorDict 的键。

for key in tensordict.keys():
    print(key)
a
nested

默认情况下,此操作仅遍历 TensorDict 的顶层键, 但可通过关键字参数 include_nested=True 递归遍历 TensorDict 中的所有键。该操作将递归遍历任意嵌套 TensorDict 中的所有键,并以字符串元组的形式返回嵌套键。

for key in tensordict.keys(include_nested=True):
    print(key)
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'double_nested')
('nested', 'b')
nested

在您只想迭代对应于Tensor值的键时,您可以 另外指定leaves_only=True

for key in tensordict.keys(include_nested=True, leaves_only=True):
    print(key)
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'b')

就像dict一样,也有.values.items方法接受相同的关键词参数。

for key, value in tensordict.items(include_nested=True):
    if isinstance(value, TensorDict):
        print(f"{key} is a TensorDict")
    else:
        print(f"{key} is a Tensor")
a is a Tensor
nested is a TensorDict
('nested', 'a') is a Tensor
('nested', 'double_nested') is a TensorDict
('nested', 'double_nested', 'a') is a Tensor
('nested', 'double_nested', 'b') is a Tensor
('nested', 'b') is a Tensor

检查键是否存在

要检查键是否存在于 TensorDict 中,请使用 in 运算符结合 .keys()

注意

执行 key in tensordict.keys() 可以高效地进行 dict 键查找 (在嵌套情况下递归地在每一级进行),因此性能不会 在 TensorDict 中键的数量很大时受到负面影响。

assert "a" in tensordict.keys()
# to check for nested keys, set include_nested=True
assert ("nested", "a") in tensordict.keys(include_nested=True)
assert ("nested", "banana") not in tensordict.keys(include_nested=True)

Flattening and unflattening nested keys

我们可以使用 .flatten_keys() 方法展平一个包含嵌套值的 TensorDict

print(tensordict, end="\n\n")
print(tensordict.flatten_keys(separator="."))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

给定一个 TensorDict,它已经被展平,可以通过 .unflatten_keys() 方法再次展开。

flattened_tensordict = tensordict.flatten_keys(separator=".")
print(flattened_tensordict, end="\n\n")
print(flattened_tensordict.unflatten_keys(separator="."))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

这在操作 torch.nn.Module 的参数时特别有用,因为我们最终可能会得到一个 TensorDict,其结构模仿了模块结构。

import torch.nn as nn

module = nn.Sequential(
    nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
    nn.Linear(10, 1),
)
params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()

print(params)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                1: TensorDict(
                    fields={
                        bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

选择和排除键

我们可以使用子集键获得一个新的TensorDict,通过使用 TensorDict.select,它返回一个包含仅指定键的新 TensorDict,或者 :meth: TensorDict.exclude <tensordict.TensorDict.exclude>,它返回一个省略了指定键的新 TensorDict

print("Select:")
print(tensordict.select("a", ("nested", "a")), end="\n\n")
print("Exclude:")
print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))
Select:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Exclude:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

脚本总运行时间: (0 分钟 0.009 秒)

通过 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源