泛型连接上下文管理器¶
通用连接上下文管理器有助于在不均匀上进行分布式训练
输入。本页概述了相关类的 API:、 和 。有关教程,请参阅使用 Join Context Manager 进行输入不均匀的分布式训练。Join
Joinable
JoinHook
- 类 torch.distributed.algorithms 中。Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[来源]¶
此类定义了通用连接上下文管理器,它允许在进程连接后调用自定义钩子。
这些钩子应该隐藏 未加入的进程的集体通信,以防止挂起和 错误并确保算法的正确性。有关 hook 定义的详细信息,请参阅 。
警告
上下文管理器要求将 对象
是相同的。如果有多个
对象,则使用第一个对象。 进程组和设备信息用于检查非 加入的进程和通知进程在启用时引发异常,这两者都使用全 减少。
process_group
device
throw_on_early_termination
- 参数
例:
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
- 静态notify_join_context(可加入)[来源]¶
通知连接上下文管理器调用进程尚未加入。
然后,如果 , 则检查是否检测到不均匀的输入 (即,如果一个进程已经加入),如果是这样,则抛出异常。
throw_on_early_termination=True
此方法应从
对象调用 它的每次迭代集体通信。例如,这应该 在 中的前向传递开始时调用 。
DistributedDataParallel
仅传递到上下文中的第一个
对象 manager 在此方法中执行集体通信,并且 对于其他人来说,这种方法是空洞的。
- 参数
joinable (Joinable) –
调用此 方法。
- 返回
all-reduce 的异步工作句柄,用于通知上下文 进程尚未加入的 manager (如果为 第一个传递到上下文管理器中; 否则。
joinable
None
- 类 torch.distributed.algorithms 中。可加入[来源]¶
这为可连接类定义了一个抽象基类。
可加入的类 (inheriting from
) 应该实现
, 它返回一个
实例,此外还返回 Device 和
process group 信息。