保存 TensorDict 和 tensorclass 对象¶
虽然我们可以使用 ,但 this
将创建一个包含数据结构全部内容的单个文件。
人们可以很容易地想象出这是次优的情况!
TensorDict 序列化 API 主要依赖于 which,用于以数据结构在磁盘上独立写入张量
它模拟了 TensorDict 的 one。
TensorDict 的序列化速度__faster__可以比
PyTorch 的 具有 的 pickle 依赖。本文档说明
如何使用 TensorDict 创建存储在磁盘上的数据并与之交互。
保存内存映射的 TensorDict¶
当 tensordict 作为 mmap 数据结构转储时,每个条目对应
添加到单个文件中,目录结构由
Key 结构:通常,嵌套 key 对应于子目录。*.memmap
将数据结构保存为一组结构化的内存映射张量具有以下 优势:
保存的数据可以部分加载。如果大型模型保存在磁盘上,但 只需要将其权重的一部分加载到在单独的 脚本,则只有这些权重才会加载到内存中。
保存数据是安全的:使用 pickle 库序列化大数据结构 可能是不安全的,因为 unpickling 可以执行任何任意代码。TensorDict 的加载 API 仅从保存的 json 文件和内存缓冲区中读取预选字段 保存在磁盘上。
保存速度快:因为数据被写入多个独立的文件中, 我们可以通过启动多个并发线程来分摊 IO 开销,这些线程 每个 VPN 都自行访问一个专用文件。
保存数据的结构是显而易见的:目录树是指示性的 的数据内容。
但是,这种方法也有一些缺点:
并非每种数据类型都可以保存。
允许保存 任何非张量数据:如果这些数据可以用 JSON 文件表示,则 JSON 格式。否则,非 Tensor 数据将被独立保存 with
作为后备。 该
类可用于表示非张量 data 的实例。
tensordict 的内存映射 API 依赖于四个核心方法:、
和
。
and
方法将在修改或不修改 tensordict 的情况下将数据写入磁盘
实例。这些方法可用于序列化模型
在磁盘上(我们使用多个线程来加速序列化):
>>> model = nn.Transformer()
>>> weights = TensorDict.from_module(model)
>>> weights_disk = weights.memmap("/path/to/saved/dir", num_threads=32)
>>> new_weights = TensorDict.load_memmap("/path/to/saved/dir")
>>> assert (weights_disk == new_weights).all()
>>> def make_datum(): # used for illustration purposes
... return TensorDict({"image": torch.randint(255, (3, 64, 64)), "label": 0}, batch_size=[])
>>> dataset_size = 1_000_000
>>> datum = make_datum() # creates a single instance of a TensorDict datapoint
>>> data = datum.expand(dataset_size) # does NOT require more memory usage than datum, since it's only a view on datum!
>>> data_disk = data.memmap_like("/path/to/data") # creates the two memory-mapped tensors on disk
>>> del data # data is not needed anymore
如上所示,当将 a 的条目转换为 时,可以控制其中的位置
内存映射保存在磁盘上,以便它们持久存在并且可以
在以后加载。另一方面,也可以使用文件系统。
要使用它,只需丢弃 three serialization 中的参数
方法。
TensorDict`
prefix
指定 a 后,数据结构遵循 TensorDict 的 1:prefix
>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
>>> td.memmap_(prefix="tensordict")
生成以下目录结构
tensordict
├── a.memmap
├── b
│ ├── c.memmap
│ └── meta.json
└── meta.json
这些文件包含用于重建
tensordict,例如 device、batch-size,但也包括 tensordict 子类型。
这意味着将能够
重建 sub-tensordict 具有不同类型的复杂嵌套结构
比父母:
meta.json
>>> from tensordict import TensorDict, tensorclass, TensorDictBase
>>> from tensordict.utils import print_directory_tree
>>> import torch
>>> import tempfile
>>> td_list = [TensorDict({"item": i}, batch_size=[]) for i in range(4)]
>>> @tensorclass
... class MyClass:
... data: torch.Tensor
... metadata: str
>>> tc = MyClass(torch.randn(3), metadata="some text", batch_size=[])
>>> data = TensorDict({"td_list": torch.stack(td_list), "tensorclass": tc}, [])
>>> with tempfile.TemporaryDirectory() as tempdir:
... data.memmap_(tempdir)
...
... loaded_data = TensorDictBase.load_memmap(tempdir)
... assert (loaded_data == data).all()
... print_directory_tree(tempdir)
tmpzy1jcaoq/
tensorclass/
_tensordict/
data.memmap
meta.json
meta.json
td_list/
0/
item.memmap
meta.json
1/
item.memmap
meta.json
3/
item.memmap
meta.json
2/
item.memmap
meta.json
meta.json
meta.json
处理现有
¶
如果 已经包含条目,则有一些
可能的行为。
TensorDict`
如果未指定并
调用 两次,则生成的 TensorDict 将包含与原始 TensorDict 相同的数据。
prefix
>>> td = TensorDict({"a": 1}, []) >>> td0 = td.memmap() >>> td1 = td0.memmap() >>> td0["a"] is td1["a"] True
如果指定了 ID 并且与现有
实例的前缀不同,则会引发异常。 除非传递 copy_existing=True:
prefix
>>> with tempfile.TemporaryDirectory() as tmpdir_0: ... td0 = td.memmap(tmpdir_0) ... td0 = td.memmap(tmpdir_0) # works, results are just overwritten ... with tempfile.TemporaryDirectory() as tmpdir_1: ... td1 = td0.memmap(tmpdir_1) ... td_load = TensorDict.load_memmap(tmpdir_1) # works! ... assert (td_load == td).all() ... with tempfile.TemporaryDirectory() as tmpdir_1: ... td_load = TensorDict.load_memmap(tmpdir_1) # breaks!
实现此功能是为了防止用户无意中复制 memorymapped 张量从一个位置到另一个位置。
TorchSnapshot 兼容性¶
警告
由于 torchsnapshot 维护已停止。因此,我们不会实施 Tensordict 与此库兼容的新功能。
TensorDict 与 torchsnapshot 兼容,
PyTorch 检查点库。
TorchSnapshot 将独立保存每个张量,其数据结构为
模拟您的 TensorDict 或 TensorClass 之一。此外,TensorDict 自然具有
在磁盘上保存和加载大型数据集所需的工具,而无需
在内存中加载完整的张量:换句话说,tensordict + torchsnapshot 的组合
可以将几百 Gb 的张量加载到
pre-allocated 而不在 RAM 上的一个块中传递它。MemmapTensor
有两个主要用例:保存和加载适合内存的 tensordict,
以及使用 保存和加载存储在磁盘上的 Tensordict。MemmapTensor
一般用例:内存中加载¶
如果您的目标 tensordict 未预先分配,则此方法适用。 这提供了灵活性(您可以将任何 tensordict 加载到 tensordict 上,您 不需要提前知道它的内容),而这种方法只是勉强 比其他更容易编码。 但是,如果您的张量非常大且无法放入内存,则这可能会中断。 此外,它不允许您直接加载到您选择的设备上。
保存操作要记住的两个主要命令是:
>>> state = {"state": tensordict_source}
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path="/path/to/my/snapshot")
要加载到目标 tensordict 上,您只需加载快照并更新
tensordict 的在后台,此方法将调用 、
这意味着 将首先完全放入内存中,然后加载到
目标 tensordict 中:tensordict_target.load_state_dict(state_dict)
state_dict
>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> state_target = {"state": tensordict_target}
>>> snapshot.restore(app_state=state_target)
下面是一个完整的示例:
>>> import uuid
>>> import torchsnapshot
>>> from tensordict import TensorDict
>>> import torch
>>>
>>> tensordict_source = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3)}}, [])
>>> state = {"state": tensordict}
>>> path = f"/tmp/{uuid.uuid4()}"
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path)
>>> # later
>>> snapshot = torchsnapshot.Snapshot(path=path)
>>> tensordict2 = TensorDict()
>>> target_state = {
>>> "state": tensordict2
>>> }
>>> snapshot.restore(app_state=target_state)
>>> assert (tensordict == tensordict2).all()
保存和加载大数据集¶
如果数据集太大而无法放入内存,则上述方法很容易中断。 我们利用 torchsnapshot 的功能以小块加载张量 在其预先分配的目标上。 这需要您知道目标数据将具有和存在的形状、设备等, 但是,要能够对模型进行检查点或数据加载,只需付出很小的代价!
与前面的示例相比,我们不会使用
of 的,而是从目标对象获取的
我们将使用保存的数据重新填充。load_state_dict()
TensorDict
state_dict
同样,两行代码足以保存数据:
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path="/path/to/my/snapshot")
与前面的示例不同,我们一直在使用并显式调用 。
现在,要将其加载到目标 tensordict 上:torchsnapshot.StateDict
my_tensordict_source.state_dict(keep_vars=True)
>>> snapshot = Snapshot(path="/path/to/my/snapshot")
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
在这个例子中,加载完全由 torchsnapshot 处理,即。有
不调用 .TensorDict.load_state_dict()
注意
这有两个重要含义:
Since (和其他惰性 tensordict 类) 在执行某些操作后返回数据的副本,加载到 state-dict 不会更新原始类。但是,由于 state_dict() 操作 ,则不会引发错误。
LazyStackedTensorDict.state_dict()
同样,由于 state-dict 是就地更新的,但 tensordict 不是 使用 或 进行更新,缺少 key 中,则不会被注意到。
TensorDict.update()
TensorDict.set()
下面是一个完整的示例:
>>> td = TensorDict({"a": torch.randn(3), "b": TensorDict({"c": torch.randn(3, 1)}, [3, 1])}, [3])
>>> td.memmap_()
>>> assert isinstance(td["b", "c"], MemmapTensor)
>>>
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=td.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
>>>
>>> td_dest = TensorDict({"a": torch.zeros(3), "b": TensorDict({"c": torch.zeros(3, 1)}, [3, 1])}, [3])
>>> td_dest.memmap_()
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>> # sanity check
>>> assert (td_dest == td).all()
>>> assert (td_dest["b"].batch_size == td["b"].batch_size)
>>> assert isinstance(td_dest["b", "c"], MemmapTensor)
最后,tensorclass 也支持此功能。该代码与上面的代码非常相似:
>>> from __future__ import annotations
>>> import uuid
>>> from typing import Union, Optional
>>>
>>> import torchsnapshot
>>> from tensordict import TensorDict, MemmapTensor
>>> import torch
>>> from tensordict.prototype import tensorclass
>>>
>>> @tensorclass
>>> class MyClass:
... x: torch.Tensor
... y: Optional[MyClass]=None
...
>>> tc = MyClass(x=torch.randn(3), y=MyClass(x=torch.randn(3), batch_size=[]), batch_size=[])
>>> tc.memmap_()
>>> assert isinstance(tc.y.x, MemmapTensor)
>>>
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True))
... }
>>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}")
>>>
>>> tc_dest = MyClass(x=torch.randn(3), y=MyClass(x=torch.randn(3), batch_size=[]), batch_size=[])
>>> tc_dest.memmap_()
>>> assert isinstance(tc_dest.y.x, MemmapTensor)
>>> app_state = {
... "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True))
... }
>>> snapshot.restore(app_state=app_state)
>>>
>>> assert (tc_dest == tc).all()
>>> assert (tc_dest.y.batch_size == tc.y.batch_size)
>>> assert isinstance(tc_dest.y.x, MemmapTensor)