目录

分布式设置中的 TensorDict

TensorDict 可用于分布式设置,以从一个节点传递张量 到另一个。 如果两个节点可以访问共享的物理存储,则内存映射张量可以 用于有效地将数据从一个正在运行的进程传递到另一个正在运行的进程。 在这里,我们提供了有关如何在分布式 RPC 设置中实现此目的的一些详细信息。 有关分布式 RPC 的更多详细信息,请查看官方 pytorch 文档

创建内存映射的 TensorDict

内存映射张量(和数组)具有一个巨大的优势,它们可以存储 大量数据,并允许随时访问数据切片,而无需 读取内存中的整个文件。 TensorDict 在内存映射 数组和torch.Tensor名为 的类。 实例可以存储在对象中,从而允许 tensordict 来表示一个大数据集,存储在磁盘上,在 batched 方式跨节点。MemmapTensorMemmapTensorTensorDict

内存映射的 tensordict 只需通过 (1) 填充 TensorDict 来创建 memory-mapped 张量,或者 (2) 通过调用将其置于 物理存储。 通过查询 tensordict.is_memmap() 可以轻松检查是否将 tensordict 放在物理存储上。tensordict.memmap_()

创建内存映射张量本身可以通过多种方式完成。 首先,可以简单地创建一个空张量:

>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)

该属性指示临时文件的存储位置。 将张量存储在每个 节点!prefix

另一种选择是表示磁盘上的现有张量:

>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")

当张量很大或不适合内存时,前一种方法将是首选: 它适用于非常大的 Tensor 并用作公共存储 跨节点。例如,可以创建一个易于访问的数据集 通过单个或不同的节点,比每个文件都必须 在内存中独立加载:

在磁盘上创建空数据集
>>> dataset = TensorDict({
...      "images": MemmapTensor(50000, 480, 480, 3),
...      "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
...      "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
    fields={
        images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
        labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
        masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

请注意,我们已经指示了 的设备。 这种语法 sugar 允许直接加载查询的张量 在设备上。MemmapTensor

另一个需要考虑的因素是 currently 与 autograd作不兼容。MemmapTensor

跨节点对 Memory-mapped 张量进行作

我们提供了一个分布式脚本的简单示例,其中一个进程创建一个 memory-mapped 张量,并将其引用发送给另一个负责 更新它。您可以在 benchmark 目录中找到此示例。

简而言之,我们的目标是展示如何处理 big 上的读写作 Tensors (当节点有权访问共享物理存储时)。这些步骤包括:

  • 在磁盘上创建空张量;

  • 设置要执行的本地和远程作;

  • 使用 RPC 将命令从 worker 传递到 worker 以读取和写入 共享数据。

此示例首先编写一个更新 TensorDict 实例的函数 在具有 1 填充张量的特定索引处:

>>> def fill_tensordict(tensordict, idx):
...     tensordict[idx] = TensorDict(
...         {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
...     )
...     return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)

这可确保函数是可序列化的。 接下来,我们创建一个相当大的 tensordict,以表明 如果必须传递,这将很难在 worker 之间传递 常规 TensorPipe:CloudpickleWrapper

>>> tensordict = TensorDict(
...     {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )

最后,还是在主节点上,我们在远程节点上调用函数,然后 检查数据是否已写入需要的位置:

>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
...     worker_info,
...     fill_tensordict_cp,
...     args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())

尽管 to 的调用涉及传递整个 tensordict, 更新此对象的特定索引并将其返回给原始 worker, 此代码段的执行速度非常快(如果引用 到内存位置已经提前传递了,参见 torchrl 的分布式 replay buffer documentation 以了解更多信息)。rpc.rpc_sync

该脚本包含超出 本文档的用途。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源