目录

使用 Join Context Manager 进行输入不均匀的分布式训练

创建时间:2021 年 8 月 4 日 |上次更新时间: 2023 年 1 月 9 日 |上次验证: Nov 05, 2024

作者Andrew Gu

注意

编辑github 中查看和编辑本教程。

注意

Join在 PyTorch 1.10 中作为原型功能引入。这 API 可能会发生更改。

在本教程中,您将看到:

  • Join 上下文管理器概述。

  • 如何将上下文管理器与 一起使用的示例。DistributedDataParallel

  • 如何将上下文管理器与 和 一起使用的示例。DistributedDataParallelZeroRedundancyOptimizer

  • 将关键字参数传递给上下文管理器的示例。

  • 深入了解 Join 上下文管理器的工作原理。

  • 演示如何使 toy 类与上下文兼容的示例 经理。

什么?Join

Distributed Data Parallel 入门 - 基本用例中,您看到了 使用 DistributedDataParallel 执行数据的通用框架 并行训练。这将隐式地将 all-reduce 安排在每次向后传递中 以同步跨等级的渐变。这种集体交流需要参与 从进程组中的所有等级中,因此如果等级的输入较少,则 其他 rank 将挂起或出错(取决于后端)。更一般地说,这个 对于执行每次迭代同步的任何类,问题仍然存在 集体通信。

Join是一个上下文管理器,用于 per-rank 训练循环 促进输入不均匀的训练。上下文管理器允许 ranks 他们早早地用尽了他们的投入(即早早加入)来跟随集体 由尚未加入的用户执行的通信。其方式 被影子的通信由 hook 指定。

使用JoinDistributedDataParallel

PyTorch 的 DistributedDataParallel 与上下文管理器一起开箱即用。下面是一个示例用法:Join

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将生成以下输出(其中 s 来自 rank 0 和 等级 1 可以任意排序):print()

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

DistributedDataParallel 提供了自己的 join() 上下文管理器 在引入此通用上下文管理器之前。在 上面的示例,using 等同于使用 。现有 API 的一个限制是它不允许多个 参与课程,例如 和 ZeroRedundancyOptimizer 一起。Joinwith Join([model]):with model.join():DistributedDataParallel.join()DistributedDataParallel

使用 with 和JoinDistributedDataParallelZeroRedundancyOptimizer

上下文管理器不仅可以使用单个类,还可以使用 多个类一起。PyTorch 的也是 与 Context Manager 兼容,因此在这里,我们研究如何修改 上一个同时使用 AND 的示例:JoinZeroRedundancyOptimizerDistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与以前相同的输出。显着的变化是 此外,将实例传入 .ZeroRedundancyOptimizerJoin()

传递关键字参数

类可以提供关键字参数来修改它们在上下文中的行为 manager 的 manager 中。例如,提供 参数 ,它确定渐变是否为 除以初始世界大小或有效世界大小(即数字 未加入的行列)。此类关键字参数可以直接传递到 上下文管理器。DistributedDataParalleldivide_by_initial_world_size

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递给上下文管理器的关键字参数在 所有参与的课程。这不应该是一个限制,因为我们确实 不希望出现多个 S 需要不同设置的情况 相同的参数。尽管如此,这是要记住的。Joinable

如何运作?Join

现在我们已经看到了如何使用上下文管理器的一些初步示例,让我们更深入地研究它是如何工作的。这将提供一个 更深入地了解它提供的全部功能,并为您做好准备 兼容您自己的自定义类。在这里,我们将类作为 以及支持类和 .JoinJoinJoinableJoinHook

Joinable

首先,与上下文管理器兼容的类必须继承 从抽象基类 .特别是,必须 实现:JoinJoinableJoinable

  • join_hook(self, **kwargs) -> JoinHook

这将返回 的实例,确定 加入的进程应该影随每次迭代的集体通信 由 执行 。JoinHookJoinableJoinable

  • join_device(self) -> torch.device

这将返回上下文管理器用来执行的设备 集体通信,例如 或。Jointorch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这会将上下文管理器要使用的流程组返回到 执行集体通信。Join

具体而言,和 是必需的 属性来确保上下文管理器可以调度集合 已加入和未加入的流程之间的通信。一种用法是计数 使用 all-reduce 的每次迭代中未加入的进程数。 另一种用途是实现 所需的机制,我们将在后面解释。join_devicejoin_process_groupthrow_on_early_termination=True

