目录

保存TensorDict和tensorclass对象

虽然我们可以直接使用 save() 保存一个 TensorDict,但这会将整个数据结构的内容写入单个文件。 可以很容易地想到,这种做法在某些情况下并不理想!

TensorDict序列化API主要依赖于MemoryMappedTensor 它用于将张量独立地写入磁盘,数据结构模仿了TensorDict的结构。

TensorDict 的序列化速度可比 PyTorch 依赖 save() 的 pickle 方式快一个数量级。本文档将介绍如何创建和操作存储在磁盘上的 TensorDict 数据。

保存内存映射的TensorDicts

当一个张量字典以mmap数据结构转储时,每个条目对应于单个*.memmap文件,并且目录结构由键结构决定:通常,嵌套键对应于子目录。

将数据结构保存为结构化的内存映射张量集合具有以下优势:

  • 保存的数据可以部分加载。如果一个大型模型已保存在磁盘上,但仅需将其中部分权重加载到另一个脚本中创建的模块中,则只有这些权重会被加载到内存中。

  • 保存数据是安全的:使用 pickle 库对大型数据结构进行序列化可能存在安全隐患,因为反序列化过程可能执行任意代码。TensorDict 的加载 API 仅从已保存的 JSON 文件及磁盘上的内存缓冲区(memorybuffers)中读取预选字段。

  • 保存速度快:由于数据被写入多个独立的文件, 我们可以通过启动多个并发线程来分摊 I/O 开销,每个线程各自访问一个专属文件。

  • 保存的数据结构清晰明了:目录树反映了数据内容。

然而,这种方法也有一些缺点:

  • 并非所有数据类型都可以保存。tensorclass 允许保存任何非张量数据:如果这些数据可以用 json 文件表示,则使用 json 格式。否则,非张量数据将独立保存,并以 save() 作为后备。 可以使用 NonTensorData 类来在常规 TensorDict 实例中表示非张量数据。

tensordict的内存映射API依赖于四个核心方法: memmap_(), memmap(), memmap_like()load_memmap()

The memmap_()memmap() 方法将在磁盘上写入数据,无论是否修改包含数据的 tensordict 实例。这些方法可以用于将模型序列化到磁盘(我们使用多个线程来加速序列化):

>>> model = nn.Transformer()
>>> weights = TensorDict.from_module(model)
>>> weights_disk = weights.memmap("/path/to/saved/dir", num_threads=32)
>>> new_weights = TensorDict.load_memmap("/path/to/saved/dir")
>>> assert (weights_disk == new_weights).all()

The memmap_like() 应在需要预先分配数据集到磁盘时使用,典型用法如下:

>>> def make_datum(): # used for illustration purposes
...    return TensorDict({"image": torch.randint(255, (3, 64, 64)), "label": 0}, batch_size=[])
>>> dataset_size = 1_000_000
>>> datum = make_datum() # creates a single instance of a TensorDict datapoint
>>> data = datum.expand(dataset_size) # does NOT require more memory usage than datum, since it's only a view on datum!
>>> data_disk = data.memmap_like("/path/to/data")  # creates the two memory-mapped tensors on disk
>>> del data # data is not needed anymore

如上所示,当将TensorDict`的条目转换为MemoryMappedTensor时,可以控制内存映射保存在磁盘上的位置,以便它们持久化并可以在以后加载。另一方面,也可以使用文件系统。要使用此功能,只需在上述三种序列化方法中丢弃prefix参数即可。

当指定 prefix 时,数据结构遵循 TensorDict 的结构:

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
>>> td.memmap_(prefix="tensordict")

生成以下目录结构

tensordict
├── a.memmap
├── b
│   ├── c.memmap
│   └── meta.json
└── meta.json

meta.json 个文件包含重建 tensordict 所需的所有相关信息,例如设备、批次大小,还包括 tensordict 子类型。 这意味着 load_memmap() 将能够重建复杂的嵌套结构,其中子 tensordicts 的类型与父级不同:

>>> from tensordict import TensorDict, tensorclass, TensorDictBase
>>> from tensordict.utils import print_directory_tree
>>> import torch
>>> import tempfile
>>> td_list = [TensorDict({"item": i}, batch_size=[]) for i in range(4)]
>>> @tensorclass
... class MyClass:
...     data: torch.Tensor
...     metadata: str
>>> tc = MyClass(torch.randn(3), metadata="some text", batch_size=[])
>>> data = TensorDict({"td_list": torch.stack(td_list), "tensorclass": tc}, [])
>>> with tempfile.TemporaryDirectory() as tempdir:
...     data.memmap_(tempdir)
...
...     loaded_data = TensorDictBase.load_memmap(tempdir)
...     assert (loaded_data == data).all()
...     print_directory_tree(tempdir)
tmpzy1jcaoq/
    tensorclass/
        _tensordict/
            data.memmap
            meta.json
        meta.json
    td_list/
        0/
            item.memmap
            meta.json
        1/
            item.memmap
            meta.json
        3/
            item.memmap
            meta.json
        2/
            item.memmap
            meta.json
        meta.json
    meta.json

处理现有 MemoryMappedTensor

如果 TensorDict` 已经包含 MemoryMappedTensor 项,则有几种 可能的行为。

  • 如果 prefix 没有指定并且 memmap() 被调用两次,那么生成的 TensorDict 将包含与原始数据相同的数据。

    >>> td = TensorDict({"a": 1}, [])
    >>> td0 = td.memmap()
    >>> td1 = td0.memmap()
    >>> td0["a"] is td1["a"]
    True
    
  • 如果 prefix 被指定并且与现有 MemoryMappedTensor 实例的前缀不同,则会引发异常, 除非传递了 copy_existing=True

    >>> with tempfile.TemporaryDirectory() as tmpdir_0:
    ...     td0 = td.memmap(tmpdir_0)
    ...     td0 = td.memmap(tmpdir_0)  # works, results are just overwritten
    ...     with tempfile.TemporaryDirectory() as tmpdir_1:
    ...         td1 = td0.memmap(tmpdir_1)
    ...         td_load = TensorDict.load_memmap(tmpdir_1)  # works!
    ...     assert (td_load == td).all()
    ...     with tempfile.TemporaryDirectory() as tmpdir_1:
    ...         td_load = TensorDict.load_memmap(tmpdir_1)  # breaks!
    

    此功能的实现旨在防止用户无意中将内存映射张量从一个位置复制到另一个位置。

TorchSnapshot 兼容性

警告

由于 torchsnapshot 的维护工作即将停止,因此我们将不再为此库实现与 tensordict 兼容的新功能。

TensorDict 与 torchsnapshot 兼容, 一个 PyTorch 检查点库。 TorchSnapshot 将独立保存您的每个张量,其数据结构 模仿您的 tensordict 或 tensorclass。此外,TensorDict 自然地 内置了在不将整个张量加载到内存的情况下在磁盘上保存和加载大型数据集的工具: 换句话说,tensordict + torchsnapshot 的组合 使得可以将大小为数百 GB 的张量加载到预分配的 MemmapTensor 上, 而无需一次性将其传递到 RAM 中。

有两个主要用例:保存和加载内存中的张量字典,以及使用MemmapTensor保存和加载存储在磁盘上的张量字典。

通用用例:内存加载

如果目标 TensorDict 尚未预先分配内存,则此方法适用。 该方法具有灵活性(您可以将任意 TensorDict 加载到您的 TensorDict 中,无需事先了解其具体内容),且相较于其他方法,编码略为简单。 然而,若您的张量极大、无法全部装入内存,则此方法可能失效。 此外,它也不支持直接将数据加载到您指定的设备上。

保存操作的两个主要命令是:

>>> state = {"state": tensordict_source}
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path="/path/to/my/snapshot")

要加载到目标张量字典中,您可以简单地加载快照并更新张量字典。在幕后,此方法将调用tensordict_target.load_state_dict(state_dict), 这意味着state_dict将首先完全放入内存中,然后加载到目标张量字典上:

>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> state_target = {"state": tensordict_target}
>>> snapshot.restore(app_state=state_target)

