torch.utils.checkpoint¶
注意
检查点是通过重新运行
向后传播期间的每个 checkpointed 段。这可能会导致持续
像 RNG 州这样的州比没有 RNG 州更先进
检查点。默认情况下,checkpointing 包括 juggle 逻辑
RNG 状态,以便使用 RNG 的检查点传递
(例如,通过 dropout)具有确定性输出,如
与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑
RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时
的检查点操作。如果确定性输出与
非检查点通行证不是必需的,提供或省略储藏和
在每个检查点期间恢复 RNG 状态。preserve_rng_state=False
checkpoint
checkpoint_sequential
存储逻辑保存和恢复 CPU 和另一个
设备类型(从 Tensor 参数推断设备类型,不包括 CPU
张量 ) 设置为 .如果有多个
device,则仅保存单个设备类型的设备的设备状态,
其余设备将被忽略。因此,如果任何 checkpoint
函数涉及随机性,这可能会导致 gradient 不正确。(注意
如果 CUDA 设备在检测到的设备中,则它将被优先考虑;
否则,将选择遇到的第一个设备。如果没有
CPU-tensors,默认设备类型状态(默认值为 cuda,它
可以设置为 Other Device 的 ) 将被保存和恢复。
但是,该逻辑无法预测用户是否会移动
Tensors 添加到自身内的新设备。因此,如果您移动
Tensors 添加到新设备(“new”表示不属于
[当前设备 + Tensor 参数的设备])within , 确定性
与非 checkpoint 传递相比,永远无法保证输出。_infer_device_type
run_fn
DefaultDeviceType
run_fn
run_fn
- torch.utils.checkpoint。checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[来源]¶
对模型或模型的一部分进行 Checkpoint 操作。
激活检查点是一种用计算换取内存的技术。 而不是保持 backward 所需的张量处于活动状态,直到它们在 Checkpointed 中 backward 和 forward 计算期间的梯度计算 regions 省略了保存张量以供向后,并在 向后传递。激活检查点可以应用于 型。
目前有两种可用的 checkpointing 实现,determined by 参数。建议您使用 .请参阅下面的注释,了解 他们的差异。
use_reentrant
use_reentrant=False
警告
如果向后传递期间的调用不同 从前向传递中,例如,由于全局变量,checkpointed version 可能不等效,这可能会导致 错误或导致 Gradient 错误。
function
警告
应显式传递该参数。在版本中 2.4 如果未传递,我们将引发异常。 如果您使用的是变体,请参阅 请注意以下重要注意事项和潜在限制。
use_reentrant
use_reentrant
use_reentrant=True
注意
checkpoint () 和 checkpoint () 的不可重入变体 在以下方面有所不同:
use_reentrant=True
use_reentrant=False
不可重入 checkpoint 在需要时立即停止重新计算 已重新计算中间激活数。此功能已启用 默认情况下,但可以使用 . Reentrant 检查点始终在其 整个过程。
set_checkpoint_early_stop()
function
reentrant 变体在 forward pass,因为它与 下的 forward pass 一起
运行 。不可重入版本会记录 autograd graph 中执行,允许在 检查点区域。
可重入检查点仅支持不带 inputs 参数的向后传递 API,而不可重入版本支持所有方式
执行向后传递。
至少一个输入和输出必须具有 reentrant 变体。如果未满足此条件,则 checkpoint 部分 的模型将没有梯度。不可重入版本执行 没有此要求。
requires_grad=True
可重入版本不考虑嵌套结构中的张量 (例如,自定义对象、列表、字典等)作为参与 autograd 的 AUTOGRAD 版本,而 non-reentrant 版本则适用。
可重入检查点不支持 从计算图中分离的张量,而 non-reentrant 版本会。对于可重入变体,如果 Checkpointed Segment 包含使用 OR 分离的张量 使用
时,向后传递将引发错误。 这是因为 make all the output require gradients 当张量被定义为没有 gradient 时,这会导致问题 模型。为避免这种情况,请将张量分离到函数外部。
detach()
checkpoint
checkpoint
- 参数
function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS
(activation, hidden)
function
activation
hidden
preserve_rng_state (bool, optional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。请注意,在 torch.compile 下, 此标志不会生效,我们始终保留 RNG 状态。 违约:
True
use_reentrant (bool) – 指定是否使用满足 需要可重入 autograd。应传递此参数 明确地。在版本 2.5 中,如果未传递,我们将引发异常。如果 ,将使用不需要 reentrant autograd.这允许支持额外的 功能,例如按预期工作和支持输入到 checkpointed 函数。
use_reentrant
use_reentrant=False
checkpoint
checkpoint
torch.autograd.grad
context_fn (Callable, optional) – 返回 2 元组的可调用对象 上下文管理器。该函数及其重新计算将运行 分别在 First 和 Second Context Manager 下。 仅当 .
use_reentrant=False
determinism_check (str, optional) – 指定确定性的字符串 检查以执行。默认情况下,它被设置为 比较重新计算的张量的形状、数据类型和设备 针对那些保存的张量。要关闭此检查,请指定 。目前,这是仅有的两个受支持的值。 如果您想看到更多的确定性,请打开一个 issue 检查。仅当 , 如果 ,则始终禁用确定性检查。
"default"
"none"
use_reentrant=False
use_reentrant=True
debug (bool, optional) – 如果 ,错误消息还将包括 在原始前向计算期间运行的运算符的跟踪 以及重新计算。仅当 .
True
use_reentrant=False
args – 包含
function
- 返回
运行 on 的输出
function
*args
- torch.utils.checkpoint。checkpoint_sequential(函数、段、输入、use_reentrant=无、**kwargs)[来源]¶
对 sequential 模型进行 checkpoint 以节省内存。
顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 segment 之外的所有 segment 都不会存储 中间激活。每个 checkpointed segment 的 inputs 将 保存以在 Backward Pass 中重新运行该段落。
警告
应显式传递该参数。在版本中 2.4 如果未传递,我们将引发异常。 如果您使用的是 .
use_reentrant
use_reentrant
use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False
- 参数
segments – 要在模型中创建的块数
input – 一个 Tensor 输入到
functions
preserve_rng_state (bool, optional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。 违约:
True
use_reentrant (bool) – 指定是否使用满足 需要可重入 autograd。应传递此参数 明确地。在版本 2.5 中,如果未传递,我们将引发异常。如果 ,将使用不需要 reentrant autograd.这允许支持额外的 功能,例如按预期工作和支持输入到 checkpointed 函数。
use_reentrant
use_reentrant=False
checkpoint
checkpoint
torch.autograd.grad
- 返回
顺序运行的输出
functions
*inputs
例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint。set_checkpoint_debug_enabled(启用)[来源]¶
设置 checkpoint 是否应打印其他调试的上下文管理器 信息。有关更多信息,请参阅 flag
。请注意, 设置后,此上下文管理器将覆盖 Passed to 的值 检查站。要遵循本地设置,请传递到此上下文。
debug
debug
None
- 参数
enabled (bool) - 检查点是否应打印调试信息。 默认值为 'None'。
- 类 torch.utils.checkpoint。CheckpointPolicy(value)[来源]¶
用于指定反向传播期间 checkpointing 策略的枚举。
支持以下策略:
{MUST,PREFER}_SAVE
:在转发过程中将保存操作的输出 pass 的 API 中,并且在 backward pass 期间不会重新计算{MUST,PREFER}_RECOMPUTE
:在 forward pass 的 cookie 和 x 的 GO ,并将在 BACKWARD PASS 期间重新计算
使用 over 指示不应覆盖策略 通过其他子系统(如 torch.compile)进行。
MUST_*
PREFER_*
注意
始终返回的策略函数是 相当于原版 Checkpointing。
PREFER_RECOMPUTE
返回每个运算的 policy 函数是 不等同于不使用检查点。使用此类策略将 保存额外的张量,不限于实际需要的张量 梯度计算。
PREFER_SAVE
- 类 torch.utils.checkpoint。SelectiveCheckpointContext(*, is_recompute)[来源]¶
在选择性检查点期间传递给策略函数的上下文。
此类用于在 选择性检查点。元数据包括当前调用 的 policy 函数是否在重新计算期间。
例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint。create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[来源]¶
Helper 避免在激活 checkpointing 期间重新计算某些操作。
将其与 torch.utils.checkpoint.checkpoint 一起使用来控制哪个 操作在 Backward Pass 期间重新计算。
- 参数
policy_fn_or_list (Callable 或 List) –
allow_cache_entry_mutation (bool,可选) – 默认情况下,错误为 如果选择性激活检查点缓存的任何张量为 changed 以确保正确性。如果设置为 True,则此检查 已禁用。
- 返回
两个上下文管理器的元组。
例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )