目录

torch.utils.checkpoint

注意

检查点是通过在反向传播期间为每个检查点段重新运行前向传递段来实现的。这可能导致诸如随机数生成器(RNG)状态等持久状态比没有检查点时更先进。默认情况下,检查点包含逻辑来处理RNG状态,使得使用RNG(例如通过dropout)的检查点传递与非检查点传递相比具有确定性输出。存储和恢复RNG状态的逻辑可能会根据检查点操作的运行时间带来中等程度的性能损失。如果不需要与非检查点传递相比的确定性输出,请向preserve_rng_state=False提供checkpointcheckpoint_sequential以省略在每次检查点期间存储和恢复RNG状态。

存储逻辑会为CPU和其他设备类型(通过_infer_device_type从张量参数中排除CPU张量推断出设备类型)保存和恢复RNG状态到run_fn。如果有多个设备,设备状态将只为一种设备类型的设备保存,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能导致梯度不正确。(注意,如果检测到CUDA设备,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有CPU张量,则默认设备类型状态(默认值是cuda,也可以通过DefaultDeviceType设置为其他设备)将被保存和恢复。然而,该逻辑无法预知用户是否会在run_fn本身内将张量移动到新设备。因此,如果你在run_fn中将张量移动到一个新设备(“新”是指不属于[当前设备 + 张量参数的设备]集合的设备),则与非检查点传递相比,确定性输出永远无法保证。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]

保存模型或模型的一部分。

激活检查点(Activation checkpointing)是一种以计算换取内存的技术。 通常,在反向传播过程中,会保留用于梯度计算的张量直到它们被使用。而采用检查点技术的区域在正向传播时会省略保存这些张量,转而在反向传播过程中重新计算它们。这种激活检查点技术可以应用于模型的任何部分。

目前有两种检查点实现方式,由use_reentrant参数决定。建议您使用use_reentrant=False。有关它们之间差异的讨论,请参阅下面的说明。

警告

如果在反向传播过程中调用的function与正向传播过程不同,例如由于全局变量的影响,检查点版本可能不等价,这可能会引发错误或导致梯度静默地不正确。

警告

参数 use_reentrant 应该显式传递。在版本 2.4 中,如果未传递 use_reentrant,我们将引发异常。 如果您使用的是 use_reentrant=True 变体,请参阅下面的注意事项和潜在限制。

注意

可重入的 checkpoint 变体 (use_reentrant=True) 和 不可重入的 checkpoint 变体 (use_reentrant=False) 在以下方面有所不同:

  • 非可重入检查点在所有需要的中间激活值被重新计算后立即停止重新计算。此功能默认启用,但可以通过 set_checkpoint_early_stop() 禁用。可重入检查点在反向传递过程中始终完全重新计算 function

  • 可重入变体在正向传播过程中不会记录自动求导图,因为它在 torch.no_grad() 的上下文中运行。不可重入版本会记录自动求导图,允许在检查点区域内对图执行反向传播。

  • 可重入检查点仅支持不带其 inputs 参数的反向传递的 torch.autograd.backward() API,而非可重入版本支持所有执行反向传递的方式。

  • 至少有一个输入和输出必须为 requires_grad=True 的重入变体。如果未满足此条件,模型中被检查点的部分将不会有梯度。非重入版本没有这个要求。

  • 可重入版本不认为嵌套结构(例如,自定义对象、列表、字典等)中的张量参与自动求导,而非可重入版本则认为它们参与。

  • 可重入检查点不支持计算图中从计算图分离的张量的检查点区域,而非可重入版本则支持。对于可重入变体,如果检查点段包含使用 detach() 或 通过 torch.no_grad() 分离的张量,反向传播将引发错误。 这是因为 checkpoint 会使所有输出需要梯度,并且当模型中定义某个张量不需要梯度时会导致问题。为了避免这种情况,请在 checkpoint 函数外部分离张量。

Parameters
  • function – 描述在模型或模型的部分的前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递 (activation, hidden), function 应该能正确地将第一个输入视为 activation,第二个输入视为 hidden

  • preserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态。请注意,在 torch.compile 下,此标志无效,我们始终保留 RNG 状态。 默认值: True

  • use_reentrant (bool) – 指定是否使用需要可重入自动梯度的激活检查点变体。此参数应显式传递。在版本 2.4 中,如果我们未传入 use_reentrant,将引发异常。如果传入 use_reentrant=Falsecheckpoint 将使用不需要可重入自动梯度的实现方式。这使得 checkpoint 可以支持额外的功能,例如与 torch.autograd.grad 正常工作,并支持关键字参数输入到检查点函数中。

  • context_fn (Callable, optional) – 一个可调用对象,返回两个上下文管理器的元组。该函数及其重新计算将分别在第一个和第二个上下文管理器下运行。 如果 use_reentrant=False,则支持此参数。

  • determinism_check (str, optional) – 一个字符串,指定要执行的确定性检查。默认情况下,它被设置为 "default", 该设置会将重新计算的张量的形状、数据类型和设备与已保存张量进行比较。若要关闭此检查,请指定 "none"。目前仅支持这两个值。如果您希望看到更多的确定性检查功能,请提交问题报告。这个参数仅在 use_reentrant=False 时受支持,如果 use_reentrant=True,则确定性检查始终被禁用。

  • 调试 (布尔值, 可选) – 如果为 True,错误消息还将包括 原始前向计算过程中运行的操作符的跟踪信息以及重新计算的信息。此参数仅在 use_reentrant=False 时支持。

  • args – 包含输入的元组 function

Returns

*args上运行function的输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]

通过检查点保存顺序模型以节省内存。

Sequential 模型按顺序执行一系列模块/函数。因此,我们可以将此类模型划分为多个段,并对每个段进行检查点处理。除了最后一个段之外,其他所有段都不会存储中间激活值。每个被检查点处理的段的输入将会被保存,以便在反向传播过程中重新运行该段。

警告

The use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. If you are using the use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False.

Parameters
  • functions – 一个 torch.nn.Sequential 或者模块或函数的列表(构成模型),按顺序运行。

  • segments – 在模型中创建的块数

  • 输入 – 一个作为 functions 输入的张量

  • preserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态. 默认值: True

  • use_reentrant (bool) – 指定是否使用需要可重入自动梯度的激活检查点变体。此参数应显式传递。在版本 2.4 中,如果我们未传入 use_reentrant,将引发异常。如果传入 use_reentrant=Falsecheckpoint 将使用不需要可重入自动梯度的实现方式。这使得 checkpoint 可以支持额外的功能,例如与 torch.autograd.grad 正常工作,并支持关键字参数输入到检查点函数中。

Returns

依次在 *inputs 上运行 functions 的输出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]

上下文管理器,用于设置在运行时检查点是否应打印额外的调试信息。有关 debug 标志的更多信息,请参阅 checkpoint()。请注意,当设置此上下文管理器时,它将覆盖传递给检查点的 debug 的值。若要使用本地设置,请向此上下文传递 None

Parameters

启用 (布尔值) – 是否应打印检查点的调试信息。 默认值为‘无’。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源