torch.utils.checkpoint¶
注意
检查点是通过重新运行
backward 期间的每个 checkpointed segment。这可能会导致持续
像 RNG 州这样的州比没有
检查点。默认情况下,checkpointing 包括 juggle 逻辑
RNG 状态,以便使用 RNG 的检查点传递
(例如,通过 dropout)具有确定性输出,如
与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑
RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时
的检查点操作。如果确定性输出与
非检查点通行证不是必需的,提供或省略储藏和
在每个检查点期间恢复 RNG 状态。preserve_rng_state=False
checkpoint
checkpoint_sequential
存储逻辑保存和恢复 CPU 和另一个
设备类型(从 Tensor 参数推断设备类型,不包括 CPU
张量 ) 设置为 .如果有多个
device,则仅保存单个设备类型的设备的设备状态,
其余设备将被忽略。因此,如果任何 checkpoint
函数涉及随机性,这可能会导致 gradient 不正确。(注意
如果 CUDA 设备在检测到的设备中,则它将被优先考虑;
否则,将选择遇到的第一个设备。如果没有
CPU-tensors,默认设备类型状态(默认值为 cuda,它
可以设置为 Other Device 的 ) 将被保存和恢复。
但是,该逻辑无法预测用户是否会移动
Tensors 添加到自身内的新设备。因此,如果您移动
Tensors 添加到新设备(“new”表示不属于
[当前设备 + Tensor 参数的设备])within , 确定性
与非 checkpoint 传递相比,永远无法保证输出。_infer_device_type
run_fn
DefaultDeviceType
run_fn
run_fn
- torch.utils.checkpoint。checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[来源]¶
对模型或模型的一部分进行 Checkpoint 操作。
激活检查点是一种用计算换取内存的技术。 而不是保持 backward 所需的张量处于活动状态,直到它们在 Checkpointed 中 backward 和 forward 计算期间的梯度计算 regions 省略了保存张量以供向后,并在 向后传递。激活检查点可以应用于 型。
目前有两种可用的 checkpointing 实现,determined by 参数。建议您使用 .请参阅下面的注释,了解 他们的差异。
use_reentrant
use_reentrant=False
警告
如果向后传递期间的调用不同 从前向传递中,例如,由于全局变量,checkpointed checkpointed 版本可能不等效,这可能会导致 错误或导致 Gradient 错误。
function
警告
如果您使用的是变体(目前为 默认),请参阅下面的注释以获取重要 注意事项和潜在限制。
use_reentrant=True
注意
checkpoint () 和 checkpoint () 的不可重入变体 在以下方面有所不同:
use_reentrant=True
use_reentrant=False
不可重入 checkpoint 在需要时立即停止重新计算 已重新计算中间激活数。此功能已启用 默认情况下,但可以使用 . Reentrant 检查点始终在其 整个过程。
set_checkpoint_early_stop()
function
reentrant 变体在 forward pass,因为它与 下的 forward pass 一起
运行 。不可重入版本会记录 autograd graph 中执行,允许在 检查点区域。
可重入检查点仅支持不带 inputs 参数的向后传递 API,而不可重入版本支持所有方式
执行向后传递。
至少一个输入和输出必须具有 reentrant 变体。如果未满足此条件,则 checkpoint 部分 的模型将没有梯度。不可重入版本执行 没有此要求。
requires_grad=True
可重入版本不考虑嵌套结构中的张量 (例如,自定义对象、列表、字典等)作为参与 autograd 的 AUTOGRAD 版本,而 non-reentrant 版本则适用。
可重入检查点不支持 从计算图中分离的张量,而 non-reentrant 版本会。对于可重入变体,如果 Checkpointed Segment 包含使用 OR 分离的张量 使用
时,向后传递将引发错误。 这是因为 make all the output require gradients 当张量被定义为没有 gradient 时,这会导致问题 模型。为避免这种情况,请将张量分离到函数外部。
detach()
checkpoint
checkpoint
- 参数
function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS
(activation, hidden)
function
activation
hidden
preserve_rng_state (bool, optional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。 违约:
True
use_reentrant (bool, optional) – 使用检查点 需要重入 autograd 的实现。 如果指定,将使用 不需要重入 autograd 的实现。这 允许支持其他功能,例如 按预期工作并支持 keyword 参数输入到 checkpointed 函数中。请注意,future 的 PyTorch 版本将默认为 . 违约:
use_reentrant=False
checkpoint
checkpoint
torch.autograd.grad
use_reentrant=False
True
context_fn (Callable, optional) – 返回 2 元组的可调用对象 上下文管理器。该函数及其重新计算将运行 分别在 First 和 Second Context Manager 下。 仅当 .
use_reentrant=False
determinism_check (str, optional) – 指定确定性的字符串 检查以执行。默认情况下,它被设置为 比较重新计算的张量的形状、数据类型和设备 针对那些保存的张量。要关闭此检查,请指定 。目前,这是仅有的两个受支持的值。 如果您想看到更多的确定性,请打开一个 issue 检查。仅当 , 如果 ,则始终禁用确定性检查。
"default"
"none"
use_reentrant=False
use_reentrant=True
debug (bool, optional) – 如果 ,错误消息还将包括 在原始前向计算期间运行的运算符的跟踪 以及重新计算。仅当 .
True
use_reentrant=False
args – 包含
function
- 返回
运行 on 的输出
function
*args
- torch.utils.checkpoint。checkpoint_sequential(函数、段、输入、use_reentrant=无、**kwargs)[来源]¶
对 sequential 模型进行 checkpoint 以节省内存。
顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 segment 之外的所有 segment 都不会存储 中间激活。每个 checkpointed segment 的 inputs 将 保存以在 Backward Pass 中重新运行该段落。
警告
如果您使用的是 .
use_reentrant=True` variant (this is the default), please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False
- 参数
segments – 要在模型中创建的块数
input – 一个 Tensor 输入到
functions
preserve_rng_state (bool, optional) – 省略 stashing 和 restore 每个检查点期间的 RNG 状态。 违约:
True
use_reentrant (bool, optional) – 使用检查点 需要重入 autograd 的实现。 如果指定,将使用 不需要重入 autograd 的实现。这 允许支持其他功能,例如 按预期工作并支持 keyword 参数输入到 checkpointed 函数中。 违约:
use_reentrant=False
checkpoint
checkpoint
torch.autograd.grad
True
- 返回
顺序运行的输出
functions
*inputs
例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)