目录

泛型连接上下文管理器

通用连接上下文管理器有助于在不均匀上进行分布式训练 输入。本页概述了相关类的 API:、 和 。有关教程,请参阅使用 Join Context Manager 进行输入不均匀的分布式训练JoinJoinableJoinHook

可加入对象enable=True,throw_on_early_termination=False**kwargs[来源]torch.distributed.algorithms.Join

此类定义了通用连接上下文管理器,它允许自定义 要在进程加入后调用的钩子。这些钩子应该隐藏 未加入的进程的集体通信,以防止挂起和 错误并确保算法的正确性。有关 hook 定义的详细信息,请参阅 。

警告

上下文管理器要求每个参与 在该方法自己的 per- iteration collective 通信来确保正确性。

警告

上下文管理器要求将 对象是相同的。如果有多个对象,则使用第一个对象。 进程组和设备信息用于检查非 加入的进程和通知进程在启用时引发异常,这两者都使用全 减少。process_groupdevicethrow_on_early_termination

参数
  • joinablesList[Joinable]) —— 参与 s 的列表;它们的钩子在给定的 次序。

  • enablebool) – 启用不均匀输入检测的标志;设置为 将禁用上下文管理器的功能,并且应该 仅当用户知道输入不会不均匀时才设置 (默认值:)。FalseTrue

  • throw_on_early_terminationbool) – 控制是否抛出 检测到不均匀的输入时出现异常(默认值:)。False

例:

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
static 可加入[来源]notify_join_context

通知连接上下文管理器调用进程尚未 加入;然后,如果 , 检查是否不均匀 检测到输入(即,如果一个进程已加入)和 如果是这样,则引发异常。throw_on_early_termination=True

此方法应从对象调用 它的每次迭代集体通信。例如,这应该 在 中的前向传递开始时调用 。DistributedDataParallel

仅传递到上下文中的第一个对象 manager 在此方法中执行集体通信,并且 对于其他人来说,这种方法是空洞的。

参数

joinableJoinable) – 调用此 方法。

返回

all-reduce 的异步工作句柄,用于通知上下文 进程尚未加入的 manager (如果为 第一个传递到上下文管理器中; 否则。joinableNone

[来源]torch.distributed.algorithms.Joinable

这为可连接类定义了一个抽象基类。可加入的类 (inheriting from ) 应该实现 , 它返回一个实例,此外还返回 Device 和 process group 信息。

abstract 属性join_device

返回从中执行集体通信的设备 Join Context Manager 实现本身需要。

摘要 **kwargs[来源]join_hook

返回给定 实例 .

参数

kwargsdict) – 包含任何关键字参数的 a 在运行时修改 Join 钩子的行为;共享同一联接上下文的所有实例 manager 将转发相同的值。kwargs

abstract 属性join_process_group

返回 所需的集体通信的进程组 连接上下文管理器本身。

[来源]torch.distributed.algorithms.JoinHook

这定义了一个 join 钩子,它在 join 中提供了两个入口点 上下文管理器:一个主钩子,当存在时被重复调用 一个未加入的进程,以及一个 post-hook,它被调用一次所有进程 已加入。

要为通用连接上下文管理器实现连接钩子,请定义一个 继承自 and override 的类。main_hook()post_hook()

main_hook()[来源]

当存在未加入的进程时,将重复调用此 hook 在一次训练迭代中影子集体通信(即在 一个 FORWARD PASS、BACKWARD PASS 和 Optimizer 步骤)。

post_hook(is_last_joiner[来源]

此 hook 在所有进程都加入后调用。它传递了一个 additional 参数,该参数指示 rank 是最后加入的 RANK 之一。boolis_last_joiner

参数

is_last_joinerbool) – 如果排名是最后一个 加入; 否则。TrueFalse

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源