torch.utils.checkpoint¶
注意
检查点是通过在反向传播期间为每个检查点段重新运行前向传递段来实现的。这可能导致诸如随机数生成器(RNG)状态等持久状态比没有检查点时更先进。默认情况下,检查点包含逻辑来处理RNG状态,使得使用RNG(例如通过dropout)的检查点传递与非检查点传递相比具有确定性输出。存储和恢复RNG状态的逻辑可能会根据检查点操作的运行时间带来中等程度的性能损失。如果不需要与非检查点传递相比的确定性输出,请向preserve_rng_state=False提供checkpoint或checkpoint_sequential以省略在每次检查点期间存储和恢复RNG状态。
存储逻辑会为CPU和其他设备类型(通过_infer_device_type从张量参数中排除CPU张量推断出设备类型)保存和恢复RNG状态到run_fn。如果有多个设备,设备状态将只为一种设备类型的设备保存,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能导致梯度不正确。(注意,如果检测到CUDA设备,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有CPU张量,则默认设备类型状态(默认值是cuda,也可以通过DefaultDeviceType设置为其他设备)将被保存和恢复。然而,该逻辑无法预知用户是否会在run_fn本身内将张量移动到新设备。因此,如果你在run_fn中将张量移动到一个新设备(“新”是指不属于[当前设备 + 张量参数的设备]集合的设备),则与非检查点传递相比,确定性输出永远无法保证。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source][source]¶
保存模型或模型的一部分。
激活检查点(Activation checkpointing)是一种以计算换取内存的技术。 通常,在反向传播过程中,会保留用于梯度计算的张量直到它们被使用。而采用检查点技术的区域在正向传播时会省略保存这些张量,转而在反向传播过程中重新计算它们。这种激活检查点技术可以应用于模型的任何部分。
目前有两种检查点实现方式,由
use_reentrant参数决定。建议您使用use_reentrant=False。有关它们之间差异的讨论,请参阅下面的说明。警告
如果在反向传播过程中调用的
function与正向传播过程不同,例如由于全局变量的影响,检查点版本可能不等价,这可能会引发错误或导致梯度静默地不正确。警告
参数
use_reentrant应该显式传递。在版本 2.4 中,如果未传递use_reentrant,我们将引发异常。 如果您使用的是use_reentrant=True变体,请参阅下面的注意事项和潜在限制。注意
可重入的 checkpoint 变体 (
use_reentrant=True) 和 不可重入的 checkpoint 变体 (use_reentrant=False) 在以下方面有所不同:非可重入检查点在所有需要的中间激活值被重新计算后立即停止重新计算。此功能默认启用,但可以通过
set_checkpoint_early_stop()禁用。可重入检查点在反向传递过程中始终完全重新计算function。可重入变体在正向传播过程中不会记录自动求导图,因为它在
torch.no_grad()的上下文中运行。不可重入版本会记录自动求导图,允许在检查点区域内对图执行反向传播。可重入检查点仅支持不带其 inputs 参数的反向传递的
torch.autograd.backward()API,而非可重入版本支持所有执行反向传递的方式。至少有一个输入和输出必须为
requires_grad=True的重入变体。如果未满足此条件,模型中被检查点的部分将不会有梯度。非重入版本没有这个要求。可重入版本不认为嵌套结构(例如,自定义对象、列表、字典等)中的张量参与自动求导,而非可重入版本则认为它们参与。
可重入检查点不支持计算图中从计算图分离的张量的检查点区域,而非可重入版本则支持。对于可重入变体,如果检查点段包含使用
detach()或 通过torch.no_grad()分离的张量,反向传播将引发错误。 这是因为checkpoint会使所有输出需要梯度,并且当模型中定义某个张量不需要梯度时会导致问题。为了避免这种情况,请在checkpoint函数外部分离张量。
- Parameters
function – 描述在模型或模型的部分的前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递
(activation, hidden),function应该能正确地将第一个输入视为activation,第二个输入视为hiddenpreserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态。请注意,在 torch.compile 下,此标志无效,我们始终保留 RNG 状态。 默认值:
Trueuse_reentrant (bool) – 指定是否使用需要可重入自动求导的激活检查点变体。此参数应显式传递。在 2.5 版本中,如果未传递
use_reentrant,我们将引发异常。如果use_reentrant=False,checkpoint将使用不需要可重入自动求导的实现。这使得checkpoint可以支持额外的功能,例如与torch.autograd.grad正常工作以及支持关键字参数输入到检查点函数中。context_fn (Callable, optional) – 一个可调用对象,返回两个上下文管理器的元组。该函数及其重新计算将分别在第一个和第二个上下文管理器下运行。 如果
use_reentrant=False,则支持此参数。determinism_check (str, optional) – 一个字符串,指定要执行的确定性检查。默认情况下,它被设置为
"default", 该设置会将重新计算的张量的形状、数据类型和设备与已保存张量进行比较。若要关闭此检查,请指定"none"。目前仅支持这两个值。如果您希望看到更多的确定性检查功能,请提交问题报告。这个参数仅在use_reentrant=False时受支持,如果use_reentrant=True,则确定性检查始终被禁用。调试 (布尔值, 可选) – 如果为
True,错误消息还将包括 原始前向计算过程中运行的操作符的跟踪信息以及重新计算的信息。此参数仅在use_reentrant=False时支持。args – 包含输入的元组
function
- Returns
在
*args上运行function的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source][source]¶
通过检查点保存顺序模型以节省内存。
Sequential 模型按顺序执行一系列模块/函数。因此,我们可以将此类模型划分为多个段,并对每个段进行检查点处理。除了最后一个段之外,其他所有段都不会存储中间激活值。每个被检查点处理的段的输入将会被保存,以便在反向传播过程中重新运行该段。
警告
The
use_reentrantparameter should be passed explicitly. In version 2.4 we will raise an exception ifuse_reentrantis not passed. If you are using theuse_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False.- Parameters
functions – 一个
torch.nn.Sequential或者模块或函数的列表(构成模型),按顺序运行。segments – 在模型中创建的块数
输入 – 一个作为
functions输入的张量preserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态. 默认值:
Trueuse_reentrant (bool) – 指定是否使用需要可重入自动求导的激活检查点变体。此参数应显式传递。在 2.5 版本中,如果未传递
use_reentrant,我们将引发异常。如果use_reentrant=False,checkpoint将使用不需要可重入自动求导的实现。这使得checkpoint可以支持额外的功能,例如与torch.autograd.grad正常工作以及支持关键字参数输入到检查点函数中。
- Returns
依次在
*inputs上运行functions的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source][source]¶
上下文管理器,用于设置在运行时检查点是否应打印额外的调试信息。有关
debug标志的更多信息,请参阅checkpoint()。请注意,当设置此上下文管理器时,它将覆盖传递给检查点的debug的值。若要使用本地设置,请向此上下文传递None。- Parameters
启用 (布尔值) – 是否应打印检查点的调试信息。 默认值为‘无’。
- class torch.utils.checkpoint.CheckpointPolicy(value)[source][source]¶
用于指定反向传播期间检查点策略的枚举。
支持以下策略:
{MUST,PREFER}_SAVE: 操作的输出将在前向传递期间保存,并且在反向传递期间不会重新计算{MUST,PREFER}_RECOMPUTE: 在正向传播过程中,操作的输出将不会被保存,并将在反向传播过程中重新计算
使用
MUST_*而不是PREFER_*来表示该策略不应被其他子系统如 torch.compile 覆盖。注意
一个始终返回
PREFER_RECOMPUTE的策略函数等同于普通检查点。一个策略函数,它在每次操作时都返回
PREFER_SAVE并不等同于不使用检查点。使用这样的策略将保存额外的张量,这些张量不仅限于梯度计算实际所需的张量。
- class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source][source]¶
在选择性检查点期间传递给策略函数的上下文。
此类用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括当前调用策略函数是否处于重新计算过程中。
示例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source][source]¶
在激活检查点期间避免重新计算某些操作的辅助工具。
使用此功能与 torch.utils.checkpoint.checkpoint 来控制在反向传播过程中哪些操作会被重新计算。
- Parameters
policy_fn_or_list (Callable 或 List) –
如果提供了策略函数,它应该接受一个
SelectiveCheckpointContext,OpOverload,传递给操作的 args 和 kwargs,并返回一个CheckpointPolicy枚举值, 表示该操作的执行是否需要重新计算。如果提供了一系列表示操作的列表,这相当于一个策略, 对于指定的操作返回 CheckpointPolicy.MUST_SAVE,而对于所有其他操作返回 CheckpointPolicy.PREFER_RECOMPUTE。
allow_cache_entry_mutation (bool, optional) – 默认情况下,如果通过选择性激活检查点缓存的任何张量被修改,则会引发错误,以确保正确性。如果设置为 True,则禁用此检查。
- Returns
一个包含两个上下文管理器的元组。
示例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )