torch.utils.checkpoint¶
注意
检查点通过在反向传播期间为每个检查点段重新运行前向传播段来实现。这可能导致诸如随机数生成器(RNG)状态等持久状态比没有检查点时更进一步。默认情况下,检查点包含逻辑来调整 RNG 状态,以确保使用 RNG(例如通过 dropout)的检查点传递与非检查点传递相比具有确定性输出。存储和恢复 RNG 状态的逻辑可能会根据检查点操作的运行时间带来中等程度的性能损失。如果不需要与非检查点传递相比的确定性输出,请向 preserve_rng_state=False 传入 checkpoint 或 checkpoint_sequential,以跳过每次检查点期间存储和恢复 RNG 状态的操作。
保存和恢复随机数生成器(RNG)状态的机制会为当前设备以及所有cuda张量参数的设备保存并恢复到run_fn。
然而,该机制无法预知用户是否会在run_fn内部将张量移动到新设备上。因此,如果你在run_fn中将张量移动到新设备(“新”指的是不属于[当前设备 + 张量参数的设备]集合的设备),与非检查点传递相比,确定性输出将无法得到保证。
-
torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True, **kwargs)[source]¶ 保存模型或模型的一部分
检查点机制通过用计算换取内存来实现。与存储整个计算图的所有中间激活值以进行反向传播不同,被检查点标记的部分不会保存中间激活值,而是在反向传播过程中重新计算它们。它可以应用于模型的任何部分。
具体来说,在前向传播过程中,
function将以torch.no_grad()的方式运行,即不存储中间 激活值。相反,前向传播会保存输入元组和function参数。在反向传播过程中,会检索保存的输入和function,然后再次在function上执行前向传播,此时会跟踪中间激活值,之后使用这些激活值计算梯度。The output of
functioncan contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd。警告
如果在反向传播过程中,
function的调用与正向传播过程中的调用有任何不同,例如由于某些全局变量,检查点版本将不会等效,并且不幸的是无法检测到这一点。警告
如果指定了
use_reentrant=True,那么如果检查点段包含通过detach()或 torch.no_grad()从计算图中分离的张量,则反向传播将引发错误。这是因为 checkpoint会使所有输出都需要梯度,这会在模型中定义张量不需要梯度时引发问题。 为避免此问题,请在checkpoint函数之外分离张量。请注意,如果指定了use_reentrant=False,检查点段可以包含从计算图中分离的张量。警告
如果指定了
use_reentrant=True,则在需要模型输入的梯度时,至少有一个输入需要具有requires_grad=True,否则模型的检查点部分将没有梯度。同样,至少有一个输出也需要具有requires_grad=True。请注意,如果指定了use_reentrant=False,则此规则不适用。警告
如果指定了
use_reentrant=True,当前仅支持torch.autograd.backward(),并且只有在其inputs参数未传递时才支持。torch.autograd.grad()不被支持。如果指定了use_reentrant=False,checkpointing将与torch.autograd.grad()一起工作。- Parameters
function – 描述在模型或模型的部分的前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递
(activation, hidden),function应该能正确地将第一个输入视为activation,第二个输入视为hiddenpreserve_rng_state (bool, optional, default=True) – 在每次检查点期间省略保存和恢复 随机数生成器状态。
use_reentrant (bool, optional, default=True) – 使用需要可重入自动梯度的检查点实现。 如果指定了
use_reentrant=False,checkpoint将使用不需要可重入自动梯度的实现。这 允许checkpoint支持更多功能,例如 与torch.autograd.grad正常配合使用。请注意,未来的 PyTorch 版本将默认使用use_reentrant=False。args – 包含输入的元组
function
- Returns
在
*args上运行function的输出
-
torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)[source]¶ 用于检查点序列模型的辅助函数。
Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in
torch.no_grad()manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.参见
checkpoint()了解检查点的工作原理。警告
检查点功能目前仅支持
torch.autograd.backward()且仅当其 inputs 参数未传递时才支持。torch.autograd.grad()不被支持。- Parameters
functions – 一个
torch.nn.Sequential或者模块或函数的列表(构成模型),按顺序运行。segments – 在模型中创建的块数
输入 – 一个作为
functions输入的张量preserve_rng_state (bool, optional, default=True) – 在每次检查点期间省略保存和恢复 随机数生成器状态。
- Returns
依次在
*inputs上运行functions的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)