目录

torch.utils.checkpoint

注意

检查点通过在反向传播期间为每个检查点段重新运行前向传播段来实现。这可能导致诸如随机数生成器(RNG)状态等持久状态比没有检查点时更进一步。默认情况下,检查点包含逻辑来调整 RNG 状态,以确保使用 RNG(例如通过 dropout)的检查点传递与非检查点传递相比具有确定性输出。存储和恢复 RNG 状态的逻辑可能会根据检查点操作的运行时间带来中等程度的性能损失。如果不需要与非检查点传递相比的确定性输出,请向 preserve_rng_state=False 传入 checkpointcheckpoint_sequential,以跳过每次检查点期间存储和恢复 RNG 状态的操作。

保存和恢复随机数生成器(RNG)状态的机制会为当前设备以及所有cuda张量参数的设备保存并恢复到run_fn。 然而,该机制无法预知用户是否会在run_fn内部将张量移动到新设备上。因此,如果你在run_fn中将张量移动到新设备(“新”指的是不属于[当前设备 + 张量参数的设备]集合的设备),与非检查点传递相比,确定性输出将无法得到保证。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True, **kwargs)[source]

保存模型或模型的一部分

检查点机制通过用计算换取内存来实现。与存储整个计算图的所有中间激活值以进行反向传播不同,被检查点标记的部分不会保存中间激活值,而是在反向传播过程中重新计算它们。它可以应用于模型的任何部分。

具体来说,在前向传播过程中,function将以 torch.no_grad()的方式运行,即不存储中间 激活值。相反,前向传播会保存输入元组和 function参数。在反向传播过程中,会检索保存的输入和 function,然后再次在 function上执行前向传播,此时会跟踪中间激活值,之后使用这些激活值计算梯度。

The output of function can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd。

警告

如果在反向传播过程中,function 的调用与正向传播过程中的调用有任何不同,例如由于某些全局变量,检查点版本将不会等效,并且不幸的是无法检测到这一点。

警告

如果指定了use_reentrant=True,那么如果检查点段包含通过detach()torch.no_grad()从计算图中分离的张量,则反向传播将引发错误。这是因为 checkpoint会使所有输出都需要梯度,这会在模型中定义张量不需要梯度时引发问题。 为避免此问题,请在checkpoint函数之外分离张量。请注意,如果指定了 use_reentrant=False,检查点段可以包含从计算图中分离的张量。

警告

如果指定了use_reentrant=True,则在需要模型输入的梯度时,至少有一个输入需要具有requires_grad=True,否则模型的检查点部分将没有梯度。同样,至少有一个输出也需要具有requires_grad=True。请注意,如果指定了use_reentrant=False,则此规则不适用。

警告

如果指定了use_reentrant=True,当前仅支持torch.autograd.backward(),并且只有在其inputs参数未传递时才支持。torch.autograd.grad()不被支持。如果指定了use_reentrant=False,checkpointing将与torch.autograd.grad()一起工作。

Parameters:
  • function – 描述在模型或模型的部分的前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递 (activation, hidden), function 应该能正确地将第一个输入视为 activation,第二个输入视为 hidden

  • preserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态. 默认值: True

  • use_reentrant (bool, optional) – 使用需要可重入自动求导的检查点实现。如果指定了 use_reentrant=Falsecheckpoint 将使用不需要可重入自动求导的实现。这使得 checkpoint 可以支持额外的功能,例如与 torch.autograd.grad 一起正常工作,并支持输入到检查点函数中的关键字参数。请注意,PyTorch 的未来版本将默认使用 use_reentrant=False。 默认值:True

  • args – 包含输入的元组 function

Returns:

*args上运行function的输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)[source]

用于检查点序列模型的辅助函数。

Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in torch.no_grad() manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.

参见 checkpoint() 了解检查点的工作原理。

警告

检查点功能目前仅支持 torch.autograd.backward() 且仅当其 inputs 参数未传递时才支持。 torch.autograd.grad() 不被支持。

Parameters:
  • functions – 一个 torch.nn.Sequential 或者模块或函数的列表(构成模型),按顺序运行。

  • segments – 在模型中创建的块数

  • 输入 – 一个作为 functions 输入的张量

  • preserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态. 默认值: True

Returns:

依次在 *inputs 上运行 functions 的输出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源