目录

FSDP 注释

FSDP 预取细微差别

对于与 compute 重叠的 all-gathers,有两种可能的机制:forwardforward

  1. 隐式正向预取(始终启用)

  2. 显式正向预取 (forward_prefetch=True)

隐式预取是指依赖于从单独的 CUDA 发出所有收集 stream 以允许将 all-gather 与之前发出的计算重叠(从 CPU 视角)。例如,如果我们有第 0 层 all-gather -> 层 0 compute -> 层 1 all-gather -> ...,则第 1 层 all-gather 可以与第 0 层计算重叠,即使 CPU 线程随后发出了它。(第 1 个 all-gather 将无法与任何内容重叠。forwardforwardforwardforward

显式预取是指更改 CPU 线程的 issue 顺序:例如第 0 层 all-gather -> 第 1 层 all-gather -> 第 0 层计算 -> ....在 Eager 模式下,没有办法 通常知道哪个层是下一层(例如,示例中的第 1 层)仍在执行 图层 0.因此,显式预取只应用于执行 顺序在每次迭代中都是固定的(我们有时称之为 “静态图”)。一个 不满足此约束的模型是 FLAVA)。forwardforwardforward

显式预取仅节省在 当下一个 All-gather 的输出张量仍处于当前 正在使用。通过在当前计算内核之前发出下一个 all-gather,下一个 all-gather 可以在 GPU 上更快地启动。对于大多数 LLM 工作负载,情况并非如此,因此没有 启用 .forwardforwardforwardforward_prefetch=True

相反,对于 ,我们必须使用显式预取,否则将有 0 重叠 通信和计算。原因是我们对两者使用单个 NCCL 流程组 all-gather 和 reduce-scatter (部分原因是在早期的 NCCL 版本中,使用起来不安全 多个并发在相同 rank 的同一设备上)。单个 NCCL 流程组表示 单个内部 NCCL 流,reduce-scatter 和 all-gathers 在其上串行运行。因此,除非 我们显式地将 CPU 的 issue 顺序重新排序为下一个 all-gather ->当前的 reduce-scatter,然后 当前的 reduce-scatter 将阻止下一个 all-gather,从而阻止下一个计算 防止当前的 reduce-scatter 重叠。backwardbackwardbackward

通信负载大小

在 FSDP 中,通信是:

  1. all-gather 对 中的 parametersforward

  2. all-gather 对 中的 parametersbackward

  3. backward

如果使用激活 checkpointing (),则没有 额外的通信,因为在 .backward

在 FSDP 设计中,每个等级的通信有效载荷确定如下:每次调用 to 都会创建一个通信组,该组由 中的参数组成,但已分配给嵌套实例的任何参数除外。例如,对于 Llama,如果您将每个 transformer 块和根模块,则每个 transformer 块,最后是一个具有初始嵌入和最终线性的通信组。 每个通信组对应于一个 all-gather 调用和单个 reduce-scatter 调用。在 这样,您的应用方式将决定通信大小。在 通常,将 FSDP 应用于每个变压器块对于 LLM 来说是一个很好的启发式方法,而且很难做到 考虑到目前的设计,这比这要好。FullyShardedDataParallelmodule.parameters()FullyShardedDataParallelFullyShardedDataParallelFullyShardedDataParallel

让我们考虑一个例子,我们有一个基于 Transformer 的模型,该模型被分片到 8 个 GPU 上,其中 分片只发生在 transformer 块级,每个 transformer 块包含 1.6B parameters 和 arguments 以 fp32 为单位(每个 4 字节)。这意味着,一旦分片,每个 transformer 块的每个等级将包含 0.2B 参数。

  • 该通道将在 all-gather 中以块的形式进行通信forward0.2*4 = 0.8GB

  • 该通道将分别通信 2 次(1 次 all-gather 和 1 次 reduce-scatter)backward0.8GB

换句话说,将有 3 个 communications,每个 communications 的 payload 为 。如果模型是 由 10 个变压器块组成,总共有 30 个通信,总共 。0.8GB30*0.8=24GB

要正式确定每个 rank 每个通信的有效负载大小为 (GB)。total_transformer_block_params_in_B*dtype_bytes/num_gpus

请注意,在此示例中,我们没有包括 嵌入,这也应该被考虑在内。数学运算将取决于 input 和 output embeddings 是否绑定。如果他们没有打成平手,则通信量将增加 2 倍。

FSDP 缓冲区大小

首先,我们介绍分配给通信的缓冲区:

forward当前需要 2 倍的 All-gather 缓冲区大小。原因如下:

