FSDP 笔记¶
FSDP 预取细微差别¶
对于重叠的 forward 全部收集与 forward 计算,有两种可能的机制:
隐式前向预取(始终启用)
显式前向预取 (
forward_prefetch=True)
隐式 forward 预取是指依赖于从单独的CUDA
流中发出所有收集操作,以允许与之前发出的forward 计算(从CPU的角度来看)重叠。例如,如果我们有层0的所有收集 -> 层0 forward 计算 -> 层1
所有收集 -> …,那么层1的所有收集可以与层0 forward 计算重叠,即使CPU线程在之后发出它。(第1个所有收集将无法与任何内容重叠。)
显式 forward 预取指的是改变CPU线程的发布顺序:例如,第0层
全聚合 -> 第1层全聚合 -> 第0层 forward 计算 -> …. 在急切模式下,通常无法知道在执行
第0层时下一个层是哪一层(例如示例中的第1层)。因此,显式 forward 预取仅应用于每次迭代执行
顺序固定的模型(我们有时称之为“静态图”)。不满足此约束条件的一个模型示例是 FLAVA)。
显式 forward 预取仅节省了发出层的 forward 计算内核所需的时间,代价是必须在当前张量仍在使用时分配下一个全聚操作的输出张量。通过在当前 forward 计算内核之前发出下一个全聚操作,可以在GPU上更早地开始下一个全聚操作。对于大多数LLM工作负载,情况并非如此,因此没有启用 forward_prefetch=True 的动机。
相比之下,对于backward,我们必须使用显式的backward预取,否则通信和计算将没有重叠。原因是我们在all-gather和reduce-scatter中都使用了一个NCCL进程组(部分原因是早期的NCCL版本中,在同一设备上对相同秩同时使用多个进程组是不安全的)。一个NCCL进程组意味着一个内部的NCCL流,reduce-scatters和all-gathers在该流上顺序运行。因此,除非我们显式地重新排序CPU的发出顺序为下一个all-gather -> 当前的reduce-scatter,否则当前的reduce-scatter会阻塞下一个all-gather,从而阻止下一个backward计算,防止当前的reduce-scatter与之重叠。
通信负载大小¶
在FSDP中,通信是:
参数的全聚操作在
forward参数的全聚操作在
backward在
backward中对梯度进行reduce-scatter
如果使用了激活检查点(checkpoint()),则不会产生额外的通信,因为参数在backward期间无论如何都会被预取。
在FSDP设计中,每个秩的通信负载确定如下:每次调用FullyShardedDataParallel都会创建一个包含module.parameters()中的参数(除了任何已分配给嵌套FullyShardedDataParallel实例的参数)的通信组。例如,对于Llama,如果你将FullyShardedDataParallel应用于每个transformer块以及根模块,那么每个transformer块有一个通信组,最后还有一个包含初始嵌入和最终线性层的通信组。每个通信组对应一次all-gather调用和一次reduce-scatter调用。因此,你如何应用FullyShardedDataParallel决定了通信大小。一般来说,对于LLMs,将FSDP应用于每个transformer块是一个很好的启发式方法,在当前设计下很难做得更好。
让我们考虑一个例子,其中我们有一个基于Transformer的模型分布在8个GPU上,分片仅在Transformer块级别进行,每个Transformer块包含1.6B参数,并且这些参数为fp32(每个4字节)。这意味着一旦分片后,每个Transformer块将在每个秩上包含0.2B参数。
第二次传递将以
forward次0.2*4 = 0.8GB块的形式在所有节点间进行通信第
backward轮将通信 2 次0.8GB每次(1次全聚和1次减少分散)
换句话说,将会有3次通信,每次的负载为0.8GB。如果模型由10个Transformer块组成,则总共会有30次通信,总负载为30*0.8=24GB。
为了正式确定每个通信的每个等级的有效载荷大小是
total_transformer_block_params_in_B*dtype_bytes/num_gpus(GB)。
请注意,在这个示例中我们没有包含嵌入所需的额外通信,这也应该被考虑在内。数学计算将取决于输入和输出嵌入是否绑定。如果它们没有绑定,通信量将会增加两倍。
FSDP 缓冲区大小¶
首先,让我们来介绍为通信分配的缓冲区:
forward 当前需要2倍的全聚合缓冲区大小。原因如下:
如在FSDP 预取细节中所解释的,在显式forward预取的情况下
(forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1
all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward,而另一个用于执行预取。
虽然理论上隐式forward预取(forward_prefetch=False,默认值)的相同序列只需要1个缓冲区,但实际上仍然是2倍的全聚大小缓冲区。原因是,在扁平参数FSDP设计中,我们不会从全聚缓冲区复制出来。用于计算的参数直接映射到全聚缓冲区(事实上,“扁平参数”的主要好处正是这个原因)。在这种情况下,当“第1层全聚”与“第0层前向计算”重叠时,“第0层前向计算”正在使用映射到“第0层全聚”缓冲区中的参数。
那么一个自然的问题是,你什么时候会想要 forward_prefetch=False?对于静态图模型(如大多数LLM),有一个主要的技术原因。实际上,我们为了某些CPU受限的内部模型快速添加了这个选项,并且还没有在单元测试中测试所有代码路径,所以我们对它的信心不足。 forward_prefetching=False 可能稍微容易理解一些,因为我们不需要检查记录的前向顺序作为可能的“失败模式”;模块的所有gather操作在其性能分析跟踪中的标签始终为 record_function。
backward 当前至少需要两倍的全聚缓冲区大小,可能还需要更多。原因如下:
当前的FSDP设计使用recordStream来管理在一个流中生成但在另一个流中消耗的分配,这可能导致比预期更多的内存使用。额外的内存使用量可能是“不确定的”,因为它取决于GPU内核定时与CPU的相对关系。 limit_all_gathers=True参数是对此问题的一种缓解措施 - 有关更多详细信息,请参阅此讨论:FSDP & CUDACachingAllocator。
现有的FSDP与自动梯度计算(autograd)的工作方式:
现有的FSDP对
flat_param进行全聚合,这是自动梯度的叶子节点。它调用
torch.split以获取与构成原始参数对应的flat_param的1D视图。它在每个1D分割上调用
torch.view以查看返回到ND。这意味着在
backward中,我们最终得到ViewBackward(ND -> 1D)和SplitWithSizesBackward(这是一个连接操作)。特别是,每个单独的梯度都是作为单独的分配计算的,并且会进行显式的连接以构建reduce-scatter输入缓冲区。这实际上意味着在峰值内存点处reduce-scatter的缓冲区大小为2倍。
总之,对于backward,它大约是reduce-scatter的两倍缓冲区大小加上任何recordStream效果。
其次,让我们讨论一下额外的缓冲区:
在收集来自所有 ranks 的分片参数后,它们需要一个额外的 total_transformer_block_params_in_B*dtype_bytes 缓冲区用于完整参数 - 因此继续前面的例子,如果每个 transformer 块有 1.6B 参数且参数为 fp32,则需要 1.6*4=6.4GB 个缓冲区。
并且需要2个这样的缓冲区,因为当前正在使用一个,而另一个正在预取。
总而言之,我们有:
2 倍通信缓冲区的
total_transformer_block_params_in_B*dtype_bytes/num_gpus2 倍未分片的变压器块参数缓冲区
``total_transformer_block_params_in_B*dtype_bytes
或者如果你一直在跟随这个示例:
2*1.6*4/8=1.6GB2**1.6*4=12.8GB
和总数为 14.4GB。
现在让我们简要讨论一下嵌入层在我们将其从计算中排除后会发生什么变化:
根据我们在笔记中讨论并包含的规则,该规则以“通信缓冲区大小如下确定”开始,我们可以进行如下分析:
假设我们将FSDP应用于根模块(例如,
Transformer类)。假设我们进一步将FSDP应用于每个Transformer块(例如,TransformerBlock类)。通常情况下,嵌入和最终的线性投影是根
Transformer类的直接子元素。根据我们的规则,这意味着嵌入和最终的线性投影被分配给根
Transformer的扁平参数。我们有另一个特殊规则,即根节点在前向传播后不会释放其参数,因为这些参数无论如何都会在反向传播中立即被全部收集。
将这些内容结合在一起,这意味着根节点的扁平参数,包括嵌入和最终投影,在开始前向传播时都会被收集到一起,并且在反向传播结束之前一直保留在GPU内存中。
如果嵌入层和最终的线性层没有权重共享,那么我们可以进一步将FSDP应用于嵌入层和最终的线性层。对于权重共享的参数,我们需要它们成为同一个扁平参数的一部分(否则它会被双重计算)。这将允许在前向传播使用后释放嵌入层,并且只在反向传播结束时进行全局聚合。
希望这能更好地说明——每个FSDP模块都会为其
module.parameters分配参数,除了那些已经分配给另一个嵌套的FSDP模块的参数,而FSDP模块的forward定义了其参数的“活动”区间。因此,嵌套的nn.Module结构可以影响所有gather/free调度,从而影响内存/吞吐量性能。