目录

torch.distributed.fsdp.fully_shard

PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 提供完全分片的数据并行 (FSDP) 实现 以高性能 Eager-Mode 为目标,同时使用每参数分片以改进 可用性。

  • 如果您是 FSDP 的新用户,我们建议您从 FSDP2 开始,因为 FSDP2 得到了改进 可用性。

  • 如果您当前正在使用 FSDP1,请考虑评估以下内容 差异以查看是否应该切换到 FSDP2:

与 PyTorch FSDP1 () 相比:FullyShardedDataParallel

  • FSDP2 使用基于 -dim-0 的每参数分片,以实现更简单的 分片表示与 FSDP1 的 flat-parameter 分片相比,而 保持类似的吞吐量性能。更具体地说,FSDP2 块 数据并行工作程序中 dim-0 上的每个参数(使用 ),而 FSDP1 展平、连接和分块 a 一组张量,对 每个工作线程和重新分片到不同的并行度很复杂。每个参数 分片提供更直观的用户体验,放宽约束 围绕冻结参数,并允许无通信(分片)状态 dicts 中,否则需要在 FSDP1 中使用 all-gathers。DTensortorch.chunk(dim=0)

  • FSDP2 实现了一种不同的内存管理方法来处理 多流使用,以避免 .这确保了 确定性和预期的内存使用情况,并且不需要阻塞 CPU 就像在 FSDP1 中一样。torch.Tensor.record_streamlimit_all_gathers=True

  • FSDP2 公开了用于手动控制预取和集体的 API 计划,允许高级用户进行更多自定义。有关详细信息,请参阅下面的方法。FSDPModule

  • FSDP2 简化了一些 API 表面:例如,FSDP2 不直接 支持全状态 dict。相反,用户可以对分片的状态字典进行重新分片 使用 API 将 s 包含到完整状态字典本身,或者使用更高级别的 API,如 PyTorch Distributed Checkpoint 的 分布式状态 dict API。此外,还删除了一些其他 args;请在此处查看 详。DTensorDTensorDTensor.full_tensor()

如果您是第一次加入 FSDP,或者如果上述任何一项对 您的使用案例,我们建议您考虑使用 FSDP2。

有关详细信息,请参阅此 RFC 关于系统设计和实施。

注意

torch.distributed.fsdp.fully_shard当前处于原型状态,并且 正在开发中。核心 API 可能不会更改,但我们可能会制作一些 如有必要,请更改 API。

前端 API 可以在 :fully_shardmodule

torch.distributed.fsdp 中。fully_shardmodule*mesh=Nonereshard_after_forward=Trueshard_placement_fn=Nonemp_policy=MixedPrecisionPolicy(param_dtype=无, reduce_dtype=无, output_dtype=无, cast_forward_inputs=True)offload_policy=OffloadPolicy()[来源]

将完全分片数据并行 (FSDP) 应用于 ,其中 FSDP 跨数据分片模块参数、梯度和优化器状态 并行 worker,以牺牲通信为代价来节省内存。module

在初始化时,FSDP 会在数据中对模块的参数进行分片 parallel worker 由 给出。在转发之前,FSDP 会全部收集 sharded 参数以获取未分片的 参数进行前向计算。如果为 ,则 FSDP 在 forward 和 re-all-在 Gradient 计算之前将它们向后收集。渐变后 计算时,FSDP 会释放未分片的参数并 reduce-scatter 跨数据并行工作线程的未分片梯度。meshreshard_after_forwardTrue

此实现将分片参数表示为 s sharded 在 dim-0 上分片,而未分片的参数将与原始 参数 on (例如,如果最初是 )。all-gather 上的模块 forward pre-hook 会收集参数,而 module forward hook on 会释放它们(如果需要)。类似的 backward hooks all-gather 参数,以及后来的 free parameters 和 reduce-scatter gradients。DTensormodulemodulemodule

由于将多个张量分组为一个集合对于以下 通信效率,此实现使此分组优先 类。调用 on 会构造一个组,该组 包括 except those already 中的参数 从子模块上的先前调用分配给组。这意味着 这应该在你的模型上称为 bottom-up。每个组的 参数全部聚集在一个集合中,其渐变为 reduce-scattered 在一个集合中。将模型划分为多个 组(“逐层”)允许峰值内存节省和通信/计算 重叠。用户通常不应只在 最顶层的根模块。modulemodule.parameters()

