目录

torch.utils.checkpoint

注意

检查点是通过重新运行 向后传播期间的每个 checkpointed 段。这可能会导致持续 像 RNG 州这样的州比没有 RNG 州更先进 检查点。默认情况下,checkpointing 包括 juggle 逻辑 RNG 状态,以便使用 RNG 的检查点传递 (例如,通过 dropout)具有确定性输出,如 与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑 RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时 的检查点操作。如果确定性输出与 非检查点通行证不是必需的,提供或省略储藏和 在每个检查点期间恢复 RNG 状态。preserve_rng_state=Falsecheckpointcheckpoint_sequential

存储逻辑保存和恢复 CPU 和另一个 设备类型(从 Tensor 参数推断设备类型,不包括 CPU 张量 ) 设置为 .如果有多个 device,则仅保存单个设备类型的设备的设备状态, 其余设备将被忽略。因此,如果任何 checkpoint 函数涉及随机性,这可能会导致 gradient 不正确。(注意 如果 CUDA 设备在检测到的设备中,则它将被优先考虑; 否则,将选择遇到的第一个设备。如果没有 CPU-tensors,默认设备类型状态(默认值为 cuda,它 可以设置为 Other Device 的 ) 将被保存和恢复。 但是,该逻辑无法预测用户是否会移动 Tensors 添加到自身内的新设备。因此,如果您移动 Tensors 添加到新设备(“new”表示不属于 [当前设备 + Tensor 参数的设备])within , 确定性 与非 checkpoint 传递相比,永远无法保证输出。_infer_device_typerun_fnDefaultDeviceTyperun_fnrun_fn

torch.utils.checkpoint。checkpointfunction*argsuse_reentrant=Nonecontext_fn=<function noop_context_fn>determinism_check='default'debug=False**kwargs[来源]

对模型或模型的一部分进行 Checkpoint 操作。

激活检查点是一种用计算换取内存的技术。 而不是保持 backward 所需的张量处于活动状态,直到它们在 Checkpointed 中 backward 和 forward 计算期间的梯度计算 regions 省略了保存张量以供向后,并在 向后传递。激活检查点可以应用于 型。

目前有两种可用的 checkpointing 实现,determined by 参数。建议您使用 .请参阅下面的注释,了解 他们的差异。use_reentrantuse_reentrant=False

警告

如果向后传递期间的调用不同 从前向传递中,例如,由于全局变量,checkpointed version 可能不等效,这可能会导致 错误或导致 Gradient 错误。function

警告

应显式传递该参数。在版本中 2.4 如果未传递,我们将引发异常。 如果您使用的是变体,请参阅 请注意以下重要注意事项和潜在限制。use_reentrantuse_reentrantuse_reentrant=True

注意

checkpoint () 和 checkpoint () 的不可重入变体 在以下方面有所不同:use_reentrant=Trueuse_reentrant=False

  • 不可重入 checkpoint 在需要时立即停止重新计算 已重新计算中间激活数。此功能已启用 默认情况下,但可以使用 . Reentrant 检查点始终在其 整个过程。set_checkpoint_early_stop()function

  • reentrant 变体在 forward pass,因为它与 下的 forward pass 一起运行 。不可重入版本会记录 autograd graph 中执行,允许在 检查点区域。

  • 可重入检查点仅支持不带 inputs 参数的向后传递 API,而不可重入版本支持所有方式 执行向后传递。

  • 至少一个输入和输出必须具有 reentrant 变体。如果未满足此条件,则 checkpoint 部分 的模型将没有梯度。不可重入版本执行 没有此要求。requires_grad=True

  • 可重入版本不考虑嵌套结构中的张量 (例如,自定义对象、列表、字典等)作为参与 autograd 的 AUTOGRAD 版本,而 non-reentrant 版本则适用。

  • 可重入检查点不支持 从计算图中分离的张量,而 non-reentrant 版本会。对于可重入变体,如果 Checkpointed Segment 包含使用 OR 分离的张量 使用 时,向后传递将引发错误。 这是因为 make all the output require gradients 当张量被定义为没有 gradient 时,这会导致问题 模型。为避免这种情况,请将张量分离到函数外部。detach()checkpointcheckpoint

参数
  • function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS(activation, hidden)functionactivationhidden

  • preserve_rng_statebooloptional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。请注意,在 torch.compile 下, 此标志不会生效,我们始终保留 RNG 状态。 违约:True

  • use_reentrantbool) – 指定是否使用满足 需要可重入 autograd。应传递此参数 明确地。在版本 2.5 中,如果未传递,我们将引发异常。如果 ,将使用不需要 reentrant autograd.这允许支持额外的 功能,例如按预期工作和支持输入到 checkpointed 函数。use_reentrantuse_reentrant=Falsecheckpointcheckpointtorch.autograd.grad

  • context_fnCallableoptional) – 返回 2 元组的可调用对象 上下文管理器。该函数及其重新计算将运行 分别在 First 和 Second Context Manager 下。 仅当 .use_reentrant=False

  • determinism_checkstroptional) – 指定确定性的字符串 检查以执行。默认情况下,它被设置为 比较重新计算的张量的形状、数据类型和设备 针对那些保存的张量。要关闭此检查,请指定 。目前,这是仅有的两个受支持的值。 如果您想看到更多的确定性,请打开一个 issue 检查。仅当 , 如果 ,则始终禁用确定性检查。"default""none"use_reentrant=Falseuse_reentrant=True

  • debugbooloptional) – 如果 ,错误消息还将包括 在原始前向计算期间运行的运算符的跟踪 以及重新计算。仅当 .Trueuse_reentrant=False

  • args – 包含function

