目录

torch.utils.checkpoint

注意

检查点是通过重新运行 backward 期间的每个 checkpointed segment。这可能会导致持续 像 RNG 州这样的州比没有 检查点。默认情况下,checkpointing 包括 juggle 逻辑 RNG 状态,以便使用 RNG 的检查点传递 (例如,通过 dropout)具有确定性输出,如 与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑 RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时 的检查点操作。如果确定性输出与 非检查点通行证不是必需的,提供或省略储藏和 在每个检查点期间恢复 RNG 状态。preserve_rng_state=Falsecheckpointcheckpoint_sequential

存储逻辑保存和恢复 CPU 和另一个 设备类型(从 Tensor 参数推断设备类型,不包括 CPU 张量 ) 设置为 .如果有多个 device,则仅保存单个设备类型的设备的设备状态, 其余设备将被忽略。因此,如果任何 checkpoint 函数涉及随机性,这可能会导致 gradient 不正确。(注意 如果 CUDA 设备在检测到的设备中,则它将被优先考虑; 否则,将选择遇到的第一个设备。如果没有 CPU-tensors,默认设备类型状态(默认值为 cuda,它 可以设置为 Other Device 的 ) 将被保存和恢复。 但是,该逻辑无法预测用户是否会移动 Tensors 添加到自身内的新设备。因此,如果您移动 Tensors 添加到新设备(“new”表示不属于 [当前设备 + Tensor 参数的设备])within , 确定性 与非 checkpoint 传递相比,永远无法保证输出。_infer_device_typerun_fnDefaultDeviceTyperun_fnrun_fn

torch.utils.checkpoint。checkpointfunction*argsuse_reentrant=Nonecontext_fn=<function noop_context_fn>determinism_check='default'debug=False**kwargs[来源]

对模型或模型的一部分进行 Checkpoint 操作。

激活检查点是一种用计算换取内存的技术。 而不是保持 backward 所需的张量处于活动状态,直到它们在 Checkpointed 中 backward 和 forward 计算期间的梯度计算 regions 省略了保存张量以供向后,并在 向后传递。激活检查点可以应用于 型。

目前有两种可用的 checkpointing 实现,determined by 参数。建议您使用 .请参阅下面的注释,了解 他们的差异。use_reentrantuse_reentrant=False

警告

如果向后传递期间的调用不同 从前向传递中,例如,由于全局变量,checkpointed checkpointed 版本可能不等效,这可能会导致 错误或导致 Gradient 错误。function

警告

如果您使用的是变体(目前为 默认),请参阅下面的注释以获取重要 注意事项和潜在限制。use_reentrant=True

注意

checkpoint () 和 checkpoint () 的不可重入变体 在以下方面有所不同:use_reentrant=Trueuse_reentrant=False

  • 不可重入 checkpoint 在需要时立即停止重新计算 已重新计算中间激活数。此功能已启用 默认情况下,但可以使用 . Reentrant 检查点始终在其 整个过程。set_checkpoint_early_stop()function

  • reentrant 变体在 forward pass,因为它与 下的 forward pass 一起运行 。不可重入版本会记录 autograd graph 中执行,允许在 检查点区域。

  • 可重入检查点仅支持不带 inputs 参数的向后传递 API,而不可重入版本支持所有方式 执行向后传递。

  • 至少一个输入和输出必须具有 reentrant 变体。如果未满足此条件,则 checkpoint 部分 的模型将没有梯度。不可重入版本执行 没有此要求。requires_grad=True

  • 可重入版本不考虑嵌套结构中的张量 (例如,自定义对象、列表、字典等)作为参与 autograd 的 AUTOGRAD 版本,而 non-reentrant 版本则适用。

  • 可重入检查点不支持 从计算图中分离的张量,而 non-reentrant 版本会。对于可重入变体,如果 Checkpointed Segment 包含使用 OR 分离的张量 使用 时,向后传递将引发错误。 这是因为 make all the output require gradients 当张量被定义为没有 gradient 时,这会导致问题 模型。为避免这种情况,请将张量分离到函数外部。detach()checkpointcheckpoint

参数
  • function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS(activation, hidden)functionactivationhidden

  • preserve_rng_statebooloptional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。 违约:True

  • use_reentrantbooloptional) – 使用检查点 需要重入 autograd 的实现。 如果指定,将使用 不需要重入 autograd 的实现。这 允许支持其他功能,例如 按预期工作并支持 keyword 参数输入到 checkpointed 函数中。请注意,future 的 PyTorch 版本将默认为 . 违约:use_reentrant=Falsecheckpointcheckpointtorch.autograd.graduse_reentrant=FalseTrue

  • context_fnCallableoptional) – 返回 2 元组的可调用对象 上下文管理器。该函数及其重新计算将运行 分别在 First 和 Second Context Manager 下。 仅当 .use_reentrant=False

  • determinism_checkstroptional) – 指定确定性的字符串 检查以执行。默认情况下,它被设置为 比较重新计算的张量的形状、数据类型和设备 针对那些保存的张量。要关闭此检查,请指定 。目前,这是仅有的两个受支持的值。 如果您想看到更多的确定性,请打开一个 issue 检查。仅当 , 如果 ,则始终禁用确定性检查。"default""none"use_reentrant=Falseuse_reentrant=True

  • debugbooloptional) – 如果 ,错误消息还将包括 在原始前向计算期间运行的运算符的跟踪 以及重新计算。仅当 .Trueuse_reentrant=False

  • args – 包含function

返回

运行 on 的输出function*args

torch.utils.checkpoint。checkpoint_sequential函数输入use_reentrant=**kwargs[来源]

对 sequential 模型进行 checkpoint 以节省内存。

顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 segment 之外的所有 segment 都不会存储 中间激活。每个 checkpointed segment 的 inputs 将 保存以在 Backward Pass 中重新运行该段落。

警告

如果您使用的是 .use_reentrant=True` variant (this is the default), please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False

参数
  • functions – A 或模块列表或 函数(组成模型)按顺序运行。

  • segments – 要在模型中创建的块数

  • input – 一个 Tensor 输入到functions

  • preserve_rng_statebooloptional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。 违约:True

  • use_reentrantbooloptional) – 使用检查点 需要重入 autograd 的实现。 如果指定,将使用 不需要重入 autograd 的实现。这 允许支持其他功能,例如 按预期工作并支持 keyword 参数输入到 checkpointed 函数中。 违约:use_reentrant=Falsecheckpointcheckpointtorch.autograd.gradTrue

返回

顺序运行的输出functions*inputs

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint。set_checkpoint_debug_enabled启用[来源]

设置 checkpoint 是否应打印其他调试的上下文管理器 信息。有关更多信息,请参阅 flag 。请注意, 设置后,此上下文管理器将覆盖 Passed to 的值 检查站。要遵循本地设置,请传递到此上下文。debugdebugNone

参数

enabledbool) - 检查点是否应打印调试信息。 默认值为 'None'。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源