序列化语义¶
本说明介绍如何保存和加载 PyTorch 张量和模块状态 以及如何序列化 Python 模块以便可以用 C++ 加载它们。
目录
保存和加载张量¶
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch 文件通常使用“.pt”或“.pth”扩展名编写。
并
默认使用 Python 的 pickle,
因此,您还可以将多个张量保存为 Python 对象(如元组)的一部分,
lists 和 dicts:
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
如果 数据结构是可 pickle 的。
保存和加载张量会保留视图¶
保存张量会保留它们的视图关系:
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在幕后,这些张量共享相同的 “存储空间”。有关更多信息,请参阅 Tensor 视图 在 views 和 storage 上。
当 PyTorch 保存 tensor 时,它会保存它们的存储对象和张量 元数据。这是一个实现细节,在 future,但它通常会节省空间并允许 PyTorch 轻松 重建加载的 Tensor 之间的视图关系。在上述 snippet 中,只有一个存储被写入 'tensors.pt' 中。
但是,在某些情况下,可能不需要保存当前存储对象 并创建大得令人望而却步的文件。在下面的代码片段中,存储很多 大于保存的张量写入文件:
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
而不是仅将小张量中的 5 个值保存到 'small.pt' 保存并加载了它与 large 共享的 Storage 中的 999 个值。
当保存元素少于其存储对象的张量时, 可以通过首先克隆张量来减少保存的文件。克隆张量 生成一个新 Tensor,其中包含仅包含值的新 Storage 对象 在 Tensor 中:
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
但是,由于克隆的张量彼此独立,因此它们具有 原始张量没有 view relationships 。如果文件大小和 在保存小于其 storage 对象,则必须小心构造新的 Tensor,以最小化 其存储对象的大小,但仍具有所需的视图关系 在保存之前。
保存和加载 torch.nn.Modules¶
另请参阅:教程:保存和加载模块
在 PyTorch 中,模块的状态经常使用“状态字典”进行序列化。 模块的状态字典包含其所有参数和持久缓冲区:
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
出于兼容性原因,建议不要直接保存模块
改为仅保存其 state dict。Python 模块甚至有一个函数 ,可以从 state dict 恢复它们的状态:
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
即使是自定义模块和包含其他模块的模块也有 state dicts 和 可以使用以下模式:
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
的序列化文件格式
¶
从 PyTorch 1.6.0 开始,默认返回未压缩的 ZIP64
archive 除非用户设置 .torch.save
_use_new_zipfile_serialization=False
在此存档中,文件按以下顺序排序
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 这些条目如下:
data.pkl
是封存传递给的对象的结果,以排除它所包含的对象torch.save
torch.Storage
byteorder
包含保存时带有 (“little” 或 “big”) 的字符串sys.byteorder
data/
包含对象中的所有存储,其中每个存储都是一个单独的文件version
包含可在加载时使用的保存版本号
保存时,PyTorch 会确保每个文件的本地文件头都已填充 设置为 64 字节的倍数的偏移量,确保每个文件的偏移量 是 64 字节对齐的。
注意
某些设备(如 XLA)上的张量被序列化为腌制的 numpy 数组。如
因此,它们的存储不会被序列化。在这些情况下,可能不存在
在 checkpoint 中。data/
跟
¶
从版本 2.6 开始,如果未传递参数,将使用。torch.load
weights_only=True
pickle_module
如 , restricts 的文档中所述
unpickler 用于仅执行 plain 以及其他一些原始类型所需的函数/构建类。进一步
与模块提供的默认值不同,Unpickler
不允许在 unpickling 期间动态导入任何内容。
weights_only=True
torch.load
state_dicts
torch.Tensors
Unpickler
pickle
weights_only
如上所述,使用 .如果加载旧的
检查点,我们建议使用 。当加载包含
Tensor 子类,则可能会有需要列入允许列表的函数/类,有关更多详细信息,请参阅下文。state_dict
torch.save
nn.Module
weights_only=False
如果 Unpickler 遇到未列入允许列表的函数或类
默认情况下,在 pickle 文件中,您应该会看到一个可操作的错误,如下所示weights_only
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
请按照错误消息中的步骤操作,并仅在您信任函数或类的情况下将其列入允许列表。
要获取检查点中尚未列入允许列表的所有 GLOBAL(函数/类),您可以使用 它将返回以下形式的字符串列表。如果您信任这些函数/类,则可以导入它们并将其列入允许列表
错误消息 VIA
或 Context Manager
。
{__module__}.{__name__}
要访问用户允许列表的函数/类列表,您可以使用 和
要清除当前列表,请参阅
。
故障 排除
¶
获取不安全的全局变量¶
需要注意的是,它会静态地分析 checkpoint,
某些类型可能是在 unpickling 过程中动态构建的,因此不会由
.一个这样的例子是在 numpy 中。在将您报告
的所有函数/类列入允许列表后,可能会看到类似
dtypes
numpy < 1.25
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
这可以通过 列入允许列表。{add_}safe_globals([type(np.dtype(np.float32))])
在你会看到numpy >=1.25
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
这可以通过 列入允许列表。{add_}safe_globals([np.dtypes.Float32DType])
序列化 torch.nn.Modules 并在 C++ 中加载它们¶
Смотритетакже: 教程:在 C++ 中加载 TorchScript 模型
ScriptModule 可以序列化为 TorchScript 程序并加载
使用 .
这种序列化对所有模块的方法、子模块、参数、
和属性,它允许用 C++ 加载序列化程序
(即没有 Python)。
和 may not 之间的区别
立即明确。
使用 pickle 保存 Python 对象。
这对于原型设计、研究和培训特别有用。
另一方面,将 ScriptModules 序列化为格式
可以用 Python 或 C++ 加载。这在保存和加载 C++ 时非常有用
模块,或者用于运行使用 C++ 在 Python 中训练的模块,这是一种常见的做法
部署 PyTorch 模型时。
要在 Python 中编写脚本、序列化和加载模块:
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )
跟踪的模块也可以用 保存,但需要注意
仅序列化跟踪的代码路径。以下示例演示
这:
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上面的模块有一个 if 语句,它不是由跟踪的输入触发的, ,因此不是 traced 模块的一部分,也不与它一起序列化。 但是,脚本化模块包含 if 语句并用它进行序列化。 有关脚本编写和跟踪的更多信息,请参阅 TorchScript 文档。
最后,要在 C++ 中加载模块:
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅 PyTorch C++ API 文档。
跨 PyTorch 版本保存和加载 ScriptModule¶
PyTorch 团队建议保存和加载具有相同版本的 PyTorch 的 Torch 中。旧版本的 PyTorch 可能不支持较新的模块,而较新的 版本可能已删除或修改较旧的行为。这些变化是 在 PyTorch 的发行说明, 依赖于已更改功能的模块可能需要更新 以继续正常工作。在有限的情况下,PyTorch 将 保留序列化 ScriptModule 的历史行为,因此它们不需要 更新。
torch.div 执行整数除法¶
在 PyTorch 1.5 及更早版本中,当
给定两个整数输入:
# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
但是,在 PyTorch 1.7 中,将始终执行真正的除法
的输入,就像 Python 3 中的 division 一样:
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
的行为保留在序列化的 ScriptModules 中。
也就是说,使用 1.6 之前的 PyTorch 版本序列化的 ScriptModule 将继续
查看
在给定两个整数输入时执行向下取整除法
即使加载了较新版本的 PyTorch 也是如此。在 PyTorch 1.6 及更高版本上使用
和序列化的 ScriptModule 无法加载到早期版本的
但是,PyTorch 的 API 版本无法理解新行为。
torch.full 始终推断 float dtype¶
在 PyTorch 1.5 及更早版本中,始终返回浮点张量,
无论给定的 fill 值如何:
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!
但是,在 PyTorch 1.7 中,将推断返回的张量的
dtype 从 fill 值获取:
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
的行为保留在序列化的 ScriptModules 中。那是
使用 1.6 之前的 PyTorch 版本序列化的 ScriptModule 将继续看到
torch.full 默认返回浮点张量,即使给定 bool 或
整数填充值。ScriptModules using
and serialized on PyTorch 1.6
及更高版本无法加载到早期版本的 PyTorch 中,因为这些
早期版本不理解新行为。
效用函数¶
以下实用程序函数与序列化相关:
- torch.serialization 中。register_package(priority, tagger, deserializer)[来源][来源]¶
注册可调用对象,以便标记和反序列化具有关联优先级的存储对象。 标记会在保存时将设备与存储对象相关联,而反序列化会将 storage 对象添加到适当的设备中。 并按照它们给出的顺序运行,直到 tagger/deserializer 返回 值,该值不是 None。
tagger
deserializer
priority
要覆盖全局注册表中设备的反序列化行为,可以注册一个 优先级高于现有标记器的标记器。
此函数还可用于为新设备注册 tagger 和 deserializer。
- 参数
priority (int) – 表示与标记器和反序列化器关联的优先级,其中较低的 value 表示优先级较高。
tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 可调用的,它接收存储对象并将其标记的设备作为字符串返回 或 None。
反序列化器 (Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 可调用对象,它接受存储对象和设备字符串并返回存储 object 或 None。
- 返回
没有
例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization 中。set_crc32_options(compute_crc32)[来源][来源]¶
-
注意
将此项设置为 this 可能会解压缩输出 由于 CRC32 损坏而失败或发出警告。然而 能够加载文件。
False
torch.save
torch.load
- 参数
compute_crc32 (bool) – 设置 CRC32 计算标志
- torch.serialization 中。get_default_load_endianness()[来源][来源]¶
获取加载文件的回退字节顺序
如果保存的 checkpoint 中不存在 byteorder mark,则 此字节顺序用作回退。 默认情况下,它是 “本机” 字节顺序。
- 返回
可选 [LoadEndianness]
- 返回类型
default_load_endian
- torch.serialization 中。set_default_load_endianness(字节序)[source][source]¶
设置加载文件的回退字节顺序
如果保存的 checkpoint 中不存在 byteorder mark,则 此字节顺序用作回退。 默认情况下,它是 “本机” 字节顺序。
- 参数
字节序 – 新的回退字节顺序
- torch.serialization 中。get_default_mmap_options()[来源][来源]¶
获取 的默认 mmap 选项
。
mmap=True
默认为 。
mmap.MAP_PRIVATE
- 返回
int
- 返回类型
default_mmap_options
- torch.serialization 中。set_default_mmap_options(标志)[来源][来源]¶
上下文管理器或函数,用于
为 with to 标志设置默认 mmap 选项。
mmap=True
目前,仅支持 or 。 如果您需要在此处添加任何其他选项,请打开一个问题。
mmap.MAP_PRIVATE
mmap.MAP_SHARED
注意
Windows 目前不支持此功能。
- 参数
flags (int) – 或
mmap.MAP_PRIVATE
mmap.MAP_SHARED
- torch.serialization 中。add_safe_globals(safe_globals)[来源][来源]¶
将给定的 globals 标记为 safe for load。例如,函数 添加到此列表中的 unpickling 可以在 unpickling 期间调用,类可以实例化 并设置了 state 。
weights_only
列表中的每个项目都可以是函数/类,也可以是 (function/class, string),其中 string 是函数/类的完整路径。
在序列化格式中,每个函数都用其完整的 path 设置为 .调用此 API 时,你可以提供此 full path 的 path 匹配 checkpoint 中的路径,否则将使用 default。
{__module__}.{__name__}
{fn.__module__}.{fn.__name__}
- 参数
safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要标记为安全的全局变量列表
例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- torch.serialization 中。get_unsafe_globals_in_checkpoint(f)[来源][来源]¶
返回对象中对 . 不安全的函数/类的字符串列表。
torch.save
weights_only
对于给定的函数或类 ,相应的字符串将采用 格式。
f
{f.__module__}.{f.__name__}
此函数将返回检查点中不在标记为 safe 的集合中的任何 GLOBALs for(通过
OR
Context 或 默认被允许列表)。
weights_only
torch
注意
此函数将静态反汇编 checkpoint 中的 pickle 文件。 这意味着在解封期间动态推送到堆栈上的任何类 将不包含在输出中。
- 类 torch.serialization 中。safe_globals(safe_globals)[来源][来源]¶
Context-manager 将某些全局变量添加为安全加载。
weights_only
例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- 类 torch.serialization 中。skip_data(materialize_fake_tensors=False)[来源][来源]¶
跳过为调用写入存储字节的上下文管理器。
torch.save
存储仍将被保存,但其字节通常会写入的空间 将是空白区域。然后,可以在单独的传递中填充存储字节。
警告
上下文管理器是一个早期原型,可能会发生更改。
skip_data
- 参数
materialize_fake_tensors (bool) - 是否具体化 FakeTensors。
例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])