目录

序列化语义

此说明介绍了如何在Python中保存和加载PyTorch张量和模块状态,以及如何序列化Python模块以便它们可以在C++中加载。

保存和加载张量

torch.save()torch.load() 让你可以轻松地保存和加载张量:

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照惯例,PyTorch文件通常使用‘.pt’或‘.pth’扩展名编写。

torch.save()torch.load() 默认使用 Python 的 pickle, 因此你也可以将多个张量作为 Python 对象(如元组、列表和字典)的一部分进行保存:

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

包含PyTorch张量的自定义数据结构也可以保存,只要该数据结构是可pickle的。

保存和加载张量会保留视图

保存张量会保留它们的视图关系:

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕后,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅 张量视图

当PyTorch保存张量时,它会分别保存存储对象和张量元数据。这是一个可能在未来发生变化的实现细节,但它通常可以节省空间,并让PyTorch轻松地重建加载张量之间的视图关系。例如,在上述代码片段中,只有一个存储被写入到‘tensors.pt’。

在某些情况下,保存当前的存储对象可能是不必要的,并且会创建非常大的文件。在下面的代码片段中,写入文件的存储比保存的张量大得多:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

不是只保存small张量中的五个值到‘small.pt’, 而是保存并加载了它与large共享的999个值。

当保存的张量元素少于其存储对象时,可以通过先克隆张量来减少保存文件的大小。克隆一个张量会产生一个新的张量,该张量具有一个新的存储对象,其中仅包含张量中的值:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

由于克隆的张量彼此独立,因此它们没有任何原始张量所具有的视图关系。如果在保存小于其存储对象的张量时,文件大小和视图关系都很重要,那么在保存之前必须小心地构建新的张量,以最小化其存储对象的大小,同时仍然具有所需的视图关系。

保存和加载torch.nn模块

另请参阅:教程:保存和加载模块

在PyTorch中,模块的状态通常使用“状态字典”进行序列化。 模块的状态字典包含其所有参数和持久缓冲区:

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

为了兼容性原因,建议只保存模块的状态字典而不是直接保存模块。Python模块甚至有一个函数, load_state_dict(),用于从状态字典恢复它们的状态:

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

请注意,状态字典首先从其文件中加载 torch.load() 然后使用 load_state_dict() 恢复状态。

即使是自定义模块和包含其他模块的模块也有状态字典,并且可以使用这种模式:

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

序列化文件格式用于 torch.save

自PyTorch 1.6.0起,torch.save 默认返回未压缩的ZIP64 存档,除非用户设置 _use_new_zipfile_serialization=False

在这个存档中,文件的排列顺序如下

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
The entries are as follows:
  • data.pkl 是将对象传递给 torch.save 后进行pickle操作的结果 不包括其中包含的 torch.Storage 个对象

  • byteorder 包含一个字符串,在保存时指示 sys.byteorder(“小端”或“大端”)

  • data/ 包含对象中的所有存储,其中每个存储都是一个单独的文件

  • version 在保存时包含一个版本号,该版本号可以在加载时使用

在保存时,PyTorch 将确保每个文件的本地文件头被填充到一个偏移量,该偏移量是 64 字节的倍数,从而确保每个文件的偏移量都是 64 字节对齐的。

注意

在某些设备(如XLA)上的张量被序列化为pickle的numpy数组。因此,它们的存储不会被序列化。在这种情况下,data/ 可能不会存在于检查点中。

torch.loadweights_only=True

从2.6版本开始,如果未传递pickle_module参数,torch.load将使用weights_only=True

如文档中所述 torch.load()weights_only=True 限制了在 torch.load 中使用的 unpickler 只能执行构建 state_dicts 的普通 torch.Tensors 以及一些其他基本类型所需的函数/类。此外,与 pickle 模块提供的默认 Unpickler 不同,weights_only Unpickler 在反序列化过程中不允许动态导入任何内容。

如上所述,在使用torch.save时,保存模块的state_dict是一个最佳实践。如果加载包含nn.Module的旧检查点,我们建议使用weights_only=False。当加载包含张量子类的检查点时,可能需要允许列表中的函数/类,请参阅以下详细信息。

