泛型连接上下文管理器¶
通用连接上下文管理器有助于在不均匀上进行分布式训练
输入。本页概述了相关类的 API:、 和 。有关教程,请参阅使用 Join Context Manager 进行输入不均匀的分布式训练。Join
Joinable
JoinHook
-
类(可加入对象,enable=True,throw_on_early_termination=False,**kwargs)[来源]
torch.distributed.algorithms.
Join
¶ 此类定义了通用连接上下文管理器,它允许自定义 要在进程加入后调用的钩子。这些钩子应该隐藏 未加入的进程的集体通信,以防止挂起和 错误并确保算法的正确性。指
JoinHook
了解有关 Hook 定义的详细信息。警告
上下文管理器要求每个参与
Joinable
自 调用方法notify_join_context()
在它自己的 per- iteration collective 通信来确保正确性。警告
上下文管理器要求将 这
process_group
JoinHook
对象是相同的。如果有多个JoinHook
objects,则使用第一个的 。 进程组和设备信息用于检查非 加入的进程和通知进程在启用时引发异常,这两者都使用全 减少。device
throw_on_early_termination
- 参数
例:
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
-
类 [来源]
torch.distributed.algorithms.
Joinable
¶ 这为可连接类定义了一个抽象基类。可加入的类 (继承自
Joinable
) 应实现join_hook()
, ,它返回一个JoinHook
实例,除了join_device()
和join_process_group()
返回设备并 process group 信息。-
abstract 属性
join_device
¶ 返回从中执行集体通信的设备 Join Context Manager 实现本身需要。
-
abstract 属性
join_process_group
¶ 返回 所需的集体通信的进程组 连接上下文管理器本身。
-
abstract 属性
-
类 [来源]
torch.distributed.algorithms.
JoinHook
¶ 这定义了一个 join 钩子,它在 join 中提供了两个入口点 上下文管理器:一个主钩子,当存在时被重复调用 一个未加入的进程,以及一个 post-hook,它被调用一次所有进程 已加入。
要为通用连接上下文管理器实现连接钩子,请定义一个 继承自
JoinHook
并酌情覆盖 和。main_hook()
post_hook()