torch.utils.checkpoint¶
注意
检查点通过在反向传播期间为每个检查点段重新运行前向传播段来实现。这可能导致诸如随机数生成器(RNG)状态等持久状态比没有检查点时更进一步。默认情况下,检查点包含逻辑来调整 RNG 状态,以确保使用 RNG(例如通过 dropout)的检查点传递与非检查点传递相比具有确定性输出。存储和恢复 RNG 状态的逻辑可能会根据检查点操作的运行时间带来中等程度的性能损失。如果不需要与非检查点传递相比的确定性输出,请向 preserve_rng_state=False 传入 checkpoint 或 checkpoint_sequential,以跳过每次检查点期间存储和恢复 RNG 状态的操作。
存储逻辑会为CPU和其他设备类型(通过_infer_device_type从张量参数中排除CPU张量推断出设备类型)保存和恢复RNG状态到run_fn。如果有多个设备,设备状态将只为一种设备类型的设备保存,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能导致梯度不正确。(注意,如果检测到CUDA设备,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有CPU张量,则默认设备类型状态(默认值是cuda,也可以通过DefaultDeviceType设置为其他设备)将被保存和恢复。然而,该逻辑无法预知用户是否会在run_fn本身内将张量移动到新设备。因此,如果你在run_fn中将张量移动到一个新设备(“新”是指不属于[当前设备 + 张量参数的设备]集合的设备),则与非检查点传递相比,确定性输出永远无法保证。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]¶
保存模型或模型的一部分
激活检查点(Activation checkpointing)是一种以计算换取内存的技术。 通常,在反向传播过程中,会保留用于梯度计算的张量直到它们被使用。而采用检查点技术的区域在正向传播时会省略保存这些张量,转而在反向传播过程中重新计算它们。这种激活检查点技术可以应用于模型的任何部分。
目前有两种检查点实现方式,由
use_reentrant参数决定。建议您使用use_reentrant=False。有关它们之间差异的讨论,请参阅下面的说明。警告
如果在反向传播过程中调用的
function与正向传播过程不同,例如由于全局变量的原因,检查点保存的版本可能不等价,这可能会引发错误或导致静默的梯度计算错误。警告
如果你正在使用
use_reentrant=True变体(目前这是默认设置),请参阅以下注释以了解重要的注意事项和潜在的限制。注意
可重入的 checkpoint 变体 (
use_reentrant=True) 和 不可重入的 checkpoint 变体 (use_reentrant=False) 在以下方面有所不同:非可重入检查点在所有需要的中间激活值被重新计算后立即停止重新计算。此功能默认启用,但可以通过
set_checkpoint_early_stop()禁用。可重入检查点在反向传递过程中始终完全重新计算function。可重入变体在正向传播过程中不会记录自动求导图,因为它在
torch.no_grad()的上下文中运行。不可重入版本会记录自动求导图,允许在检查点区域内对图执行反向传播。可重入检查点仅支持不带其 inputs 参数的反向传递的
torch.autograd.backward()API,而非可重入版本支持所有执行反向传递的方式。至少有一个输入和输出必须为
requires_grad=True的重入变体。如果未满足此条件,模型中被检查点的部分将不会有梯度。非重入版本没有这个要求。可重入版本不认为嵌套结构(例如,自定义对象、列表、字典等)中的张量参与自动求导,而非可重入版本则认为它们参与。
可重入检查点不支持计算图中从计算图分离的张量的检查点区域,而非可重入版本则支持。对于可重入变体,如果检查点段包含使用
detach()或 通过torch.no_grad()分离的张量,反向传播将引发错误。 这是因为checkpoint会使所有输出需要梯度,并且当模型中定义某个张量不需要梯度时会导致问题。为了避免这种情况,请在checkpoint函数外部分离张量。
- Parameters
function – 描述在模型或模型的部分的前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户传递
(activation, hidden),function应该能正确地将第一个输入视为activation,第二个输入视为hiddenpreserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态. 默认值:
Trueuse_reentrant (bool, optional) – 使用需要可重入自动求导的检查点实现。如果指定了
use_reentrant=False,checkpoint将使用不需要可重入自动求导的实现。这使得checkpoint可以支持额外的功能,例如与torch.autograd.grad一起正常工作,并支持输入到检查点函数中的关键字参数。请注意,PyTorch 的未来版本将默认使用use_reentrant=False。 默认值:Truecontext_fn (Callable, optional) – 一个可调用对象,返回两个上下文管理器的元组。该函数及其重新计算将分别在第一个和第二个上下文管理器下运行。 如果
use_reentrant=False,则支持此参数。determinism_check (str, optional) – 一个字符串,指定要执行的确定性检查。默认情况下,它被设置为
"default", 该设置会将重新计算的张量的形状、数据类型和设备与已保存张量进行比较。若要关闭此检查,请指定"none"。目前仅支持这两个值。如果您希望看到更多的确定性检查功能,请提交问题报告。这个参数仅在use_reentrant=False时受支持,如果use_reentrant=True,则确定性检查始终被禁用。调试 (布尔值, 可选) – 如果为
True,错误消息还将包括 原始前向计算过程中运行的操作符的跟踪信息以及重新计算的信息。此参数仅在use_reentrant=False时支持。args – 包含输入的元组
function
- Returns
在
*args上运行function的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs)[source]¶
用于检查点序列模型的辅助函数。
Sequential 模型按顺序执行一系列模块/函数。因此,我们可以将此类模型划分为多个段,并对每个段进行检查点处理。除了最后一个段之外,其他所有段都不会存储中间激活值。每个被检查点处理的段的输入将会被保存,以便在反向传播过程中重新运行该段。
警告
如果你正在使用
use_reentrant=True` variant (this is the default), please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False。- Parameters
functions – 一个
torch.nn.Sequential或者模块或函数的列表(构成模型),按顺序运行。segments – 在模型中创建的块数
输入 – 一个作为
functions输入的张量preserve_rng_state (bool, optional) – 在每次检查点时省略存储和恢复随机数生成器状态. 默认值:
Trueuse_reentrant (bool, optional) – 使用需要可重入自动求导的检查点实现。如果指定了
use_reentrant=False,checkpoint将使用不需要可重入自动求导的实现。这使得checkpoint可以支持额外的功能,例如与torch.autograd.grad一起按预期工作,并支持输入到检查点函数中的关键字参数。默认值:True
- Returns
依次在
*inputs上运行functions的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)