注意
转到末尾下载完整的示例代码。
操作 TensorDict 的键¶
作者: Tom Begley
在本教程中,您将学习如何使用和操作 中的键,包括获取和设置键、迭代键、操作
嵌套值,并展平键。TensorDict
设置和获取密钥¶
我们可以使用与 Python 相同的语法来设置和获取键dict
import torch
from tensordict.tensordict import TensorDict
tensordict = TensorDict()
# set a key
a = torch.rand(10)
tensordict["a"] = a
# retrieve the value stored under "a"
assert tensordict["a"] is a
注意
与 Python 不同,中的所有键都必须是字符串。然而
正如我们将看到的,也可以使用字符串元组来操作嵌套的
值。dict
TensorDict
我们也可以使用 METHODS 和 来完成相同的操作。.get()
.set
tensordict = TensorDict()
# set a key
a = torch.rand(10)
tensordict.set("a", a)
# retrieve the value stored under "a"
assert tensordict.get("a") is a
就像 ,我们可以提供一个默认值,该值应该在
未找到请求的密钥。dict
get
同样,像 ,我们可以使用 来获取
值,如果未找到该键,则返回默认值,以及
此外,在 .
dict
TensorDict.setdefault()
删除键的方式也与 Python 相同,使用语句和所选键。等效地,我们可以使用该方法。
dict
del
del tensordict["banana"]
此外,当使用 设置 key 时,我们可以使用 keyword 参数进行就地更新,或者等效地使用 method。.set()
inplace=True
.set_()
tensordict.set("a", torch.zeros(10), inplace=True)
# all the entries of the "a" tensor are now zero
assert (tensordict.get("a") == 0).all()
# but it's still the same tensor as before
assert tensordict.get("a") is a
# we can achieve the same with set_
tensordict.set_("a", torch.ones(10))
assert (tensordict.get("a") == 1).all()
assert tensordict.get("a") is a
重命名键¶
要重命名键,只需使用 方法.价值
存储在原始 key 下将保留在
中,但 key
将更改为指定的新密钥。
tensordict.rename_key_("a", "b")
assert tensordict.get("b") is a
print(tensordict)
TensorDict(
fields={
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
更新多个值¶
该方法可用于
将 A 更新为另一个 ID 或 .已经
exist 将被覆盖,并且将创建尚不存在的键。
TensorDict`
dict
tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
assert (tensordict["a"] == 0).all()
assert (tensordict["b"] != 0).all()
assert (tensordict["c"] == 0).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
嵌套值¶
a 的值本身可以是 。我们可以添加嵌套的
值,通过直接添加或使用嵌套
字典TensorDict
TensorDict
TensorDict
# creating nested values with a nested dict
nested_tensordict = TensorDict(
{"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
)
# creating nested values with a TensorDict
tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
要访问这些嵌套值,我们可以使用字符串元组。例如
double_nested_a = tensordict["nested", "double_nested", "a"]
nested_a = tensordict.get(("nested", "a"))
同样,我们可以使用字符串元组设置嵌套值
tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
tensordict.set(("nested", "b"), torch.rand(2, 3))
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
迭代 TensorDict 的内容¶
我们可以使用该方法迭代 a 的键。TensorDict
.keys()
a
nested
默认情况下,这将仅迭代 ,
但是,可以使用 keyword argument 递归迭代 中的所有键 。这将迭代
递归地访问任何嵌套 TensorDict 中的所有键,以元组形式返回嵌套键
的字符串。TensorDict
TensorDict
include_nested=True
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'double_nested')
('nested', 'b')
nested
如果你只想迭代与值对应的键,你可以
另外指定 。Tensor
leaves_only=True
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'b')
与 非常相似,也有 and 方法接受
相同的关键字参数。dict
.values
.items
a is a Tensor
nested is a TensorDict
('nested', 'a') is a Tensor
('nested', 'double_nested') is a TensorDict
('nested', 'double_nested', 'a') is a Tensor
('nested', 'double_nested', 'b') is a Tensor
('nested', 'b') is a Tensor
检查是否存在密钥¶
要检查键是否存在于 中,请结合使用 运算符
跟。TensorDict
in
.keys()
注意
执行 执行 对键进行高效查找
(在嵌套情况下的每个级别递归地),因此性能不是
当 中存在大量键时,会受到负面影响。key in tensordict.keys()
dict
TensorDict
assert "a" in tensordict.keys()
# to check for nested keys, set include_nested=True
assert ("nested", "a") in tensordict.keys(include_nested=True)
assert ("nested", "banana") not in tensordict.keys(include_nested=True)
拼合和取消拼合嵌套键¶
我们可以使用该方法展平具有嵌套值的 a。TensorDict
.flatten_keys()
print(tensordict, end="\n\n")
print(tensordict.flatten_keys(separator="."))
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
给定已展平的 a,则可以再次将其展平
使用方法。TensorDict
.unflatten_keys()
flattened_tensordict = tensordict.flatten_keys(separator=".")
print(flattened_tensordict, end="\n\n")
print(flattened_tensordict.unflatten_keys(separator="."))
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
double_nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
这在操作 a 的参数时特别有用,因为我们最终可能会得到
structure 模拟 module 结构。
import torch.nn as nn
module = nn.Sequential(
nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
nn.Linear(10, 1),
)
params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()
print(params)
TensorDict(
fields={
0: TensorDict(
fields={
0: TensorDict(
fields={
bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
1: TensorDict(
fields={
bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
1: TensorDict(
fields={
bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
选择和排除键¶
我们可以使用 来获取具有键子集的 new,它返回仅包含指定键的 new
,或者
:meth: TensorDict.exclude <tensordict.TensorDict.exclude>,它返回一个省略了指定键的 new
。
print("Select:")
print(tensordict.select("a", ("nested", "a")), end="\n\n")
print("Exclude:")
print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))
Select:
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
Exclude:
TensorDict(
fields={
a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=None,
is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
脚本总运行时间:(0 分 0.009 秒)