分布式优化器¶
-
class
torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source]¶ 此类封装了一个任意的
optim.Optimizer,并按照 ZeRO 所描述的方式将状态分片到组中的各个 rank 上。每个 rank 中的本地优化器实例仅负责更新大约1 / world_size个参数,因此只需要保留1 / world_size个优化器状态。在本地更新参数后,每个 rank 会将其参数广播给所有其他对等节点,以保持所有模型副本处于相同的状态。ZeroRedundancyOptimizer可与torch.nn.parallel.DistributedDataParallel联合使用,以减少每个 rank 的峰值内存消耗。ZeroRedundancyOptimizer使用排序贪心算法在每个排名中打包一组参数。每个参数属于单一排名,不会在排名之间分配。分区是任意的,可能与参数注册或使用顺序不匹配。- Parameters
参数 (
Iterable) – 一个Iterable的torch.Tensors 提供所有参数,这些参数将在各个 rank 之间进行分片。- Keyword Arguments
优化器类 (
torch.nn.Optimizer) – 本地优化器的类。process_group (
ProcessGroup, 可选) –torch.distributedProcessGroup(默认由torch.distributed.init_process_group()初始化)。参数作为桶视图 (bool, 可选) – 如果为
True,参数会被打包到桶中以加速通信,并且param.data字段指向不同偏移量的桶视图;如果为False, 每个单独的参数会单独进行通信,并且每个params.data保持完整(默认值:False).overlap_with_ddp (bool, 可选) – 如果为
True,则step()与DistributedDataParallel的梯度同步重叠;这需要 (1)optimizer_class参数的功能优化器或具有功能等效的优化器,以及 (2) 注册由ddp_zero_hook.py中的一个函数构造的 DDP 通信钩子;参数被打包到与DistributedDataParallel匹配的桶中,这意味着parameters_as_bucket_view参数被忽略。 如果为False,则step()在反向传播之后独立运行(按正常情况)。 (默认值:False)**默认参数 – 任何尾随参数,这些参数将传递给本地优化器。
Example:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer要求传入的所有参数都是相同的密集类型。警告
如果你传递
overlap_with_ddp=True,请注意以下事项:鉴于当前实现中重叠的DistributedDataParallel与ZeroRedundancyOptimizer的方式,在前两个或三个训练迭代中,优化器步骤不会执行参数更新,具体取决于是否使用static_graph=False或static_graph=True。这是因为它需要了解由DistributedDataParallel使用的梯度分桶策略,而该策略在第二次前向传递时如果使用static_graph=False,或者在第三次前向传递时如果使用static_graph=True才会最终确定。为了对此进行调整,一种选择是添加虚拟输入。警告
ZeroRedundancyOptimizer 是实验性的,可能会发生变化。
-
add_param_group(param_group)[source]¶ 添加一个参数组到编号为
Optimizer的param_groups中。这在微调预训练网络时很有用,因为冻结的层可以在训练过程中被解冻并添加到
Optimizer中进行训练。- Parameters
param_group (dict) – 指定要优化的参数和组特定的优化选项。
警告
此方法处理所有分区上的分片更新, 但需要在所有 ranks 上调用。仅在部分 ranks 上调用会导致训练卡住, 因为通信原语的调用取决于管理的参数,并期望所有的 ranks 参与相同的参数集。
-
consolidate_state_dict(to=0)[source]¶ Consolidate a list of
state_dicts (one per rank) on the target rank。- Parameters
到 (整型) – 接收优化器状态的等级(默认值:0)。
- Raises
运行时错误 – 如果
overlap_with_ddp=True且在该ZeroRedundancyOptimizer实例完全初始化之前调用此方法, 初始化会在DistributedDataParallel个梯度桶被重建后完成。
警告
这需要在所有排名上调用。
-
join_hook(**kwargs)[source]¶ 返回 ZeRO join hook,它通过在优化器步骤中模拟集体通信,从而实现对不均衡输入的训练。
梯度必须在调用此钩子之前正确设置。
- Parameters
kwargs (字典) – 包含用于在运行时修改连接钩子行为的任何关键字参数 的
dict;所有共享相同连接上下文管理器的Joinable实例都会收到相同的kwargs值。
此钩子不支持任何关键字参数;即
kwargs是未使用的。
-
load_state_dict(state_dict)[source]¶ 从输入
state_dict加载与给定等级相关的状态,必要时更新本地优化器。- Parameters
state_dict (字典) – 优化器状态;应是从
state_dict()调用返回的对象。- Raises
运行时错误 – 如果
overlap_with_ddp=True且在该ZeroRedundancyOptimizer实例完全初始化之前调用此方法, 初始化会在DistributedDataParallel个梯度桶被重建后完成。
-
state_dict()[source]¶ 返回此 rank 所知的最后一个全局优化器状态。
- Raises
运行时错误 – 如果
overlap_with_ddp=True并且在该ZeroRedundancyOptimizer实例完全初始化之前调用了此方法,这会在DistributedDataParallel个梯度桶被重建后发生;或者如果没有先调用consolidate_state_dict()就调用了此方法。