目录

分布式优化器

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source]

此类封装了一个任意的 optim.Optimizer,并按照 ZeRO 所描述的方式将状态分片到组中的各个 rank 上。每个 rank 中的本地优化器实例仅负责更新大约 1 / world_size 个参数,因此只需要保留 1 / world_size 个优化器状态。在本地更新参数后,每个 rank 会将其参数广播给所有其他对等节点,以保持所有模型副本处于相同的状态。 ZeroRedundancyOptimizer 可与 torch.nn.parallel.DistributedDataParallel 联合使用,以减少每个 rank 的峰值内存消耗。

ZeroRedundancyOptimizer 使用排序贪心算法在每个排名中打包一组参数。每个参数属于单一排名,不会在排名之间分配。分区是任意的,可能与参数注册或使用顺序不匹配。

Parameters

参数 (Iterable) – 一个 Iterabletorch.Tensor s 提供所有参数,这些参数将在各个 rank 之间进行分片。

Keyword Arguments
  • 优化器类 (torch.nn.Optimizer) – 本地优化器的类。

  • process_group (ProcessGroup, 可选) – torch.distributed ProcessGroup (默认由 torch.distributed.init_process_group() 初始化)。

  • 参数作为桶视图 (bool, 可选) – 如果为 True,参数会被打包到桶中以加速通信,并且 param.data 字段指向不同偏移量的桶视图;如果为 False, 每个单独的参数会单独进行通信,并且每个 params.data 保持完整(默认值: False).

  • overlap_with_ddp (bool, 可选) – 如果为 True,则step()DistributedDataParallel 的梯度同步重叠;这需要 (1) optimizer_class 参数的功能优化器或具有功能等效的优化器,以及 (2) 注册由 ddp_zero_hook.py 中的一个函数构造的 DDP 通信钩子;参数被打包到与 DistributedDataParallel 匹配的桶中,这意味着 parameters_as_bucket_view 参数被忽略。 如果为 False,则step() 在反向传播之后独立运行(按正常情况)。 (默认值: False)

  • **默认参数 – 任何尾随参数,这些参数将传递给本地优化器。

Example:

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP

>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

目前,ZeroRedundancyOptimizer要求传入的所有参数都是相同的密集类型。

警告

如果你传递 overlap_with_ddp=True,请注意以下事项:鉴于当前实现中重叠的 DistributedDataParallelZeroRedundancyOptimizer 的方式,在前两个或三个训练迭代中,优化器步骤不会执行参数更新,具体取决于是否使用 static_graph=Falsestatic_graph=True。这是因为它需要了解由 DistributedDataParallel 使用的梯度分桶策略,而该策略在第二次前向传递时如果使用 static_graph=False,或者在第三次前向传递时如果使用 static_graph=True 才会最终确定。为了对此进行调整,一种选择是添加虚拟输入。

警告

ZeroRedundancyOptimizer 是实验性的,可能会发生变化。

add_param_group(param_group)[source]

添加一个参数组到编号为Optimizerparam_groups中。

这在微调预训练网络时很有用,因为冻结的层可以在训练过程中被解冻并添加到Optimizer中进行训练。

Parameters

param_group (dict) – 指定要优化的参数和组特定的优化选项。

警告

此方法处理所有分区上的分片更新, 但需要在所有 ranks 上调用。仅在部分 ranks 上调用会导致训练卡住, 因为通信原语的调用取决于管理的参数,并期望所有的 ranks 参与相同的参数集。

consolidate_state_dict(to=0)[source]

Consolidate a list of state_dict s (one per rank) on the target rank。

Parameters

(整型) – 接收优化器状态的等级(默认值:0)。

Raises

运行时错误 – 如果 overlap_with_ddp=True 且在该 ZeroRedundancyOptimizer 实例完全初始化之前调用此方法, 初始化会在 DistributedDataParallel 个梯度桶被重建后完成。

警告

这需要在所有排名上调用。

join_hook(**kwargs)[source]

返回 ZeRO join hook,它通过在优化器步骤中模拟集体通信,从而实现对不均衡输入的训练。

梯度必须在调用此钩子之前正确设置。

Parameters

kwargs (字典) – 包含用于在运行时修改连接钩子行为的任何关键字参数 的dict;所有共享相同连接上下文管理器的 Joinable 实例都会收到相同的 kwargs 值。

此钩子不支持任何关键字参数;即 kwargs 是未使用的。

load_state_dict(state_dict)[source]

从输入 state_dict 加载与给定等级相关的状态,必要时更新本地优化器。

Parameters

state_dict (字典) – 优化器状态;应是从 state_dict() 调用返回的对象。

Raises

运行时错误 – 如果 overlap_with_ddp=True 且在该 ZeroRedundancyOptimizer 实例完全初始化之前调用此方法, 初始化会在 DistributedDataParallel 个梯度桶被重建后完成。

state_dict()[source]

返回此 rank 所知的最后一个全局优化器状态。

Raises

运行时错误 – 如果 overlap_with_ddp=True 并且在该 ZeroRedundancyOptimizer 实例完全初始化之前调用了此方法,这会在 DistributedDataParallel 个梯度桶被重建后发生;或者如果没有先调用 consolidate_state_dict() 就调用了此方法。

step(closure=None, **kwargs)[source]

执行一次优化器步骤,并在所有 rank 之间同步参数。

Parameters

闭包 (可调用对象) – 一个重新评估模型并返回损失的闭包;对于大多数优化器来说是可选的。

Returns

根据底层本地优化器选择的可选损失。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源