参数
  • moduleUnion[nn.模块List[nn.Module]) – 要 与 FSDP 分片并组合在一起进行通信。

  • meshOptional[DeviceMesh]) – 此数据并行网格定义 sharding 和 device。如果为 1D,则参数完全分片 跨 1D 网格 (FSDP) 进行放置。如果为 2D,则 然后,参数将在第 1 个 Dim 上分片并复制 穿过第 0 个维度 (HSDP) 并放置。网格的设备类型给出了用于 通信;如果是 CUDA 或类似 CUDA 的设备类型,那么我们使用 当前设备。(Shard(0),)(Replicate(), Shard(0))

  • reshard_after_forwardUnion[boolint]) –

    这将控制参数 forward 之后的行为,并且可以在内存和通信之间进行权衡:

    • 如果 ,则 this 会在 forward 和 re-all-gathers in backward.True

    • 如果 ,则这会将未分片的参数保留在内存中 在 forward 之后,避免 backward 中的所有聚集。False

    • 如果为 ,则表示要重新分片到的世界大小 转发后。它应该是分片 dim 大小的非平凡除数(即不包括 1 和 dim 大小本身)。一个 choice 可以是节点内大小(例如 )。 这允许 backward 中的 all-gather 位于较小的 World 上 size 的代价是内存使用量高于设置为 .intmeshtorch.cuda.device_count()True

    • 根 FSDP 状态的值专门设置为 启发式的,因为它的参数通常是立即的 全集为向后。False

    • 转发后,注册到模块的参数取决于 更改为:注册的参数是分片参数 if ;unsharded 参数 if ;和参数 否则重新分片到较小的网格。修改参数 在 forward 和 backward 之间,注册的参数必须为 分片参数。对于 或 an ,这可以是 通过 手动重新分片 完成。TrueFalseFalseintreshard()

  • shard_placement_fn可选[Callable[[nn.Parameter]Optional[Shard]]]) – 此可调用对象可用于覆盖 parameter 在 dim-0 以外的维度上对参数进行分片。如果 此 callable 返回一个 placement (not ), 然后 FSDP 将根据该位置进行分片(例如 )。 如果在非零 dim 上分片,我们目前需要均匀分片, 即该 dim 上的 tensor dim 大小必须能被 FSDP 整除 分片网格大小。ShardNoneShard(1)

  • mp_policyMixedPrecisionPolicy) – 此参数控制混合精度 策略,为此提供 parameter/reduction 混合精度 模块。有关详细信息,请参阅

  • offload_policyOffloadPolicy) – 这将控制卸载策略, 它提供 parameter/gradient/optimizer state offloading。有关详细信息,请参见 及其子类。

调用 动态构造一个新类,该类 子类和 FSDP 类 。例如,如果 我们调用一个模块,然后 FSDP 构造一个新类并将 的 type 更改为 this。 否则,不改变模块结构和参数 完全限定名称。该类允许提供一些 特定于 FSDP 的方法。fully_shard(module)type(module)FSDPModulefully_shard(linear)linear: nn.LinearFSDPLinearlinearfully_shardFSDPModule

torch.distributed.fsdp 中。FSDPModule*args**kwargs)
reshard[来源][来源]

重新分片模块的参数,如果 它们被分配并将分片参数注册到 模块。此方法不是递归的。

set_is_last_backwardis_last_backward[来源][来源]

设置下一个向后是否为最后一个。在最后一个向后, FSDP 等待待定的梯度降低并清除内部数据 用于向后预取的数据结构。这可能对 微批处理。

set_modules_to_backward_prefetch模块[来源][来源]

设置此 FSDP 模块应显式为其执行的 FSDP 模块 在 Backward 中预取 all-gathers。这将覆盖默认的向后 pretching 实现,它根据 反向后正向顺序。

传递包含先前 FSDP 模块的单例列表会得到 与默认重叠行为相同的全聚集重叠行为。 传递长度至少为 2 的列表是更激进的必要条件 overlap 的 intent 和将占用更多的预留内存。

参数

modulesList[FSDPModule]) – 要预取的 FSDP 模块。

set_modules_to_forward_prefetch模块[source][source]

设置此 FSDP 模块应显式为其执行的 FSDP 模块 在 forward 中预取所有集合。预取在此之后运行 module 的 all-gather copy-out。

传递包含下一个 FSDP 模块的单例列表会得到相同的 all-gather overlap 行为作为默认重叠行为,但 预取的 all-gather 较早从 CPU 发出。传递列表 长度至少为 2 是更激进的重叠所必需的,并且 将使用更多保留内存。

参数

modulesList[FSDPModule]) – 要预取的 FSDP 模块。

set_post_optim_event事件[来源][来源]

为根 FSDP 模块设置 post-optimizer-step 事件以等待 all-gather 流打开。

默认情况下,根 FSDP 模块等待 current stream 来确保优化器步骤之前已完成 全能。但是,如果 在 Optimizer 步骤之后有 unrelated computation。此 API 允许用户提供自己的事件来等待。在根之后 等待事件,则事件会被丢弃,所以这个 API 应该是 调用每个迭代都有一个新事件。

参数

事件Torch.Event) – 在 optimizer 步骤之后记录的事件 等待 All-gather Streams 打开。

set_reduce_scatter_divide_factor因子[来源][来源]

设置 reduce-scatter 的自定义分割因子。这将成为一个 使用 NCCL 的 PreMulSum 自定义 reduce 运算,它允许乘以 减少前的因子。

参数

factorfloat) - 自定义除法因子。

set_requires_all_reducerequires_all_reduce*recurse=True[来源][来源]

设置模块是否应全部减少梯度。这可用于 仅使用 reduce-scatter 实现梯度累积,而不使用 reduce-scatter 实现梯度累积 all-reduce 的 HSDP 的 JSON JSON 的

