泛型连接上下文管理器¶
通用连接上下文管理器有助于在不均匀上进行分布式训练
输入。本页概述了相关类的 API:、 和 。有关教程,请参阅使用 Join Context Manager 进行输入不均匀的分布式训练。Join
Joinable
JoinHook
- 类 torch.distributed.algorithms 中。Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[来源]¶
此类定义了通用连接上下文管理器,它允许自定义 要在进程加入后调用的钩子。这些钩子应该隐藏 未加入的进程的集体通信,以防止挂起和 错误并确保算法的正确性。指
JoinHook
了解有关 Hook 定义的详细信息。警告
上下文管理器要求每个参与
Joinable
自 调用方法notify_join_context()
在它自己的 per- iteration collective 通信来确保正确性。警告
上下文管理器要求将 这
process_group
JoinHook
对象是相同的。如果有多个JoinHook
objects,则使用第一个的 。 进程组和设备信息用于检查非 加入的进程和通知进程在启用时引发异常,这两者都使用全 减少。device
throw_on_early_termination
- 参数
例:
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
- 类 torch.distributed.algorithms 中。可加入[来源]¶
这为可连接类定义了一个抽象基类。可加入的类 (继承自
Joinable
) 应实现join_hook()
, ,它返回一个JoinHook
实例,除了join_device()
和join_process_group()
返回设备并 process group 信息。
- 类 torch.distributed.algorithms 中。JoinHook[来源]¶
这定义了一个 join 钩子,它在 join 中提供了两个入口点 上下文管理器:一个主钩子,当存在时被重复调用 一个未加入的进程,以及一个 post-hook,它被调用一次所有进程 已加入。
要为通用连接上下文管理器实现连接钩子,请定义一个 继承自
JoinHook
并酌情覆盖 和。main_hook()
post_hook()