如果weights_only Unpickler在pickle文件中遇到默认不允许的函数或类,你应该看到类似以下可操作的错误

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
    1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
        but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    2. Alternatively, to load with `weights_only=True` please check the recommended
       steps in the following error message.
       WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
       default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
       `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
       if you trust this class/function.

请按照错误消息中的步骤操作,并仅在您信任的情况下将函数或类列入白名单。

要获取检查点中尚未列入白名单的所有全局(函数/类),可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它将返回一个字符串列表,形式为 {__module__}.{__name__}。如果您信任这些函数/类,可以根据错误消息通过 torch.serialization.add_safe_globals() 或上下文管理器 torch.serialization.safe_globals 导入它们并将其列入白名单。

要访问用户允许的函数/类列表,您可以使用 torch.serialization.get_safe_globals() 并且 要清除当前列表,请参见 torch.serialization.clear_safe_globals()

故障排除 weights_only

获取不安全的全局变量

需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint() 静态分析检查点, 某些类型可能在反序列化过程中动态构建,因此不会被 torch.serialization.get_unsafe_globals_in_checkpoint() 报告。一个这样的例子是 dtypes 在 numpy 中。在 numpy < 1.25 之后,在允许所有由 torch.serialization.get_unsafe_globals_in_checkpoint() 报告的函数/类后,你可能会看到类似以下的错误:

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>

这可以通过 {add_}safe_globals([type(np.dtype(np.float32))]) 加入白名单。

numpy >=1.25 你会看到

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>

这可以通过 {add_}safe_globals([np.dtypes.Float32DType]) 加入白名单。

环境变量

有两个环境变量将影响torch.load的行为。如果无法访问torch.load的调用位置,这些可能会有所帮助。

  • TORCH_FORCE_WEIGHTS_ONLY_LOAD=1 将覆盖所有 torch.load 调用点以使用 weights_only=True

  • TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 将使 torch.load 个调用点仅使用 weights_only=False 仅当 如果未将 weights_only 作为参数传递。

序列化torch.nn.Modules并在C++中加载它们

另请参阅:教程:在C++中加载TorchScript模型

ScriptModules 可以被序列化为 TorchScript 程序,并使用 torch.jit.load() 加载。 这种序列化编码了所有模块的方法、子模块、参数和属性,允许序列化的程序在 C++ 中加载(即无需 Python)。

torch.jit.save()torch.save() 之间的区别可能并不立即明显。torch.save() 使用 pickle 保存 Python 对象。这对于原型设计、研究和训练特别有用。torch.jit.save() 则将 ScriptModules 序列化为一种可以在 Python 或 C++ 中加载的格式。这在保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时非常有用,这是部署 PyTorch 模型时的常见做法。

要编写脚本、序列化和加载Python中的模块:

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

跟踪的模块也可以使用 torch.jit.save() 保存,但需要注意的是,只有跟踪的代码路径会被序列化。以下示例演示了这一点:

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上述模块包含一个if语句,该语句不会被跟踪的输入触发, 因此不包含在跟踪的模块中,也不会与之序列化。 然而,脚本化的模块包含if语句,并且会与之序列化。 有关脚本化和跟踪的更多信息,请参阅TorchScript文档

最后,要在C++中加载模块:

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

有关如何在C++中使用PyTorch模块的详细信息,请参阅PyTorch C++ API文档

在不同版本的PyTorch之间保存和加载ScriptModules

PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。旧版本的 PyTorch 可能不支持新模块,而新版本可能已移除或修改了旧的行为。这些更改在 PyTorch 的 发行说明 中有明确描述, 依赖于已更改功能的模块可能需要更新才能继续正常工作。在以下有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。

torch.div 执行整数除法

在PyTorch 1.5及更早版本中,torch.div() 在给定两个整数输入时会执行向下取整除法:

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

在PyTorch 1.7中,torch.div() 将始终对其输入执行真正的除法运算,就像Python 3中的除法一样:

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

PyTorch 1.6 之前版本序列化 ScriptModules 中 torch.div() 的行为得以保留。 也就是说,使用 PyTorch 1.6 之前的版本序列化的 ScriptModules 在加载时,即使使用更新版本的 PyTorch,当给定两个整数输入时,torch.div() 仍然会执行地板除法。 然而,使用 torch.div() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在早期版本的 PyTorch 中加载,因为这些早期版本不理解新的行为。

torch.full 总是推断为浮点数据类型

在PyTorch 1.5及更早版本中 torch.full() 总是返回一个浮点数张量, 无论给定的填充值是什么:

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

在PyTorch 1.7中,torch.full() 将根据填充值推断返回张量的 dtype:

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full()的行为在序列化的ScriptModules中得以保留。也就是说,使用1.6版本之前的PyTorch序列化的ScriptModules将继续默认返回浮点张量,即使给定布尔值或整数值填充。然而,在PyTorch 1.6及之后版本上序列化的使用torch.full()的ScriptModules无法在早期版本的PyTorch中加载,因为这些早期版本不了解新的行为。

实用函数

以下实用函数与序列化相关:

torch.serialization.register_package(priority, tagger, deserializer)[source][source]

注册带有相关优先级的可调用函数,用于对存储对象进行标记和反序列化。 标记在保存时将设备与存储对象关联,而反序列化则在加载时将存储对象移动到适当的设备。taggerdeserializer 按照它们的priority给定的顺序运行,直到某个标记器/反序列化程序返回一个不是None的值。

要覆盖全局注册表中设备的反序列化行为,可以注册一个优先级高于现有标记器的标记器。

此功能还可以用于为新设备注册标记器和反序列化器。

Parameters
Returns

None

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_crc32_options()[source][source]

获取是否 torch.save() 为每个记录计算并写入crc32。

默认为 True

Return type

布尔

torch.serialization.set_crc32_options(compute_crc32)[source][source]

设置是否 torch.save() 为每个记录计算并写入crc32。

注意

将此设置为 False 可能会导致解压 torch.save 输出失败或由于CRC32损坏而发出警告。但是 torch.load 将能够加载文件。

Parameters

compute_crc32 (bool) – 设置crc32计算标志

torch.serialization.get_default_load_endianness()[source][source]

获取加载文件的备用字节顺序

如果保存的检查点中不存在字节顺序标记, 则使用此字节顺序作为备用。 默认情况下,它是“本地”字节顺序。

Returns

Optional[LoadEndianness]

Return type

default_load_endian

torch.serialization.set_default_load_endianness(endianness)[source][source]

设置加载文件时的回退字节顺序

如果保存的检查点中不存在字节顺序标记, 则使用此字节顺序作为备用。 默认情况下,它是“本地”字节顺序。

Parameters

字节序 – 新的备用字节顺序

torch.serialization.get_default_mmap_options()[source][source]

获取 torch.load() 的默认内存映射选项,使用 mmap=True

默认为 mmap.MAP_PRIVATE

Returns

整数

Return type

default_mmap_options

torch.serialization.set_default_mmap_options(flags)[source][source]

上下文管理器或函数,用于设置默认的内存映射选项,适用于 torch.load() 并使用 mmap=True 作为标志。

目前,仅支持 mmap.MAP_PRIVATEmmap.MAP_SHARED。 如果您需要添加其他选项,请提交问题。

注意

此功能目前不支持Windows。

Parameters

标志 (整数) – mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[source][source]

将给定的全局变量标记为可以安全地进行weights_only加载。例如,添加到此列表中的函数可以在反序列化期间调用,类可以实例化并设置状态。

列表中的每一项可以是一个函数/类,或者是一个形式为 (函数/类, 字符串) 的元组,其中字符串是函数/类的完整路径。

在序列化格式中,每个函数都以其完整路径标识为{__module__}.{__name__}。调用此API时,您可以提供应与检查点中的路径匹配的完整路径,否则将使用默认的{fn.__module__}.{fn.__name__}

Parameters

safe_globals (列表[联合[可调用对象, 元组[可调用对象, 字符串]]]) – 标记为安全的全局变量列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[source][source]

清除对 weights_only 加载安全的全局变量列表。

torch.serialization.get_safe_globals()[source][source]

返回用户添加的安全全局变量列表,适用于weights_only加载。

Return type

列表[联合[可调用对象, 元组[可调用对象, 字符串]]]

torch.serialization.get_unsafe_globals_in_checkpoint(f)[source][source]

返回一个字符串列表,其中包含在 torch.save 对象中不安全用于 weights_only 的函数/类。

对于给定的函数或类 f,相应的字符串将具有以下形式 {f.__module__}.{f.__name__}

此函数将返回检查点中未在标记为安全的集合中的任何全局变量 对于 weights_only(通过 add_safe_globals()safe_globals 上下文或 默认情况下由 torch 允许列表)。

注意

此函数将静态地反汇编检查点中的pickle文件。 这意味着在反序列化过程中动态推送到栈上的任何类都不会包含在输出中。

Parameters

f (Union[str, PathLike, BinaryIO, IO[bytes]]) – 类似文件的对象或字符串,包含通过 torch.save 保存的检查点对象

Returns

检查点中未被列入白名单的pickle GLOBAL字符串列表,不允许用于weights_only

Return type

列表[字符串]

class torch.serialization.safe_globals(safe_globals)[source][source]

上下文管理器,将某些全局变量添加为对 weights_only 加载安全。

Parameters

safe_globals (列表[联合[可调用对象, 元组[可调用对象, 字符串]]]) – 仅用于权重加载的全局变量列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
class torch.serialization.skip_data(materialize_fake_tensors=False)[source][source]

上下文管理器,用于跳过为 torch.save 次调用写入存储字节。

存储仍然会被保存,但通常会写入其字节的空间将是空的。然后可以在单独的遍历中填充存储字节。

警告

上下文管理器skip_data是一个早期原型,可能会发生变化。

Parameters

materialize_fake_tensors (bool) – 是否将FakeTensors实例化。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源