DistributedDataParallel并且已经继承了 from 并实现上述方法,这就是为什么我们可以 在前面的示例中直接使用它们。ZeroRedundancyOptimizerJoinable

Joinable类应确保调用构造函数 因为它初始化了一个实例,该实例在 上下文管理器来确保正确性。这将作为 field 保存在每个 中。JoinableJoinConfigJoinable_join_config

JoinHook

接下来,我们来分解该类。A 提供两个 上下文管理器的入口点:JoinHookJoinHook

  • main_hook(self) -> None

当存在一个 rank 时,每个加入的 rank 都会重复调用这个 hook 尚未加入。它旨在掩盖集体通信 在每次训练迭代中执行(例如,在一个 forward 中 pass、backward pass 和 optimizer step)。Joinable

  • post_hook(self, is_last_joiner: bool) -> None

一旦所有 rank 都加入,就会调用这个 hook。它传递一个附加参数 ,该参数指示排名是否为以下 最后一个加入。该参数可能对同步有用。boolis_last_joiner

为了给出这些钩子可能是什么样子的具体示例,提供的 main 钩子按法线执行一个优化器步骤 由于 join 的 rank 仍然负责更新和同步其 参数的分片,以及提供的 post-hook 从最后加入的 Ranks 之一广播最终更新的模型,以确保 在所有级别中都是相同的。ZeroRedundancyOptimizerDistributedDataParallel

Join

最后,让我们看看这些如何适应类本身。Join

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的示例中所看到的,构造函数接受参与训练循环的 s 列表。这些应该是 在每次迭代中执行集体通信的类。Joinable

enable是一个可以设置为 if 你知道有 的输入不会不均匀,在这种情况下,上下文管理器会变得空洞 似。这也可能禁用与联接相关的 在参与的 S 中计算。boolFalsecontextlib.nullcontext()Joinable

throw_on_early_termination是可以设置为 让每个 rank 在检测到不均匀的输入时引发异常。 这对于不符合上下文管理器的 要求,这通常是在进行集体通信时 来自可以任意交错的不同类,例如与具有 Layers 的模型一起使用时。在 在这种情况下,应将此参数设置为 API,以便应用程序 logic 可以捕获异常并确定如何继续。boolTrueDistributedDataParallelSyncBatchNormTrue

  • 核心逻辑发生在方法中,该方法在 存在一个未连接的 rank,调用 each 的 main 钩子,并且 然后,一旦所有 rank 都加入进来,就称他们的 post 为 hooks。两个主钩子 并且 post-hooks 按照 s 的顺序迭代 传入。__exit__()JoinableJoinable

  • 上下文管理器需要来自未加入的进程的检测信号。因此, 每个类都应调用 Before 其每次迭代的 Collective 通信。上下文管理器将 确保只有第一个传入的 API 实际上会发送 心跳。JoinableJoin.notify_join_context()Joinable

警告

如上所述,上下文管理器与 的某些组合不兼容 类。的 s 必须是可序列化的,因为每个 hook 在继续执行下一个之前完全执行。换句话说,两个 钩子不能重叠。此外,目前,主钩和后 钩子以相同的确定性顺序迭代。如果这似乎是 是一个主要限制,我们可能会修改 API 以允许可自定义 订购。throw_on_early_terminationJoinJoinableJoinHook

使 Toy 类使用Join

由于上一节介绍了几个概念,让我们看看它们 用玩具例子练习。在这里,我们将实现一个类,该类对 在其 rank joins 之前在所有 ranks 中看到的输入数。这 应该提供如何使自己的类兼容的基本概念 与 Context Manager 一起使用。Join

具体来说,下面的代码将每个等级打印出 (1) 个 在加入之前看到的所有等级的输入,以及 (2) 总数 所有等级的输入。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于等级 0 看到 5 个输入,等级 1 看到 6 个输入,因此会产生输出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

需要强调的一些关键点:

  • 实例在每次迭代中执行单个 all-reduce,因此 main hook 也执行一个 all-reduce 来隐藏它。Counter

  • 该类在 其方法的开头,因为这是其 per- 迭代集体通信(即它的 all-reduce)。CounterJoin.notify_join_context()__call__()

  • 该参数用于确定 后钩。is_last_joiner

  • 我们将 keyword 参数传递给上下文管理器 然后将其转发到 的 join 钩子。sync_max_countCounter

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源