目录

使用 TensorDict 预分配内存

作者Tom Begley

在本教程中,您将学习如何利用 中的内存预分配。

假设我们有一个函数,它返回一个

import torch
from tensordict.tensordict import TensorDict


def make_tensordict():
    return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])

也许我们想多次调用这个函数,并使用结果进行填充 单个 .

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])

for i in range(N):
    tensordict[i] = make_tensordict()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

因为我们指定了 of ,所以在第一个 循环的迭代中,我们填充空张量,其第一个 dimension 是 size ,其其余维度由返回 的值。在上面的示例中,我们预先分配了一个 0 数组 key 的 size 和 key 的数组大小。循环的后续迭代是 就地编写。因此,如果不是所有值都已填充,它们将获得默认值 值为零。batch_sizetensordicttensordictNmake_tensordicttorch.Size([10, 3])"a"torch.Size([10, 3, 4])"b"

让我们通过单步执行上述循环来演示发生了什么。我们首先 初始化一个空的 .

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
    fields={
    },
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

在第一次迭代之后, 已预填充了 和 的张量。这些张量包含零,除了第一行,我们 已分配随机值。tensordict"a""b"

random_tensordict = make_tensordict()
tensordict[0] = random_tensordict

assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

在后续迭代中,我们将就地更新预先分配的张量。

a = tensordict["a"]
random_tensordict = make_tensordict()
tensordict[1] = random_tensordict

# the same tensor is stored under "a", but the values have been updated
assert tensordict["a"] is a
assert (tensordict[:2] != 0).all()

脚本总运行时间:(0 分 0.003 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源