泛型连接上下文管理器¶
通用连接上下文管理器有助于在不均匀上进行分布式训练
输入。本页概述了相关类的 API:、 和 。有关教程,请参阅使用 Join Context Manager 进行输入不均匀的分布式训练。Join
Joinable
JoinHook
-
类(可加入对象,enable=True,throw_on_early_termination=False,**kwargs)[来源]
torch.distributed.algorithms.
Join
¶ 此类定义了通用连接上下文管理器,它允许自定义 要在进程加入后调用的钩子。这些钩子应该隐藏 未加入的进程的集体通信,以防止挂起和 错误并确保算法的正确性。有关 hook 定义的详细信息,请参阅 。
警告
上下文管理器要求将 对象
是相同的。如果有多个
对象,则使用第一个对象。 进程组和设备信息用于检查非 加入的进程和通知进程在启用时引发异常,这两者都使用全 减少。
process_group
device
throw_on_early_termination
- 参数
例:
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
-
static (可加入)[来源]
notify_join_context
¶ 通知连接上下文管理器调用进程尚未 加入;然后,如果 , 检查是否不均匀 检测到输入(即,如果一个进程已加入)和 如果是这样,则引发异常。
throw_on_early_termination=True
此方法应从
对象调用 它的每次迭代集体通信。例如,这应该 在 中的前向传递开始时调用 。
DistributedDataParallel
仅传递到上下文中的第一个
对象 manager 在此方法中执行集体通信,并且 对于其他人来说,这种方法是空洞的。
- 参数
joinable (Joinable) –
调用此 方法。
- 返回
all-reduce 的异步工作句柄,用于通知上下文 进程尚未加入的 manager (如果为 第一个传递到上下文管理器中; 否则。
joinable
None
-
类 [来源]
torch.distributed.algorithms.
Joinable
¶ 这为可连接类定义了一个抽象基类。可加入的类 (inheriting from
) 应该实现
, 它返回一个
实例,此外还返回 Device 和
process group 信息。
-
abstract 属性
join_device
¶ 返回从中执行集体通信的设备 Join Context Manager 实现本身需要。
-
摘要 (**kwargs)[来源]
join_hook
¶ 返回给定
的实例
.
- 参数
kwargs (dict) – 包含任何关键字参数的 a
在运行时修改 Join 钩子的行为;共享同一联接上下文的所有
实例 manager 将转发相同的值。
kwargs
-
abstract 属性
join_process_group
¶ 返回 所需的集体通信的进程组 连接上下文管理器本身。
-
abstract 属性
-
类 [来源]
torch.distributed.algorithms.
JoinHook
¶ 这定义了一个 join 钩子,它在 join 中提供了两个入口点 上下文管理器:一个主钩子,当存在时被重复调用 一个未加入的进程,以及一个 post-hook,它被调用一次所有进程 已加入。
要为通用连接上下文管理器实现连接钩子,请定义一个 继承自
and override 的类。
main_hook()
post_hook()