分布式优化器¶
-
class (params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[来源]
torch.distributed.optim.
ZeroRedundancyOptimizer
¶ 此类包装任意值
并将其状态分片到组中的等级中,如下所示 由 ZeRO 描述。每个 rank 中的本地 optimizer 实例仅为 负责更新大约 parameters 和 因此只需要保持优化器状态。后 参数在本地更新,每个 rank 都会将其参数广播到 所有其他 Peer 节点保持所有模型副本处于相同状态。 可与 结合使用
,以减少每等级峰值 内存消耗。
1 / world_size
1 / world_size
ZeroRedundancyOptimizer
ZeroRedundancyOptimizer
使用排序贪婪算法打包数字 每个等级的参数。每个参数都属于一个等级,并且 不按等级划分。分区是任意的,可能与 参数注册或使用顺序。- 参数
- 关键字参数
optimizer_class () – 本地 优化。
torch.nn.Optimizer
process_group (可选) – (默认值:初始化者
)。
ProcessGroup
torch.distributed
ProcessGroup
dist.group.WORLD
parameters_as_bucket_view (bool, optional) – 如果 , 参数为 打包到 Bucket 中以加快通信速度,字段指向不同偏移量的 Bucket 视图;如果 每个单独的参数都单独通信,并且每个参数都保持不变(默认值:)。
True
param.data
False
params.data
False
overlap_with_ddp (bool, optional) – 如果 ,
为 与 的梯度重叠 同步;这需要 (1) 函数式优化器 对于参数或具有函数式 等效和 (2) 注册 DDP 通信挂钩 由 中的 函数之一构造; 参数被打包到与 中的参数匹配的存储桶中,这意味着该参数将被忽略。 如果 ,
在向后传递后不相交地运行 (按正常值)。 (默认:
True
DistributedDataParallel
optimizer_class
ddp_zero_hook.py
DistributedDataParallel
parameters_as_bucket_view
False
False
)defaults – 任何尾随参数,这些参数将转发到本地的 优化。
例:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,要求所有 传入的参数是相同的 dense 类型。
ZeroRedundancyOptimizer
警告
如果您传递 ,请注意以下情况:给定 目前实现 overlap with
的方式是第一个 2 次或 3 次训练迭代在 优化器步骤,分别取决于 if 或 。这是因为它需要 有关使用的梯度分桶策略的信息,该策略在 第二次向前传球(如果或直到第三次) 如果 .要对此进行调整,请选择一个选项 是预置虚拟输入。
overlap_with_ddp=True
DistributedDataParallel
static_graph=False
static_graph=True
DistributedDataParallel
static_graph=False
static_graph=True
警告
ZeroRedundancyOptimizer 是实验性的,可能会发生变化。
-
add_param_group
(param_group)[来源]¶ 将参数组添加到 的 .
Optimizer
param_groups
这在微调预先训练的网络时非常有用,因为 freeze 层可以设置为可训练并添加到 AS 培训进度。
Optimizer
- 参数
param_group (dict) – 指定要优化的参数,并且 特定于组的优化选项。
警告
此方法处理更新所有分区上的分片 但需要被召集到所有级别。在 军衔会导致训练挂起,因为通信 根据托管参数调用 Primitives,并且 期望所有排名都参与同一组参数。
-
consolidate_state_dict
(to=0)[来源]¶ 合并目标上的 s 列表(每个等级一个) 排。
state_dict
- 参数
to (int) – 接收优化器状态的排名(默认值:0)。
- 提高
RuntimeError – 如果且此方法为 在此
实例之前调用 已完全初始化,这在梯度桶 重建。
overlap_with_ddp=True
DistributedDataParallel
警告
这需要在所有级别上调用。
-
join_hook
(**kwargs)[来源]¶ 返回 ZeRO 连接钩子,该钩子允许对不均匀的输入进行训练 在 Optimizer 步骤中隐藏 Collective 通信。
在调用此钩子之前,必须正确设置渐变。
- 参数
kwargs (dict) – 包含任何关键字参数的 a
在运行时修改 Join 钩子的行为;共享同一联接上下文的所有实例 manager 将转发相同的值。
Joinable
kwargs
这个钩子不支持任何 keyword 参数;即 是 闲置。
kwargs
-
load_state_dict
(state_dict)[来源]¶ 从 input 加载与给定排名相关的状态,根据需要更新本地优化器。
state_dict
- 参数
state_dict (dict) – 优化器状态;应为返回的对象 从对 的调用
。
- 提高
RuntimeError – 如果且此方法为 在此
实例之前调用 已完全初始化,这在梯度桶 重建。
overlap_with_ddp=True
DistributedDataParallel
-
state_dict
()[来源]¶ 返回此排名已知的最后一个全局优化器状态。
- 提高
RuntimeError – 如果且此方法为 在此
实例之前调用 已完全初始化,这在梯度桶 重建;或者,如果调用此方法时没有前面的调用 到
.
overlap_with_ddp=True
DistributedDataParallel