TensorDict 在分布式环境中的使用¶
TensorDict 可用于分布式环境中,以在不同节点之间传递张量。 如果两个节点能够访问共享的物理存储,则可使用内存映射张量,高效地在运行中的进程之间传递数据。 此处,我们将介绍如何在分布式 RPC 环境中实现这一目标的一些细节。 有关分布式 RPC 的更多详细信息,请参阅 PyTorch 官方文档。
创建内存映射的TensorDict¶
内存映射张量(和数组)具有很大的优势,可以存储大量数据,并允许在不将整个文件读入内存的情况下轻松访问数据切片。
TensorDict 提供了内存映射数组与 torch.Tensor 类名为 MemmapTensor 的接口。
MemmapTensor 实例可以存储在 TensorDict 对象中,允许 tensordict 表示存储在磁盘上的大型数据集,并且可以轻松地以批处理方式跨节点访问。
一个内存映射的张量字典可以通过以下方式简单创建:(1) 使用内存映射张量填充 TensorDict 或 (2) 通过调用 tensordict.memmap_() 将其置于物理存储中。
可以通过查询 tensordict.is_memmap() 来轻松检查张量字典是否已置于物理存储中。
创建内存映射张量本身可通过多种方式实现。 首先,可以直接创建一个空张量:
>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)
The prefix 属性指示临时文件应存储的位置。
重要的是,张量必须存储在一个所有节点都能访问的目录中!
另一种选择是在磁盘上表示一个现有的张量:
>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")
当张量较大或无法完全载入内存时,将优先采用前一种方法: 它适用于规模极大、且在各节点间作为通用存储的张量。 例如,可以创建一个数据集,供单个或多个节点快速访问,其访问速度远高于每个文件需独立加载至内存的情形:
>>> dataset = TensorDict({
... "images": MemmapTensor(50000, 480, 480, 3),
... "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
... "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
fields={
images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=cpu,
is_shared=False)
请注意,我们已经指定了MemmapTensor的设备。
这种语法糖允许在需要时直接将查询的张量加载到设备上。
另一个需要考虑的因素是,目前MemmapTensor与autograd操作不兼容。
在节点间操作内存映射张量¶
我们提供了一个分布式脚本的简单示例:其中一个进程创建一个内存映射张量,并将其引用发送给另一个负责更新该张量的工作进程。您可以在 基准测试目录中找到此示例。
简而言之,我们的目标是展示当节点能够访问共享的物理存储时,如何处理大型张量的读写操作。具体步骤包括:
Creating the empty tensor on disk;
Setting the local and remote operations to be executed;
Passing commands from worker to worker using RPC to read and write the shared data.
此示例首先编写一个函数,该函数使用全 1 张量在特定索引处更新 TensorDict 实例:
>>> def fill_tensordict(tensordict, idx):
... tensordict[idx] = TensorDict(
... {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
... )
... return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
The CloudpickleWrapper 确保该函数可序列化。
接下来,我们创建一个相当大的张量字典,以说明如果必须通过常规的tensorpipe传递,
那么从一个工作者到另一个工作者传递这个张量字典将非常困难:
>>> tensordict = TensorDict(
... {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )
最后,在主节点上,我们调用远程节点上的函数,然后 检查数据是否已写入所需位置:
>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
... worker_info,
... fill_tensordict_cp,
... args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())
尽管调用 rpc.rpc_sync 时需传入整个 TensorDict,
并更新该对象的特定索引后将其返回至原始工作节点,
但此代码片段的执行速度极快(若内存地址引用已在先前传递,则速度更快;详情请参阅 TorchRL 的分布式重放缓冲区文档)。
该脚本包含额外的 RPC 配置步骤,超出了本文档的范围。