torch.utils.checkpoint¶
注意
检查点是通过重新运行
backward 期间的每个 checkpointed segment。这可能会导致持续
像 RNG 州这样的州比没有
检查点。默认情况下,checkpointing 包括 juggle 逻辑
RNG 状态,以便使用 RNG 的检查点传递
(例如,通过 dropout)具有确定性输出,如
与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑
RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时
的检查点作。如果确定性输出与
非检查点通行证不是必需的,提供或省略储藏和
在每个检查点期间恢复 RNG 状态。preserve_rng_state=False
checkpoint
checkpoint_sequential
存储逻辑保存和恢复当前设备的 RNG 状态
以及所有 cuda Tensor 参数的 device 到 .
但是,该逻辑无法预测用户是否会移动
Tensors 添加到自身内的新设备。因此,如果您移动
Tensors 添加到新设备(“new”表示不属于
[当前设备 + Tensor 参数的设备])within , 确定性
与非 checkpoint 传递相比,永远无法保证输出。run_fn
run_fn
run_fn
-
torch.utils.checkpoint.
checkpoint
(function, *args, **kwargs)[来源]¶ 对模型或模型的一部分进行 Checkpoint作
检查点的工作原理是将计算换取内存。而不是存储所有 用于计算的整个计算图的中间激活 backward,则 checkpointed 部分不保存中间 activation, 而是在 Backward Pass 中重新计算它们。它可以应用于任何部件 的模型。
具体来说,在前向传递中,将运行
function
torch.no_grad()
方式,即不存储中间体 激活。相反,forward pass 会保存 inputs tuples 和 parameter。在向后传递中,将检索保存的输入 and,并再次计算前向传递,现在跟踪中间激活,然后 使用这些激活值计算梯度。function
function
function
的输出可以包含非 Tensor 值和 gradient 仅对 Tensor 值执行记录。请注意,如果输出 由嵌套结构组成(例如:自定义对象、列表、字典等) 由 Tensor 组成,则这些嵌套在自定义结构中的 Tensor 不会 被视为 Autograd 的一部分。
function
警告
Checkpointing 目前仅支持
torch.autograd.backward()
并且仅当其 inputs 参数未传递时。torch.autograd.grad()
不支持。警告
如果 invocation during backward 执行任何不同的作 而不是 forward 期间的那个,例如,由于某些全局变量, checkpointed 版本不会等效,不幸的是,它不能是 检测。
function
警告
如果 checkpointed 段包含与计算 graph by detach() 或 torch.no_grad() 时,向后传递将引发一个 错误。这是因为 checkpoint 使所有输出都需要 gradients,当张量被定义为没有 gradient 的 Gradient 值。为了规避这种情况,请将张量从 checkpoint 函数。
警告
至少有一个输入需要具有 if grads 用于模型输入,否则 model 不会有梯度。至少还需要有一个 outputs。
requires_grad=True
requires_grad=True
- 参数
function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS
(activation, hidden)
function
activation
hidden
preserve_rng_state (bool, optional, default=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。
args – 包含
function
- 返回
运行 on 的输出
function
*args
-
torch.utils.checkpoint.
checkpoint_sequential
(函数、段、输入、**kwargs)[来源]¶ 用于对 Sequential 模型进行检查点的辅助函数。
顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 Segment 之外的所有 Segment 都将在
torch.no_grad()
方式,即不存储中间体 激活。每个检查点 segment 的输入将被保存为 在 Backward Pass 中重新运行该 Segment。看
checkpoint()
关于检查点的工作原理。警告
Checkpointing 目前仅支持
torch.autograd.backward()
并且仅当其 inputs 参数未传递时。torch.autograd.grad()
不支持。- 参数
函数 – A
torch.nn.Sequential
或模块列表或 函数(组成模型)按顺序运行。segments – 要在模型中创建的块数
input – 一个 Tensor 输入到
functions
preserve_rng_state (bool, optional, default=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。
- 返回
顺序运行的输出
functions
*inputs
例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)