目录

torch.utils.checkpoint

注意

检查点是通过重新运行 backward 期间的每个 checkpointed segment。这可能会导致持续 像 RNG 州这样的州比没有 检查点。默认情况下,checkpointing 包括 juggle 逻辑 RNG 状态,以便使用 RNG 的检查点传递 (例如,通过 dropout)具有确定性输出,如 与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑 RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时 的检查点操作。如果确定性输出与 非检查点通行证不是必需的,提供或省略储藏和 在每个检查点期间恢复 RNG 状态。preserve_rng_state=Falsecheckpointcheckpoint_sequential

存储逻辑保存和恢复当前设备的 RNG 状态 以及所有 cuda Tensor 参数的 device 到 . 但是,该逻辑无法预测用户是否会移动 Tensors 添加到自身内的新设备。因此,如果您移动 Tensors 添加到新设备(“new”表示不属于 [当前设备 + Tensor 参数的设备])within , 确定性 与非 checkpoint 传递相比,永远无法保证输出。run_fnrun_fnrun_fn

torch.utils.checkpoint.checkpoint(function*argsuse_reentrant=True**kwargs[来源]

对模型或模型的一部分进行 Checkpoint 操作

检查点的工作原理是将计算换取内存。而不是存储所有 用于计算的整个计算图的中间激活 backward,则 checkpointed 部分保存中间 activation, 而是在 Backward Pass 中重新计算它们。它可以应用于任何部件 的模型。

具体来说,在前向传递中,将以 方式运行,即不存储中间 激活。相反,forward pass 会保存 inputs tuples 和 parameter。在向后传递中,将检索保存的输入 and,并再次计算前向传递,现在跟踪中间激活,然后 使用这些激活值计算梯度。functionfunctionfunctionfunction

的输出可以包含非 Tensor 值和 gradient 仅对 Tensor 值执行记录。请注意,如果输出 由嵌套结构组成(例如:自定义对象、列表、字典等) 由 Tensor 组成,则这些嵌套在自定义结构中的 Tensor 不会 被视为 Autograd 的一部分。function

警告

如果 invocation during backward 执行任何不同的操作 而不是 forward 期间的那个,例如,由于某些全局变量, checkpointed 版本不会等效,不幸的是,它不能是 检测。function

警告

如果指定,则如果 checkpoint 段 包含通过 detach()torch.no_grad() 从计算图中分离的张量,则向后传递将引发错误。这是 因为 checkpoint 使所有输出都需要梯度,而 当将张量定义为在模型中没有梯度时,会导致问题。 要避免这种情况,请将张量从 checkpoint 函数之外分离。请注意,检查点分段可以包含张量 如果为 IS ,则从计算图中分离 指定。use_reentrant=Trueuse_reentrant=False

警告

如果指定,则至少需要一个 inputs 如果模型输入需要 grads,则 否则,模型的检查点部分将没有梯度。在 至少有一个输出需要具有 AS 井。请注意,如果是 指定。use_reentrant=Truerequires_grad=Truerequires_grad=Trueuse_reentrant=False

警告

如果指定,则当前仅执行 checkpoint 支持,并且仅当其 inputs 参数未传递时。不支持。如果指定,则执行 checkpointing 将与 一起使用。use_reentrant=Trueuse_reentrant=False

参数
  • function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS(activation, hidden)functionactivationhidden

  • preserve_rng_statebooloptionaldefault=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。

  • use_reentrantbooloptionaldefault=True) – 使用检查点 需要重入 autograd 的实现。 如果指定,将使用 不需要重入 autograd 的实现。这 允许支持其他功能,例如 按预期工作。请注意,future 的 PyTorch 版本将默认为 .use_reentrant=Falsecheckpointcheckpointtorch.autograd.graduse_reentrant=False

  • args – 包含function

返回

运行 on 的输出function*args

torch.utils.checkpoint.checkpoint_sequential(函数输入**kwargs[来源]

用于对 Sequential 模型进行检查点的辅助函数。

顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 segment 外,所有 segment 都将以 manner 运行,即不存储中间 激活。每个检查点 segment 的输入将被保存为 在 Backward Pass 中重新运行该 Segment。

请参阅 检查点的工作原理。

警告

Checkpointing 目前仅支持且仅在未传递其 inputs 参数时支持。不支持。

参数
  • functions – A 或模块列表或 函数(组成模型)按顺序运行。

  • segments – 要在模型中创建的块数

  • input – 一个 Tensor 输入到functions

  • preserve_rng_statebooloptionaldefault=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。

返回

顺序运行的输出functions*inputs

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源