返回

运行 on 的输出function*args

torch.utils.checkpoint。checkpoint_sequential函数输入use_reentrant=**kwargs[来源]

对 sequential 模型进行 checkpoint 以节省内存。

顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 segment 之外的所有 segment 都不会存储 中间激活。每个 checkpointed segment 的 inputs 将 保存以在 Backward Pass 中重新运行该段落。

警告

应显式传递该参数。在版本中 2.4 如果未传递,我们将引发异常。 如果您使用的是 .use_reentrantuse_reentrantuse_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False

参数
  • functions – A 或模块列表或 函数(组成模型)按顺序运行。

  • segments – 要在模型中创建的块数

  • input – 一个 Tensor 输入到functions

  • preserve_rng_statebooloptional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。 违约:True

  • use_reentrantbool) – 指定是否使用满足 需要可重入 autograd。应传递此参数 明确地。在版本 2.5 中,如果未传递,我们将引发异常。如果 ,将使用不需要 reentrant autograd.这允许支持额外的 功能,例如按预期工作和支持输入到 checkpointed 函数。use_reentrantuse_reentrant=Falsecheckpointcheckpointtorch.autograd.grad

返回

顺序运行的输出functions*inputs

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint。set_checkpoint_debug_enabled启用[来源]

设置 checkpoint 是否应打印其他调试的上下文管理器 信息。有关更多信息,请参阅 flag 。请注意, 设置后,此上下文管理器将覆盖 Passed to 的值 检查站。要遵循本地设置,请传递到此上下文。debugdebugNone

参数

enabledbool) - 检查点是否应打印调试信息。 默认值为 'None'。

torch.utils.checkpoint。CheckpointPolicyvalue[来源]

用于指定反向传播期间 checkpointing 策略的枚举。

支持以下策略:

  • {MUST,PREFER}_SAVE:在转发过程中将保存操作的输出 pass 的 API 中,并且在 backward pass 期间不会重新计算

  • {MUST,PREFER}_RECOMPUTE:在 forward pass 的 cookie 和 x 的 GO ,并将在 BACKWARD PASS 期间重新计算

使用 over 指示不应覆盖策略 通过其他子系统(如 torch.compile)进行。MUST_*PREFER_*

注意

始终返回的策略函数是 相当于原版 Checkpointing。PREFER_RECOMPUTE

返回每个运算的 policy 函数是 不等同于不使用检查点。使用此类策略将 保存额外的张量,不限于实际需要的张量 梯度计算。PREFER_SAVE

torch.utils.checkpoint。SelectiveCheckpointContext*is_recompute[来源]

在选择性检查点期间传递给策略函数的上下文。

此类用于在 选择性检查点。元数据包括当前调用 的 policy 函数是否在重新计算期间。

>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    print(ctx.is_recompute)
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )
torch.utils.checkpoint。create_selective_checkpoint_contextspolicy_fn_or_listallow_cache_entry_mutation=False[来源]

Helper 避免在激活 checkpointing 期间重新计算某些操作。

将其与 torch.utils.checkpoint.checkpoint 一起使用来控制哪个 操作在 Backward Pass 期间重新计算。

参数
  • policy_fn_or_listCallableList) –

    • 如果提供了策略函数,它应该接受 a 、 、 、 args 和 kwargs 添加到 op 中,并返回一个枚举值 指示是否应重新计算 op 的执行。OpOverload

    • 如果提供了操作列表,则相当于一个策略 CheckpointPolicy.MUST_SAVE返回指定 操作和所有其他 CheckpointPolicy.PREFER_RECOMPUTE 操作。

  • allow_cache_entry_mutationbool可选) – 默认情况下,错误为 如果选择性激活检查点缓存的任何张量为 changed 以确保正确性。如果设置为 True,则此检查 已禁用。

返回

两个上下文管理器的元组。

>>> import functools
>>>
>>> x = torch.rand(10, 10, requires_grad=True)
>>> y = torch.rand(10, 10, requires_grad=True)
>>>
>>> ops_to_save = [
>>>    torch.ops.aten.mm.default,
>>> ]
>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    if op in ops_to_save:
>>>        return CheckpointPolicy.MUST_SAVE
>>>    else:
>>>        return CheckpointPolicy.PREFER_RECOMPUTE
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> # or equivalently
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
>>>
>>> def fn(x, y):
>>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源