torch.distributed.fsdp.fully_shard¶
PyTorch FSDP2 (fully_shard)¶
PyTorch FSDP2 提供了全数据并行碎片化(FSDP)实现, 旨在使用按参数分片来提升急切模式的性能,同时改善易用性。
如果您是 FSDP 新手,我们建议您从 FSDP2 开始,因为它在易用性方面有所改进。
如果你目前正在使用 FSDP1,请考虑评估以下差异,看看是否应该切换到 FSDP2:
与 PyTorch FSDP1 (FullyShardedDataParallel)相比:
FSDP2 使用基于
DTensor的 dim-0 按参数分片,与 FSDP1 的扁平参数分片相比,提供了一种更简单的分片表示形式,同时保持了类似的吞吐量性能。具体来说,FSDP2 在 dim-0 上将每个参数在数据并行工作者之间进行分片(使用torch.chunk(dim=0)),而 FSDP1 则会将一组张量展平、连接和分片,这使得理解每个工作者上存在的数据以及重新分片到不同并行度变得复杂。按参数分片提供了更直观的用户体验,放松了对冻结参数的约束,并允许无通信(分片)状态字典,否则在 FSDP1 中需要全收集。FSDP2 实现了一种不同的内存管理方法来处理多流使用,避免了
torch.Tensor.record_stream。这确保了确定性和预期的内存使用,并且不需要像 FSDP1 的limit_all_gathers=True那样阻塞 CPU。FSDP2 提供了手动控制预取和集体调度的API,允许高级用户进行更多自定义。有关详细信息,请参见下面的
FSDPModule方法。FSDP2 简化了一些 API 表面:例如,FSDP2 不直接支持完整的状态字典。相反,用户可以使用
DTensor类似的 API(如DTensor.full_tensor())或通过使用更高级别的 API(如 PyTorch 分布式检查点 的分布式状态字典 API)将包含DTensor的分片状态字典重新分片为完整状态字典。此外,一些其他参数已被移除;详情请参见这里。
如果你是第一次使用 FSDP,或者上述任何一个选项符合你的用例,我们建议你考虑使用 FSDP2。
参见此RFC以获取关于系统设计和实现的详细信息。
注意
torch.distributed.fsdp.fully_shard 当前处于原型状态并在开发中。核心API可能不会发生变化,但如果有必要,我们可能会进行一些API更改。
前端API是fully_shard,可以在module上调用:
- torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy())[source]¶
将完全分片的数据并行性(FSDP)应用于
module,其中FSDP 在数据并行工作者之间分片模块参数、梯度和优化器状态以节省内存,但会增加通信成本。初始化时,FSDP 将模块的参数跨数据并行工作者进行分片,由
mesh给出。在前向传播之前,FSDP 将分片的参数在数据并行工作者之间全部收集,以获得用于前向计算的未分片参数。如果reshard_after_forward是True,那么在前向传播之后,FSDP 释放未分片的参数,并在反向传播之前重新全部收集它们以进行梯度计算。在梯度计算之后,FSDP 释放未分片的参数,并将未分片的梯度跨数据并行工作者进行减少分片。此实现将分片参数表示为
DTensor,在dim-0维度上进行分片,而未分片的参数将保持原始状态 (例如,如果原来是torch.Tensor,则为torch.Tensor)。一个模块的 前向预钩子 在module处执行所有参数的聚合,并且一个模块的 前向钩子 在module处释放它们(如果需要)。类似的反向钩子会聚合参数并在之后释放参数并减少散列梯度。由于将多个张量分组在一起对于通信效率至关重要,此实现使这种分组成为首要任务。在
module上调用fully_shard()会构建一个组,该组包含module.parameters()中的参数,但不包括那些已经从早期对子模块的调用中分配给组的参数。这意味着fully_shard()应该从下往上对您的模型进行调用。每个组的参数都在一个集体操作中全部收集,并且其梯度在一个集体操作中减少分散。将模型划分为多个组(“一层一层”)可以实现峰值内存节省和通信/计算重叠。用户通常不应仅在最顶层根模块上调用fully_shard()。- Parameters
模块 (Union[nn.Module, List[nn.Module]) – 要使用FSDP进行分片并一起进行通信的模块或模块列表。
mesh (可选[DeviceMesh]) – 此数据并行网格定义了分片和设备。如果为1D,则参数将在1D网格上完全分片(FSDP),位置为
(Shard(0),)。如果为2D,则参数将在第1维上分片并在第0维上复制(HSDP),位置为(Replicate(), Shard(0))。网格的设备类型用于通信;如果是CUDA或类似CUDA的设备类型,则使用当前设备。前向传播后重新划分 (Union[bool, int]) –
这控制了前向传播后的参数行为,并可以在内存和通信之间进行权衡:
如果
True,则在前向传播后重新划分参数并在反向传播时重新聚合。如果
False,则在前向传播后保持未分片的参数在内存中,并避免在反向传播中的全gather操作。如果是一个
int,那么这表示在前向传播后重新划分的世界大小。它应该是mesh划分维度大小的一个非平凡因子(即不包括 1 和维度大小本身)。可以选择节点内部大小(例如torch.cuda.device_count())。这样可以在反向传播中的 all-gather 操作在一个较小的世界大小上进行,代价是比设置为True时占用更高的内存。根 FSDP 状态的值被特别设置为
False,作为一种启发式方法,因为其参数通常会立即进行全收集以供反向传播。前向传播后,注册到模块的参数取决于以下情况:如果为
True,则注册的参数为分片参数;如果为False,则为未分片参数;否则,参数会被重新分片到更小的网格中。要在前向传播和反向传播之间修改参数,注册的参数必须为分片参数。对于False或一个int,可以通过手动重新分片(reshard())来实现。
shard_placement_fn (可选[Callable[[nn.Parameter], 可选[Shard]]]) – 此可调用函数可用于覆盖参数的分片放置,以便在非dim-0维度上对参数进行分片。如果此可调用函数返回
Shard放置(而不是None),那么FSDP将根据该放置进行分片(例如Shard(1))。如果在非零维度上进行分片,我们目前要求均匀分片,即该维度上的张量尺寸必须能被FSDP分片网格大小整除。mp_policy (混合精度策略) – 这控制了混合精度策略,该策略为此模块提供了参数/规约混合精度。详见
MixedPrecisionPolicy。offload_policy (OffloadPolicy) – 这控制了卸载策略, 提供了参数/梯度/优化器状态的卸载。有关详细信息,请参见
OffloadPolicy及其子类。
调用 fully_shard(module) 动态构建了一个新的类,该类继承了 type(module) 并且是一个 FSDP 类 FSDPModule。例如,如果我们对一个模块 linear: nn.Linear 调用 fully_shard(linear),那么 FSDP 将构建一个新的类 FSDPLinear 并将 linear 的类型更改为这个新类。
否则,fully_shard 不会改变模块结构和参数的完全限定名。类 FSDPModule 允许在模块上提供一些特定于 FSDP 的方法。
- class torch.distributed.fsdp.FSDPModule(*args, **kwargs)¶
-
- set_is_last_backward(is_last_backward)[source][source]¶
设置下一次反向传播是否为最后一次。在最后一次反向传播时, FSDP 等待挂起的梯度归约并清除用于反向预取的内部数据 结构。这对于微批量处理可能很有用。
- set_modules_to_backward_prefetch(modules)[source][source]¶
设置此 FSDP 模块应在反向传播中显式预取所有集合的 FSDP 模块。这会覆盖默认的反向预取实现,该实现基于反向的前向后顺序来预取下一个 FSDP 模块。
传递一个只包含先前 FSDP 模块的单例列表会产生与默认重叠行为相同的全gather重叠行为。 传递至少长度为二的列表可以实现更激进的重叠,并且会占用更多预留内存。
- Parameters
模块 (列表[FSDPModule]) – 预取的 FSDP 模块。
- set_modules_to_forward_prefetch(modules)[source][source]¶
设置此 FSDP 模块应在前向传播中显式预取全汇聚的 FSDP 模块。预取操作在该模块的全汇聚副本输出之后运行。
传递一个只包含下一个 FSDP 模块的单例列表会产生与默认重叠行为相同的全gather重叠行为,除了预取的全gather会从CPU更早发出。为了实现更激进的重叠,需要传递长度至少为二的列表,并且会占用更多的预留内存。
- Parameters
模块 (列表[FSDPModule]) – 预取的 FSDP 模块。
- set_post_optim_event(event)[source][source]¶
为根 FSDP 模块设置一个优化器步骤后的事件,等待 all-gather 流。
默认情况下,根 FSDP 模块会在当前流上等待 all-gather 流程,以确保在进行 all-gather 之前优化器步骤已完成。然而,如果在优化器步骤之后有不相关的计算,这可能会引入虚假依赖。此 API 允许用户提供自己的事件来等待。根模块在等待事件后会丢弃该事件,因此每次迭代时应调用一个新的事件。
- Parameters
事件 (torch.Event) – 在优化器步骤后记录的事件 以等待全部收集流。
- set_reduce_scatter_divide_factor(factor)[source][source]¶
设置自定义的分割因子用于减少散射。这成为一个自定义的减少操作,使用 NCCL 的 PreMulSum,在减少之前可以乘以该因子。
- Parameters
因子 (浮点数) – 自定义除法因子。
- set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]¶
设置模块是否进行全归约梯度。这可以用于仅使用归约散射而不使用全归约来实现梯度累加,以支持HSDP。
- set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]¶
设置模块是否同步梯度。这可以用于实现无需通信的梯度累加。对于HSDP,这同时控制了减少散射和全归约。
- set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]¶
设置模块是否在反向传播后重新分配参数。这可以在梯度累积期间用于权衡更高的内存以减少通信,因为无需在下一次前向传播之前重新聚集未分片的参数。
- set_unshard_in_backward(unshard_in_backward)[source][source]¶
设置 FSDP 模块的参数在反向传播时是否需要取消分片。这可以在某些专业情况下使用,当用户知道此 FSDP 模块的参数组中的所有参数都不需要用于反向计算(例如嵌入层)时。
- unshard(async_op=False)[source][source]¶
合并模块的参数,通过分配内存和收集所有参数来实现。此方法不会递归调用。合并操作会遵循
MixedPrecisionPolicy,因此如果设置了param_dtype,它将按照该设置进行收集。- Parameters
async_op (bool) – 如果
True,则返回一个UnshardHandle具有wait()方法来等待未分片操作。如果False,则返回None并在此函数内部等待句柄。- Return type
注意
如果
async_op=True,那么FSDP将在模块的pre-forward前等待未分片操作完成。用户只需要在需要提前等待时显式调用wait()。
- class torch.distributed.fsdp.UnshardHandle¶
一个用于等待的句柄
FSDPModule.unshard()操作。
- torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]¶
在
module上注册一个方法,将其视为FSDP的前向方法。FSDP 在前向传播之前聚合参数,并在后向传播之后根据需要释放参数(取决于
reshard_after_forward)。默认情况下,FSDP 只会对nn.Module.forward()执行此操作。该函数会为用户指定的方法打补丁,在方法执行前后分别运行前向/后向传播钩子。如果module不是一个FSDPModule,那么这将是一个无操作。
- class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)¶
此配置设置了 FSDP 的混合精度。与自动类型转换不同,这在模块级别应用混合精度,而不是操作级别。这意味着低精度激活值会在反向传播时被保存,并且仅在模块边界处才会发生从高精度到低精度的转换。
FSDP 与模块级别的混合精度配合得很好,因为它无论如何都会在内存中保留高精度的分片参数。换句话说,FSDP 不需要额外的内存来为优化器步骤保留高精度的参数副本。
- Variables
param_dtype (可选[torch.dtype]) – 这指定了未分片参数的数据类型,因此也指定了前向/后向计算和参数全收集的数据类型。如果该值为
None,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型的分片参数。(默认值:None)reduce_dtype (可选[torch.dtype]) – 这指定了梯度归约(即减少散射或全归约)的dtype。如果这是
None但param_dtype不是None,那么归约使用计算dtype。这可以用于以全精度运行梯度归约,同时以低精度进行计算。如果也通过set_requires_gradient_sync()禁用梯度归约, 那么FSDP将使用reduce_dtype积累梯度。(默认值:None)输出数据类型 (可选[torch.dtype]) – 这指定了用于转换浮点前向输出的数据类型。这可以用于帮助实现不同模块具有不同混合精度策略的情况。(默认值:
None)cast_forward_inputs (bool) – 这指定了FSDP是否应将前向传递的浮点输入张量转换为
param_dtype。
- class torch.distributed.fsdp.OffloadPolicy¶
此类表示不卸载的策略,仅用作
offload_policy参数的默认值。
- class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)¶
此卸载策略将参数、梯度和优化器状态卸载到CPU。分片参数在全gather之前从主机复制到设备。根据
reshard_after_forward释放全gather后的参数。反向传播时,分片梯度从设备复制到主机,并且优化器步骤在具有CPU优化器状态的CPU上运行。- Variables
pin_memory (bool) – 是否固定分片参数和梯度内存。固定内存可以提高H2D/D2H复制效率,并使复制与计算重叠。然而,固定内存不能被其他进程使用。如果你的CPU内存不足,请将其设置为
False。(默认值:True)