目录

torch.utils.checkpoint

注意

检查点是通过重新运行 backward 期间的每个 checkpointed segment。这可能会导致持续 像 RNG 州这样的州比没有 检查点。默认情况下,checkpointing 包括 juggle 逻辑 RNG 状态,以便使用 RNG 的检查点传递 (例如,通过 dropout)具有确定性输出,如 与非检查点传递相比。存储 (stash) 和还原 (restore) 的逻辑 RNG 状态可能会对性能造成中等程度的打击,具体取决于运行时 的检查点作。如果确定性输出与 非检查点通行证不是必需的,提供或省略储藏和 在每个检查点期间恢复 RNG 状态。preserve_rng_state=Falsecheckpointcheckpoint_sequential

存储逻辑保存和恢复当前设备的 RNG 状态 以及所有 cuda Tensor 参数的 device 到 . 但是,该逻辑无法预测用户是否会移动 Tensors 添加到自身内的新设备。因此,如果您移动 Tensors 添加到新设备(“new”表示不属于 [当前设备 + Tensor 参数的设备])within , 确定性 与非 checkpoint 传递相比,永远无法保证输出。run_fnrun_fnrun_fn

torch.utils.checkpoint.checkpoint(function*args**kwargs[来源]

对模型或模型的一部分进行 Checkpoint作

检查点的工作原理是将计算换取内存。而不是存储所有 用于计算的整个计算图的中间激活 backward,则 checkpointed 部分保存中间 activation, 而是在 Backward Pass 中重新计算它们。它可以应用于任何部件 的模型。

具体来说,在前向传递中,将运行functiontorch.no_grad()方式,即不存储中间体 激活。相反,forward pass 会保存 inputs tuples 和 parameter。在向后传递中,将检索保存的输入 and,并再次计算前向传递,现在跟踪中间激活,然后 使用这些激活值计算梯度。functionfunctionfunction

的输出可以包含非 Tensor 值和 gradient 仅对 Tensor 值执行记录。请注意,如果输出 由嵌套结构组成(例如:自定义对象、列表、字典等) 由 Tensor 组成,则这些嵌套在自定义结构中的 Tensor 不会 被视为 Autograd 的一部分。function

警告

Checkpointing 目前仅支持torch.autograd.backward()并且仅当其 inputs 参数未传递时。torch.autograd.grad()不支持。

警告

如果 invocation during backward 执行任何不同的作 而不是 forward 期间的那个,例如,由于某些全局变量, checkpointed 版本不会等效,不幸的是,它不能是 检测。function

警告

如果 checkpointed 段包含与计算 graph by detach()torch.no_grad() 时,向后传递将引发一个 错误。这是因为 checkpoint 使所有输出都需要 gradients,当张量被定义为没有 gradient 的 Gradient 值。为了规避这种情况,请将张量从 checkpoint 函数。

警告

至少有一个输入需要具有 if grads 用于模型输入,否则 model 不会有梯度。至少还需要有一个 outputs。requires_grad=Truerequires_grad=True

参数
  • function – 描述在模型的正向传递中运行什么,或者 模型的一部分。它还应该知道如何处理 inputs 作为元组传递。例如,在 LSTM 中,如果用户传递 ,则应正确使用 第一个输入 AS 和第二个输入 AS(activation, hidden)functionactivationhidden

  • preserve_rng_statebooloptionaldefault=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。

  • args – 包含function

返回

运行 on 的输出function*args

torch.utils.checkpoint.checkpoint_sequential(函数输入**kwargs[来源]

用于对 Sequential 模型进行检查点的辅助函数。

顺序模型按顺序执行模块/函数列表 (按顺序)。因此,我们可以将这样的模型分为不同的部分 并为每个段设置 checkpoint。除最后一个 Segment 之外的所有 Segment 都将在torch.no_grad()方式,即不存储中间体 激活。每个检查点 segment 的输入将被保存为 在 Backward Pass 中重新运行该 Segment。

checkpoint()关于检查点的工作原理。

警告

Checkpointing 目前仅支持torch.autograd.backward()并且仅当其 inputs 参数未传递时。torch.autograd.grad()不支持。

参数
  • 函数 – Atorch.nn.Sequential或模块列表或 函数(组成模型)按顺序运行。

  • segments – 要在模型中创建的块数

  • input – 一个 Tensor 输入到functions

  • preserve_rng_statebooloptionaldefault=True) – 省略存储和恢复 每个检查点期间的 RNG 状态。

返回

顺序运行的输出functions*inputs

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源