这里是一个完整的示例:

>>> import uuid
>>> import torchsnapshot
>>> from tensordict import TensorDict
>>> import torch
>>>
>>> tensordict_source = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3)}}, [])
>>> state = {"state": tensordict}
>>> path = f"/tmp/{uuid.uuid4()}"
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path)
>>> # later
>>> snapshot = torchsnapshot.Snapshot(path=path)
>>> tensordict2 = TensorDict()
>>> target_state = {
>>>     "state": tensordict2
>>> }
>>> snapshot.restore(app_state=target_state)
>>> assert (tensordict == tensordict2).all()

保存和加载大型数据集

如果数据集过大,无法完全加载到内存中,上述方法很容易失效。 我们利用 torchsnapshot 的功能,将张量分块加载到其预先分配的目标位置上。 这要求您事先明确目标数据的形状(shape)、所在设备(device)等信息, 但相较于能够对模型或数据加载过程进行断点保存(checkpoint),这点额外工作是值得的!

与前一个示例不同,我们将不会使用load_state_dict()方法 的TensorDict,而是从目标对象中获取的state_dict 并将保存的数据重新填充到该对象中。

再次,只需两行代码即可保存数据:

>>> app_state = {
...     "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path="/path/to/my/snapshot")

我们一直在使用 torchsnapshot.StateDict 并且我们显式调用了 my_tensordict_source.state_dict(keep_vars=True),与之前的示例不同。 现在,将其加载到目标张量字典上:

>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> app_state = {
...     "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)

在这个示例中,加载完全由torchsnapshot处理,即没有调用TensorDict.load_state_dict()

注意

这有两个重要的含义:

  1. 自从 LazyStackedTensorDict.state_dict()(和其他懒惰的张量字典类) 在执行某些操作后返回数据的副本,加载到状态字典中不会更新原始类。但是,由于支持state_dict()操作, 这不会引发错误。

  2. 同样地,由于状态字典在原地更新,而张量字典没有使用TensorDict.update()TensorDict.set()进行更新,因此目标张量字典中缺失的键将不会被注意到。

这里是一个完整的示例:

>>> td = TensorDict({"a": torch.randn(3), "b": TensorDict({"c": torch.randn(3, 1)}, [3, 1])}, [3])
>>> td.memmap_()
>>> assert isinstance(td["b", "c"], MemmapTensor)
>>>
>>> app_state = {
...     "state": torchsnapshot.StateDict(tensordict=td.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
>>>
>>> td_dest = TensorDict({"a": torch.zeros(3), "b": TensorDict({"c": torch.zeros(3, 1)}, [3, 1])}, [3])
>>> td_dest.memmap_()
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)
>>> app_state = {
...     "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>> # sanity check
>>> assert (td_dest == td).all()
>>> assert (td_dest["b"].batch_size == td["b"].batch_size)
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)

最后,tensorclass 也支持这个功能。代码与上面的非常相似:

>>> from __future__ import annotations
>>> import uuid
>>> from typing import Union, Optional
>>>
>>> import torchsnapshot
>>> from tensordict import TensorDict, MemmapTensor
>>> import torch
>>> from tensordict.prototype import tensorclass
>>>
>>> @tensorclass
>>> class MyClass:
...      x: torch.Tensor
...      y: Optional[MyClass]=None
...
>>> tc = MyClass(x=torch.randn(3), y=MyClass(x=torch.randn(3), batch_size=[]), batch_size=[])
>>> tc.memmap_()
>>> assert isinstance(tc.y.x, MemmapTensor)
>>>
>>> app_state = {
...     "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
>>> tc_dest = MyClass(x=torch.randn(3), y=MyClass(x=torch.randn(3), batch_size=[]), batch_size=[])
>>> tc_dest.memmap_()
>>> assert isinstance(tc_dest.y.x, MemmapTensor)
>>> app_state = {
...     "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>>
>>> assert (tc_dest == tc).all()
>>> assert (tc_dest.y.batch_size == tc.y.batch_size)
>>> assert isinstance(tc_dest.y.x, MemmapTensor)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源