目录

序列化语义

本说明介绍如何保存和加载 PyTorch 张量和模块状态 以及如何序列化 Python 模块以便可以用 C++ 加载它们。

保存和加载张量

让您轻松保存和加载张量:

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照惯例,PyTorch 文件通常使用“.pt”或“.pth”扩展名编写。

默认使用 Python 的 pickle, 因此,您还可以将多个张量保存为 Python 对象(如元组)的一部分, lists 和 dicts:

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

如果 数据结构是可 pickle 的。

保存和加载张量会保留视图

保存张量会保留它们的视图关系:

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕后,这些张量共享相同的 “存储空间”。有关更多信息,请参阅 Tensor 视图 在 views 和 storage 上。

当 PyTorch 保存 tensor 时,它会保存它们的存储对象和张量 元数据。这是一个实现细节,在 future,但它通常会节省空间并允许 PyTorch 轻松 重建加载的 Tensor 之间的视图关系。在上述 snippet 中,只有一个存储被写入 'tensors.pt' 中。

但是,在某些情况下,可能不需要保存当前存储对象 并创建大得令人望而却步的文件。在下面的代码片段中,存储很多 大于保存的张量写入文件:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

而不是仅将张量中的 5 个值保存到 'small.pt' 保存并加载了它与 large 共享的 Storage 中的 999 个值。

当保存元素少于其存储对象的张量时, 可以通过首先克隆张量来减少保存的文件。克隆张量 生成一个新 Tensor,其中包含仅包含值的新 Storage 对象 在 Tensor 中:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

但是,由于克隆的张量彼此独立,因此它们具有 原始张量没有 view relationships 。如果文件大小和 在保存小于其 storage 对象,则必须小心构造新的 Tensor,以最小化 其存储对象的大小,但仍具有所需的视图关系 在保存之前。

保存和加载 torch.nn.Modules

另请参阅:教程:保存和加载模块

在 PyTorch 中,模块的状态经常使用“状态字典”进行序列化。 模块的状态字典包含其所有参数和持久缓冲区:

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出于兼容性原因,建议不要直接保存模块 改为仅保存其 state dict。Python 模块甚至有一个函数 ,可以从 state dict 恢复它们的状态:

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

请注意,状态 dict 首先从其文件中加载,然后使用 .

即使是自定义模块和包含其他模块的模块也有 state dicts 和 可以使用以下模式:

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super(MyModule, self).__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

序列化 torch.nn.Modules 并在 C++ 中加载它们

Смотритетакже: 教程:在 C++ 中加载 TorchScript 模型

ScriptModule 可以序列化为 TorchScript 程序并加载 使用 . 这种序列化对所有模块的方法、子模块、参数、 和属性,它允许用 C++ 加载序列化程序 (即没有 Python)。

may not 之间的区别 立即明确。使用 pickle 保存 Python 对象。 这对于原型设计、研究和培训特别有用。另一方面,将 ScriptModules 序列化为格式 可以用 Python 或 C++ 加载。这在保存和加载 C++ 时非常有用 模块,或者用于运行使用 C++ 在 Python 中训练的模块,这是一种常见的做法 部署 PyTorch 模型时。

要在 Python 中编写脚本、序列化和加载模块:

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

跟踪的模块也可以用 保存,但需要注意 仅序列化跟踪的代码路径。以下示例演示 这:

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super(ControlFlowModule, self).__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
        return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上面的模块有一个 if 语句,它不是由跟踪的输入触发的, ,因此不是 traced 模块的一部分,也不与它一起序列化。 但是,脚本化模块包含 if 语句并用它进行序列化。 有关脚本编写和跟踪的更多信息,请参阅 TorchScript 文档

最后,要在 C++ 中加载模块:

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅 PyTorch C++ API 文档

跨 PyTorch 版本保存和加载 ScriptModule

PyTorch 团队建议保存和加载具有相同版本的 PyTorch 的 Torch 中。旧版本的 PyTorch 可能不支持较新的模块,而较新的 版本可能已删除或修改较旧的行为。这些变化是 在 PyTorch 的发行说明, 依赖于已更改功能的模块可能需要更新 以继续正常工作。在有限的情况下,PyTorch 将 保留序列化 ScriptModule 的历史行为,因此它们不需要 更新。

torch.div 执行整数除法

在 PyTorch 1.5 及更早版本中,当 给定两个整数输入:

# PyTorch 1.5 (and earlier)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

但是,在 PyTorch 1.7 中,将始终执行真正的除法 的输入,就像 Python 3 中的 division 一样:

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

的行为保留在序列化的 ScriptModules 中。 也就是说,使用 1.6 之前的 PyTorch 版本序列化的 ScriptModule 将继续 查看在给定两个整数输入时执行向下取整除法 即使加载了较新版本的 PyTorch 也是如此。在 PyTorch 1.6 及更高版本上使用和序列化的 ScriptModule 无法加载到早期版本的 但是,PyTorch 的 API 版本无法理解新行为。

torch.full 始终推断 float dtype

在 PyTorch 1.5 及更早版本中,始终返回浮点张量, 无论给定的 fill 值如何:

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

但是,在 PyTorch 1.7 中,将推断返回的张量的 dtype 从 fill 值获取:

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

的行为保留在序列化的 ScriptModules 中。那是 使用 1.6 之前的 PyTorch 版本序列化的 ScriptModule 将继续看到 torch.full 默认返回浮点张量,即使给定 bool 或 整数填充值。ScriptModules using and serialized on PyTorch 1.6 及更高版本无法加载到早期版本的 PyTorch 中,因为这些 早期版本不理解新行为。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源