序列化语义¶
此说明介绍了如何在Python中保存和加载PyTorch张量和模块状态,以及如何序列化Python模块以便它们可以在C++中加载。
目录
保存和加载张量¶
torch.save() 和 torch.load() 让你可以轻松地保存和加载张量:
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch文件通常使用‘.pt’或‘.pth’扩展名编写。
torch.save() 和 torch.load() 默认使用 Python 的 pickle,
因此你也可以将多个张量作为 Python 对象(如元组、列表和字典)的一部分进行保存:
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
包含PyTorch张量的自定义数据结构也可以保存,只要该数据结构是可pickle的。
保存和加载张量会保留视图¶
保存张量会保留它们的视图关系:
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在幕后,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅 张量视图。
当PyTorch保存张量时,它会分别保存存储对象和张量元数据。这是一个可能在未来发生变化的实现细节,但它通常可以节省空间,并让PyTorch轻松地重建加载张量之间的视图关系。例如,在上述代码片段中,只有一个存储被写入到‘tensors.pt’。
在某些情况下,保存当前的存储对象可能是不必要的,并且会创建非常大的文件。在下面的代码片段中,写入文件的存储比保存的张量大得多:
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
不是只保存small张量中的五个值到‘small.pt’, 而是保存并加载了它与large共享的999个值。
当保存的张量元素少于其存储对象时,可以通过先克隆张量来减少保存文件的大小。克隆一个张量会产生一个新的张量,该张量具有一个新的存储对象,其中仅包含张量中的值:
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
由于克隆的张量彼此独立,因此它们没有任何原始张量所具有的视图关系。如果在保存小于其存储对象的张量时,文件大小和视图关系都很重要,那么在保存之前必须小心地构建新的张量,以最小化其存储对象的大小,同时仍然具有所需的视图关系。
保存和加载torch.nn模块¶
另请参阅:教程:保存和加载模块
在PyTorch中,模块的状态通常使用“状态字典”进行序列化。 模块的状态字典包含其所有参数和持久缓冲区:
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
为了兼容性原因,建议只保存模块的状态字典而不是直接保存模块。Python模块甚至有一个函数,
load_state_dict(),用于从状态字典恢复它们的状态:
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
请注意,状态字典首先从其文件中加载 torch.load()
然后使用 load_state_dict() 恢复状态。
即使是自定义模块和包含其他模块的模块也有状态字典,并且可以使用这种模式:
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
序列化torch.nn.Modules并在C++中加载它们¶
ScriptModules 可以被序列化为 TorchScript 程序,并使用 torch.jit.load() 加载。
这种序列化编码了所有模块的方法、子模块、参数和属性,允许序列化的程序在 C++ 中加载(即无需 Python)。
torch.jit.save() 和 torch.save() 之间的区别可能并不立即明显。torch.save() 使用 pickle 保存 Python 对象。这对于原型设计、研究和训练特别有用。torch.jit.save() 则将 ScriptModules 序列化为一种可以在 Python 或 C++ 中加载的格式。这在保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时非常有用,这是部署 PyTorch 模型时的常见做法。
要编写脚本、序列化和加载Python中的模块:
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )
跟踪的模块也可以使用 torch.jit.save() 保存,但需要注意的是,只有跟踪的代码路径会被序列化。以下示例演示了这一点:
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super(ControlFlowModule, self).__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上述模块包含一个if语句,该语句不会被跟踪的输入触发, 因此不包含在跟踪的模块中,也不会与之序列化。 然而,脚本化的模块包含if语句,并且会与之序列化。 有关脚本化和跟踪的更多信息,请参阅TorchScript文档。
最后,要在C++中加载模块:
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有关如何在C++中使用PyTorch模块的详细信息,请参阅PyTorch C++ API文档。
在不同版本的PyTorch之间保存和加载ScriptModules¶
PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。旧版本的 PyTorch 可能不支持新模块,而新版本可能已移除或修改了旧的行为。这些更改在 PyTorch 的 发行说明 中有明确描述, 依赖于已更改功能的模块可能需要更新才能继续正常工作。在以下有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。
torch.div 执行整数除法¶
在PyTorch 1.5及更早版本中,torch.div() 在给定两个整数输入时会执行向下取整除法:
# PyTorch 1.5 (and earlier)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
在PyTorch 1.7中,torch.div() 将始终对其输入执行真正的除法运算,就像Python 3中的除法一样:
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
PyTorch 1.6 之前版本序列化 ScriptModules 中 torch.div() 的行为得以保留。
也就是说,使用 PyTorch 1.6 之前的版本序列化的 ScriptModules 在加载时,即使使用更新版本的 PyTorch,当给定两个整数输入时,torch.div() 仍然会执行地板除法。
然而,使用 torch.div() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在早期版本的 PyTorch 中加载,因为这些早期版本不理解新的行为。
torch.full 总是推断为浮点数据类型¶
在PyTorch 1.5及更早版本中 torch.full() 总是返回一个浮点数张量,
无论给定的填充值是什么:
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!
在PyTorch 1.7中,torch.full() 将根据填充值推断返回张量的
dtype:
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
torch.full()的行为在序列化的ScriptModules中得以保留。也就是说,使用1.6版本之前的PyTorch序列化的ScriptModules将继续默认返回浮点张量,即使给定布尔值或整数值填充。然而,在PyTorch 1.6及之后版本上序列化的使用torch.full()的ScriptModules无法在早期版本的PyTorch中加载,因为这些早期版本不了解新的行为。