set_requires_gradient_syncrequires_gradient_sync*recurse=True[源][源]

设置模块是否应同步渐变。这可用于实现 梯度累积,无通讯。对于 HSDP,此控件 reduce-scatter 和 all-reduce 一起。

参数
  • requires_gradient_syncbool) – 是否减少 模块的参数。

  • recursebool) - 是为所有 FSDP 子模块设置,还是只为 传入的模块。

set_reshard_after_backwardreshard_after_backward*recurse=True[来源][来源]

设置模块是否应在 backward 后重新分片参数。这可以 在梯度累积期间使用,以牺牲更高的内存 减少了通信,因为未分片的参数不需要 re-all-gather 在下一个前锋之前。

参数
  • reshard_after_backwardbool) – 是否在以下时间后重新分片参数 向后。

  • recursebool) - 是为所有 FSDP 子模块设置,还是只为 传入的模块。

set_unshard_in_backwardunshard_in_backward[来源][来源]

设置 FSDP 模块的参数是否需要在 向后。这可以在专家案例中使用,当用户知道所有 此 FSDP 模块的参数组中的参数不需要 反向计算(例如 embedding)。

unshardasync_op=False[来源][来源]

通过分配内存和全集合来取消模块的参数分片 参数。此方法不是递归的。取消分片遵循 ,因此如果设置,它将全部聚集在后面。param_dtype

参数

async_opbool) – 如果 ,则返回具有等待取消分片操作的方法的 a。如果 ,则返回并等待内部的句柄 这个函数。Truewait()FalseNone

返回类型

可选[UnshardHandle]

注意

如果 ,则 FSDP 将等待挂起的 unshard。仅用户 如果等待应该发生,则需要显式调用 before pre-forward.async_op=Truewait()

torch.distributed.fsdp 中。UnshardHandle

用于等待 op 的句柄。

wait[来源][来源]

等待 unshard 操作。这可确保当前流可以使用 未分片的参数,这些参数现在已注册到模块中。

torch.distributed.fsdp 中。register_fsdp_forward_methodmodulemethod_name[来源]

注册一个方法,以被视为 的正向方法 FSDP.module

FSDP 在转发前全收集参数,并选择性地释放参数 post-forward (取决于 )。FSDP 只知道 默认情况下执行此操作。此函数修补 用户指定的方法在 方法。如果 不是 ,则 这是一个 no-op。reshard_after_forwardnn.Module.forward()module

参数
  • 模块nn.module) - 要在其上注册 forward 方法的模块。

  • method_namestr) – forward 方法的名称。

torch.distributed.fsdp 中。MixedPrecisionPolicyparam_dtype=reduce_dtype=output_dtype=cast_forward_inputs=)

这将配置 FSDP 的混合精度。与自动投射不同,这适用于混合 module 级别的 precision,而不是 op 级别的 precision,这意味着低精度 为向后保存激活,从高到低精度的强制转换为 仅在模块边界处发生。

FSDP 与模块级混合精度配合得很好,因为它保持了 无论如何,内存中的高精度分片参数。换句话说,FSDP 不需要任何额外的内存来保留 optimizer 步骤的参数。

变量
  • param_dtypeOptional[torch.dtype]) – 指定 unsharded 参数,因此 FORWARD/BACKWARD 的 DTYPE computation 和参数 all-gather 的 SET 来访问。如果这是 ,则 unsharded 参数使用原始 dtype。优化器步骤 使用原始 dtype 中的 sharded 参数。(默认:NoneNone)

  • reduce_dtypeOptional[torch.dtype]) – 指定 梯度减少(即 Reduce-Scatter 或 All-Reduce)。如果这是 but is not ,则减少 使用 Compute dtype。这可用于运行梯度缩减 全精度,同时使用低精度进行计算。如果还 梯度减少通过以下方式禁用 , 则 FSDP 将使用 累积梯度。 (默认:Noneparam_dtypeNoneset_requires_gradient_sync()reduce_dtypeNone)

  • output_dtypeOptional[torch.dtype]) – 指定 强制转换浮点前向输出。这可用于 帮助实现不同模块具有不同 mixed 的情况 精确策略。(默认:None)

  • cast_forward_inputsbool) – 这指定 FSDP 是否应将 forward 的浮点输入张量 to or not。param_dtype

torch.distributed.fsdp 中。卸载策略

此基类表示无卸载策略,仅用作 arg 的默认值。offload_policy

torch.distributed.fsdp 中。CPUOffloadPolicypin_memory=)

此卸载策略将参数、梯度和优化器状态卸载到 中央处理器。分片参数在 all-gather 之前从主机复制到设备。这 所有收集的参数都根据 释放。 分片梯度向后复制到 device-host,优化器 step 在具有 CPU 优化器状态的 CPU 上运行。reshard_after_forward

变量

pin_memorybool) – 是否固定分片参数和梯度 记忆。固定内存允许更高效的 H2D/D2H 拷贝 以及 COPIES 与 COMPUTE 重叠。但是,固定的 内存不能被其他进程使用。将此项设置为 if 您的 CPU 内存不足。(默认:FalseTrue)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源