torch.utils.checkpoint¶
注意
检查点是通过重新运行
backward 期间的每个 checkpointed segment。这可能会导致持续
像 RNG 州这样的州比没有
检查点。默认情况下,checkpointing 包括 juggle 逻辑
RNG 状态,以便使用 RNG 的检查点传递
(例如,通过 dropout)具有确定性输出,如
与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑
RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时
的检查点操作。如果确定性输出与
非检查点通行证不是必需的,提供或省略储藏和
在每个检查点期间恢复 RNG 状态。preserve_rng_state=False
checkpoint
checkpoint_sequential
存储逻辑保存和恢复当前设备的 RNG 状态
以及所有 cuda Tensor 参数的 device 到 .
但是,该逻辑无法预测用户是否会移动
Tensors 添加到自身内的新设备。因此,如果您移动
Tensors 添加到新设备(“new”表示不属于
[当前设备 + Tensor 参数的设备])within , 确定性
与非 checkpoint 传递相比,永远无法保证输出。run_fn
run_fn
run_fn
-
torch.utils.checkpoint.
checkpoint
(function, *args, use_reentrant=True, **kwargs)[来源]¶ 对模型或模型的一部分进行 Checkpoint 操作
检查点的工作原理是将计算换取内存。而不是存储所有 用于计算的整个计算图的中间激活 backward,则 checkpointed 部分不保存中间 activation, 而是在 Backward Pass 中重新计算它们。它可以应用于任何部件 的模型。
具体来说,在前向传递中,将以
方式运行,即不存储中间 激活。相反,forward pass 会保存 inputs tuples 和 parameter。在向后传递中,将检索保存的输入 and,并再次计算前向传递,现在跟踪中间激活,然后 使用这些激活值计算梯度。
function
function
function
function
的输出可以包含非 Tensor 值和 gradient 仅对 Tensor 值执行记录。请注意,如果输出 由嵌套结构组成(例如:自定义对象、列表、字典等) 由 Tensor 组成,则这些嵌套在自定义结构中的 Tensor 不会 被视为 Autograd 的一部分。
function
警告
如果 invocation during backward 执行任何不同的操作 而不是 forward 期间的那个,例如,由于某些全局变量, checkpointed 版本不会等效,不幸的是,它不能是 检测。
function
警告
如果指定,则如果 checkpoint 段 包含通过 detach() 或 torch.no_grad() 从计算图中分离的张量,则向后传递将引发错误。这是 因为 checkpoint 使所有输出都需要梯度,而 当将张量定义为在模型中没有梯度时,会导致问题。 要避免这种情况,请将张量从 checkpoint 函数之外分离。请注意,检查点分段可以包含张量 如果为 IS ,则从计算图中分离 指定。
use_reentrant=True
use_reentrant=False
警告
如果指定,则至少需要一个 inputs 如果模型输入需要 grads,则 否则,模型的检查点部分将没有梯度。在 至少有一个输出需要具有 AS 井。请注意,如果是 指定。
use_reentrant=True
requires_grad=True
requires_grad=True
use_reentrant=False
警告
如果指定,则当前仅执行 checkpoint 支持
,并且仅当其 inputs 参数未传递时。
不支持。如果指定,则执行 checkpointing 将与
一起使用。
use_reentrant=True
use_reentrant=False
- 参数
function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS
(activation, hidden)
function
activation
hidden
preserve_rng_state (bool, optional, default=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。
use_reentrant (bool, optional, default=True) – 使用检查点 需要重入 autograd 的实现。 如果指定,将使用 不需要重入 autograd 的实现。这 允许支持其他功能,例如 按预期工作。请注意,future 的 PyTorch 版本将默认为 .
use_reentrant=False
checkpoint
checkpoint
torch.autograd.grad
use_reentrant=False
args – 包含
function
- 返回
运行 on 的输出
function
*args
-
torch.utils.checkpoint.
checkpoint_sequential
(函数、段、输入、**kwargs)[来源]¶ 用于对 Sequential 模型进行检查点的辅助函数。
顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 segment 外,所有 segment 都将以
manner 运行,即不存储中间 激活。每个检查点 segment 的输入将被保存为 在 Backward Pass 中重新运行该 Segment。
- 参数
segments – 要在模型中创建的块数
input – 一个 Tensor 输入到
functions
preserve_rng_state (bool, optional, default=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。
- 返回
顺序运行的输出
functions
*inputs
例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)