分布式优化器¶
警告
目前在使用 CUDA 张量时还不支持分布式优化器。
torch.distributed.optim 暴露了 DistributedOptimizer,它接受一个远程参数列表
(RRef)并在这些参数所在的工作者上本地运行优化器。分布式
优化器可以使用任何本地优化器 基类 来在每个工作者上应用梯度。
- class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source][source]¶
DistributedOptimizer 获取分布在各个工作节点上的参数的远程引用,并为每个参数本地应用给定的优化器。
此类使用
get_gradients()来检索特定参数的梯度。并发调用
step(), 无论是来自同一个客户端还是不同的客户端,都会在每个工作器上被序列化——因为每个工作器的优化器一次只能处理一组梯度。然而,不能保证一个客户端的完整前向-后向-优化器序列会依次执行。这意味着应用的梯度可能不对应于在给定工作器上执行的最新前向传播。此外,跨工作器之间也没有保证顺序。DistributedOptimizer 默认情况下创建启用了TorchScript的本地优化器,这样在多线程训练(例如分布式模型并行)时,优化器更新不会被Python全局解释器锁(GIL)阻塞。此功能目前对大多数优化器都已启用。您也可以按照PyTorch教程中的方法为自己的自定义优化器启用TorchScript支持。
- Parameters
优化器类 (optim.Optimizer) – 在每个工作进程中实例化的优化器类。
params_rref (列表[RRef]) – 用于优化的本地或远程参数的RRef列表。
args – 要传递给每个工作者的优化器构造函数的参数。
kwargs – 传递给每个工作进程的优化器构造函数的参数。
- Example::
>>> import torch.distributed.autograd as dist_autograd >>> import torch.distributed.rpc as rpc >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> >>> with dist_autograd.context() as context_id: >>> # Forward pass. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> loss = rref1.to_here() + rref2.to_here() >>> >>> # Backward pass. >>> dist_autograd.backward(context_id, [loss.sum()]) >>> >>> # Optimizer. >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> [rref1, rref2], >>> lr=0.05, >>> ) >>> dist_optim.step(context_id)
- step(context_id)[source][source]¶
执行单个优化步骤。
这将在每个工作者上调用
torch.optim.Optimizer.step(),并阻塞直到所有工作者返回。包含需要优化参数的工作进程将被调用。提供的context_id将用于检索包含应应用于参数的梯度的context。- Parameters
context_id – 我们应该为其运行优化器步骤的自动微分上下文 ID。
- class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[source][source]¶
封装任意一个
torch.optim.Optimizer并在每一步运行局部SGD后处理, 此优化器在每一步运行局部优化器。 经过预热阶段后,它会在应用局部优化器后定期平均参数。- Parameters
optim (优化器) – 本地优化器。
平均器 (ModelAverager) – 用于运行后局部SGD算法的模型平均器实例。
Example:
>>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( >>> PostLocalSGDState, >>> post_localSGD_hook, >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # Register a post-localSGD communication hook. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step()
- load_state_dict(state_dict)[source][source]¶
这与
torch.optim.Optimizerload_state_dict()相同, 但也会将模型平均器的步数恢复为在提供的state_dict中保存的值。如果没有
"step"条目在state_dict中, 它将触发警告并将模型平均器的步骤初始化为0。
- state_dict()[source][source]¶
这与
torch.optim.Optimizerstate_dict()相同, 但会额外添加一个条目以记录模型平均器的步骤到检查点中, 以确保重新加载不会导致不必要的预热。
- class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source][source]¶
包装一个任意的
optim.Optimizer并在其组中跨等级分片其状态。共享按照ZeRO所述进行。
每个rank中的本地优化器实例只需更新大约
1 / world_size个参数,因此只需保留1 / world_size个优化器状态。在参数被本地更新后,每个rank会将其参数广播给所有其他节点以保持所有模型副本处于相同状态。ZeroRedundancyOptimizer可以与torch.nn.parallel.DistributedDataParallel结合使用,以减少每rank的峰值内存消耗。ZeroRedundancyOptimizer使用排序贪心算法在每个排名中打包一组参数。每个参数属于单一排名,不会在排名之间分配。分区是任意的,可能与参数注册或使用顺序不匹配。- Parameters
参数 (
Iterable) – 一个Iterable的torch.Tensor或dict,给出所有参数,这些参数将在各个排名间进行分片。- Keyword Arguments
优化器类 (
torch.nn.Optimizer) – 本地优化器的类。process_group (
ProcessGroup, 可选) –torch.distributedProcessGroup(默认由torch.distributed.init_process_group()初始化)。参数作为桶视图 (bool, 可选) – 如果为
True,参数会被打包到桶中以加速通信,并且param.data字段指向不同偏移量的桶视图;如果为False, 每个单独的参数会单独进行通信,并且每个params.data保持完整(默认值:False).overlap_with_ddp (bool, 可选) – 如果为
True,则step()与DistributedDataParallel的梯度同步重叠;这需要 (1)optimizer_class参数的功能优化器或具有功能等效的优化器,以及 (2) 注册由ddp_zero_hook.py中的一个函数构造的 DDP 通信钩子;参数被打包到与DistributedDataParallel匹配的桶中,这意味着parameters_as_bucket_view参数被忽略。 如果为False,则step()在反向传播之后独立运行(按正常情况)。 (默认值:False)**默认参数 – 任何尾随参数,这些参数将传递给本地优化器。
Example:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer要求传入的所有参数都是相同的密集类型。警告
如果你传递
overlap_with_ddp=True,请注意以下事项:鉴于当前实现中重叠的DistributedDataParallel与ZeroRedundancyOptimizer的方式,在前两个或三个训练迭代中,优化器步骤不会执行参数更新,具体取决于是否使用static_graph=False或static_graph=True。这是因为它需要了解由DistributedDataParallel使用的梯度分桶策略,而该策略在第二次前向传递时如果使用static_graph=False,或者在第三次前向传递时如果使用static_graph=True才会最终确定。为了对此进行调整,一种选择是添加虚拟输入。警告
ZeroRedundancyOptimizer 是实验性的,可能会发生变化。
- add_param_group(param_group)[source][source]¶
添加一个参数组到编号为
Optimizer的param_groups中。这在微调预训练网络时很有用,因为冻结的层可以在训练过程中被解冻并添加到
Optimizer中进行训练。- Parameters
param_group (dict) – 指定要优化的参数和组特定的优化选项。
警告
此方法处理所有分区上的分片更新, 但需要在所有 ranks 上调用。仅在部分 ranks 上调用会导致训练卡住, 因为通信原语的调用取决于管理的参数,并期望所有的 ranks 参与相同的参数集。
- consolidate_state_dict(to=0)[source][source]¶
在目标排名上合并一个
state_dict的列表(每个排名一个)。- Parameters
到 (整型) – 接收优化器状态的等级(默认值:0)。
- Raises
运行时错误 – 如果
overlap_with_ddp=True且在该ZeroRedundancyOptimizer实例完全初始化之前调用此方法, 初始化会在DistributedDataParallel个梯度桶被重建后完成。
警告
这需要在所有排名上调用。
- join_hook(**kwargs)[source][source]¶
返回 ZeRO join 钩子。
它通过在优化器步骤中复制集体通信,实现了对不规则输入的训练。
梯度必须在调用此钩子之前正确设置。
- Parameters
kwargs (字典) – 包含用于在运行时修改连接钩子行为的任何关键字参数 的
dict;所有共享相同连接上下文管理器的Joinable实例都会收到相同的kwargs值。
此钩子不支持任何关键字参数;即
kwargs是未使用的。
- load_state_dict(state_dict)[source][source]¶
从输入
state_dict加载与给定排名相关的状态,并根据需要更新本地优化器。- Parameters
state_dict (字典) – 优化器状态;应是从
state_dict()调用返回的对象。- Raises
运行时错误 – 如果
overlap_with_ddp=True且在该ZeroRedundancyOptimizer实例完全初始化之前调用此方法, 初始化会在DistributedDataParallel个梯度桶被重建后完成。
- state_dict()[source][source]¶
返回该 ranks 知道的最后一个全局优化器状态。
- Raises
运行时错误 – 如果
overlap_with_ddp=True并且在该ZeroRedundancyOptimizer实例完全初始化之前调用了此方法,这会在DistributedDataParallel个梯度桶被重建后发生;或者如果没有先调用consolidate_state_dict()就调用了此方法。- Return type