FSDP 显式预取情况下的预取细微差别中所述 ( 而另一个用于执行预取。forwardforward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward

虽然理论上相同 sequence 的隐式预取 (, default) 情况应该只需要 1 个 buffer,但实际上它仍然是 2 倍全收集大小的 buffers。原因是在 flat-parameter FSDP design中,我们不会复制出 all-gather buffer。用于计算的参数可以直接查看到 all-gather 缓冲区中(实际上,“flat 参数”的主要好处正是这个原因)。在这种情况下,虽然“第 1 层全聚集”与“第 0 层正向计算”重叠,但“第 0 层正向计算”正在使用在“第 0 层全聚集”缓冲区中查看的参数。forwardforward_prefetch=False

那么一个自然的问题是,你什么时候想要?对于静态图模型(如大多数 LLM),有一个主要的技术原因。更多的是,实际上,我们为一些 CPU 密集型内部模型快速添加了此选项,并且没有在单元测试中使用它测试每个代码路径,因此我们对它不太有信心。 可以稍微容易推断一下,因为我们不必将记录的远期订单作为可能的 “失败模式” 进行检查;模块的 All-Gather 始终可以在其 Profiler 跟踪中其自己的标签下找到。forward_prefetch=Falseforward_prefetching=Falserecord_function

backward当前需要至少 2 倍的 All-gather 缓冲区大小,并且可能更多。原因如下:

当前的 FSDP 设计用于管理在一个流中生成的分配,而另一个流中消耗的分配可能会导致内存使用量超出预期。多少可以是 “非确定性的” ,因为它取决于 GPU 内核相对于 CPU 的时序。这个论点是对此的缓解 - 更多详情请参考这个讨论是FSDP & CUDACachingAllocatorrecordStreamlimit_all_gathers=True

现有 FSDP 与 autograd 的工作方式:

  • 现有 FSDP 全部收集 ,即 autograd 叶。flat_param

  • 它调用以获取与其组成原始参数对应的 1D 视图。torch.splitflat_param

  • 它调用每个 1D 分割以回 view 回到 ND。torch.view

  • 这意味着在 中,我们最终得到 (ND -> 1D) 和 (这是一个 concat)。特别是,每个单独的梯度都是作为单独的分配计算的,并且显式 concat 恰好构造 reduce-scatter input 缓冲区。这意味着在该峰值内存点的 reduce-scatter 实际上是 2 倍的缓冲区大小。backwardViewBackwardSplitWithSizesBackward

总之,对于 ,它大约是 reduce-scatter 加上任何效果的 2 倍缓冲区大小。backwardrecordStream

其次,我们来讨论一下额外的缓冲区:

从所有列收集分片参数后,它们需要 total_transformer_block_params_in_B*dtype_bytes 的额外缓冲区才能获得完整参数 - 因此,继续前面的示例,如果每个 transformer 块是 1.6B 参数并且参数在 fp32 中,那么它将是 1.6*4=6.4GB 缓冲区。

并且需要其中的 2 个缓冲区,因为当前正在使用一个缓冲区,另一个正在预取。

总而言之,我们有:

  1. 2 倍的通信缓冲区total_transformer_block_params_in_B*dtype_bytes/num_gpus

  2. 2 倍未分片的 transformer 块参数缓冲区``total_transformer_block_params_in_B*dtype_bytes

或者,如果您一直在遵循以下示例:

  1. 2*1.6*4/8=1.6GB

  2. 2**1.6*4=12.8GB

和 的总数。14.4GB

现在让我们简要讨论一下嵌入会发生什么,因为我们在计算中遗漏了这些嵌入:

鉴于我们讨论的规则,您在注释中包含了以“通信缓冲区” 大小确定如下“,我们可以分析如下:

  • 假设我们将 FSDP 应用于根模块(例如类)。假设我们进一步将 FSDP 应用于每个 transformer 块(例如类)。TransformerTransformerBlock

  • 最常见的是,嵌入和最终线性投影是根类的直接子级。Transformer

  • 按照我们的规则,这意味着嵌入和最终线性投影被分配给 root 的 flat 参数。Transformer

  • 我们有 _another_ 特殊规则,即 root 在 forward 之后不会释放其参数,因为它们无论如何都会立即全部聚集在 backward 中。

  • 综上所述,这意味着根的 flat 参数(包括 embedding 和 final projection)全部收集起来,开始向前传输,并保存在 GPU 内存中,直到 backward 结束。

  • 如果嵌入和最终线性没有权重绑定,那么我们_可以_进一步将 FSDP 应用于嵌入和最终线性。对于权重绑定的参数,我们要求它们成为同一 flat 参数的一部分(否则会被重复计算)。这将允许 embedding 在 forward 中使用后被释放,并且仅在 backward 结束时释放 all-gathered。

  • 希望这能提供更好的理解 – 除了已经分配给另一个嵌套 FSDP 模块的参数外,每个 FSDP 模块都会在其中分配参数,并且 FSDP 模块为其参数定义“实时”间隔。因此,嵌套结构会影响 all-gather/free 计划,从而影响内存/吞吐量性能。module.parametersforwardnn.Module

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源