目录

FullyShardedDataParallel

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=False, use_orig_params=False, ignored_parameters=None)[source]

A wrapper for sharding module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP。

Example:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

警告

优化器必须在模块被包装 之后 进行初始化, 因为 FSDP 会原地对参数进行分片,这将破坏任何 之前初始化的优化器。

警告

如果目标CUDA设备的ID为dev_id,则应满足以下条件之一:(1) module 应该已经放置在该设备上,(2) 应使用 torch.cuda.set_device(dev_id) 设置设备,或 (3) dev_id 应作为 device_id 构造函数参数传递。 此FSDP实例的计算设备将是该目标设备。对于(1)和(3),FSDP初始化总是在GPU上进行。 对于(2),FSDP初始化发生在 module 的当前设备上,这可能是CPU。

警告

FSDP 目前不支持在使用 CPU 卸载时将梯度累积到 no_sync() 之外。尝试这样做会导致 错误的结果,因为 FSDP 将使用新减少的梯度 而不是与任何现有的梯度进行累积。

警告

在构造后更改原始参数变量名称将导致未定义行为。

警告

传递 sync_module_states=True 标志需要将模块放在 GPU 上,或者使用 device_id 参数指定一个 CUDA 设备,FSDP 会将模块移动到该设备上。这是因为 sync_module_states=True 需要 GPU 通信。

警告

从PyTorch 1.12开始,FSDP仅对共享参数提供有限支持(例如,将一个Linear层的权重设置为另一个层的)。特别是,共享参数的模块必须作为同一FSDP单元的一部分进行包装。如果您的使用情况需要增强的共享参数支持,请联系https://github.com/pytorch/pytorch/issues/77724

注意

输入到 FSDP forward 函数的参数将在运行 forward 之前被移动到计算设备 (与 FSDP 模块所在的设备相同),因此用户无需手动将输入从 CPU 移动到 GPU。

Parameters:
  • 模块 (nn.Module) – 这是要用FSDP包装的模块。

  • process_group (可选[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]] 这是用于集体通信的过程组,并且是模型分片所使用的过程组。对于混合分片策略,例如 ShardingStrategy.HYBRID_SHARD 用户可以 传递一个过程组的元组,分别表示要分片和复制的过程组。

  • sharding_strategy (Optional[ShardingStrategy]) – 这配置了FSDP使用的分片策略,这可能会在内存节省和通信开销之间进行权衡。详见 ShardingStrategy。 (默认值: FULL_SHARD)

  • cpu_offload (可选[CPUOffload]) – 此配置用于CPU卸载。如果此值设置为 None,则 不会发生CPU卸载。详情请参见 CPUOffload。 (默认值: None)

  • auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]) –

    这是 None、一个 _FSDPPolicy 或一个具有固定签名的可调用函数。如果是 None,则 module 仅用顶层 FSDP 实例包裹,没有任何嵌套包裹。如果是 _FSDPPolicy,则根据给定的策略进行包裹。ModuleWrapPolicytorch.distributed.fsdp.wrap.py 中是一个示例。如果是一个可调用函数,则应接受三个参数 module: nn.Modulerecurse: boolnonwrapped_numel: int,并返回一个 bool,指定传递的 module 是否应被包裹,或者是否应继续遍历子树。recurse=Falserecurse=True。可以向可调用函数添加额外的自定义参数。size_based_auto_wrap_policytorch.distributed.fsdp.wrap.py 中给出了一个示例可调用函数,当其子树中的参数超过 100M numel 时包裹模块。一个好的做法是在包裹后打印模型,并根据需要进行调整。

    Example:

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     nonwrapped_numel: int,
    >>>     # Additional custom arguments
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return nonwrapped_numel >= min_num_params
    >>> # Configure a custom `min_num_params`
    >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
    

  • backward_prefetch (Optional[BackwardPrefetch]) – 这配置了所有gather的显式反向预取。详见 BackwardPrefetch。 (默认值: BACKWARD_PRE)

  • 混合精度 (可选[MixedPrecision]) – 此配置为FSDP设置原生混合精度。如果此值设置为 None, 则不使用混合精度。否则,可以设置参数、缓冲区和梯度减少的数据类型。详情请参见 MixedPrecision。 (默认值: None)

  • ignored_modules (可选[Iterable[torch.nn.Module]]) – 本实例将忽略这些模块及其子模块的参数和缓冲区。在ignored_modules中的任何直接模块都不应是FullyShardedDataParallel实例,并且如果它们嵌套在本实例下,已经构建的FullyShardedDataParallel实例的任何子模块也不会被忽略。此参数可用于避免在使用auto_wrap_policy时按模块粒度对特定参数进行分片,或者如果参数的分片不由FSDP管理。(默认值: None)

  • 参数初始化函数 (可选[Callable[[nn.Module], None]]) –

    一个 Callable[torch.nn.Module] -> None,用于指定当前位于元设备上的模块应该如何初始化到实际设备上。请注意,从v1.12版本开始,我们通过 is_meta 检查来检测元设备上的模块,并应用默认的初始化方法,该方法会在传入的 nn.Module 上调用 reset_parameters 方法(如果未指定 param_init_fn),否则我们会运行 param_init_fn 来初始化传入的 nn.Module。特别是,这意味着如果任何将被FSDP包装的模块参数的 is_meta=True 未指定,且 param_init_fn 也未指定,我们将假设你的模块正确实现了 reset_parameters(),否则会抛出错误。请注意,此外我们还支持使用torchdistX的 (https://github.com/pytorch/torchdistX) deferred_init API 初始化的模块。在这种情况下,延迟初始化的模块将通过默认的初始化函数进行初始化,该函数会调用 torchdistX 的 materialize_module,或者如果传入的 param_init_fn 不是 None,则使用该传入的 param_init_fn。同样的 Callable 会被用来初始化所有元模块。请注意,这个初始化函数会在执行任何FSDP分片逻辑之前应用。

    Example:

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module):
    >>>     # responsible for initializing a module, such as with reset_parameters
    >>>     ...
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_id (Optional[Union[int, torch.device]]) – 一个 inttorch.device 描述 FSDP 模块应被移动到的 CUDA 设备,确定初始化(如分片)的位置。如果未指定此参数 且 module 在 CPU 上,我们会发出警告,指出可以指定此参数以加快初始化速度。如果指定了该参数,生成的 FSDP 实例 将位于此设备上,包括在需要时移动被忽略模块的参数。请注意,如果指定了 device_idmodule 已经在不同的 CUDA 设备上,将引发错误。(默认值: None)

  • sync_module_states (bool) – 如果为 True,每个单独封装的 FSDP 单位将从 rank 0 广播模块参数,以确保初始化后所有 rank 上的参数一致。这有助于在训练开始前确保所有 rank 上的模型参数一致,但会增加 __init__ 的通信开销,因为每个单独封装的 FSDP 单位至少会触发一次广播。 这也可以帮助以内存高效的方式加载由 state_dict 保存并由 load_state_dict 加载的检查点。有关此功能的示例,请参阅 FullStateDictConfig 的文档。(默认值: False)

  • forward_prefetch (bool) – 如果 True,则在前向传递执行期间,FSDP 显式地 预取下一个即将到来的 all-gather 操作。 这可能会改善 CPU 密集型工作负载的通信和计算重叠。此功能仅适用于静态图模型,因为前向顺序基于第一次迭代的执行而固定。(默认值: False)

  • limit_all_gathers (bool) – 如果 False,则 FSDP 允许 CPU 线程在没有任何额外同步的情况下调度 all-gathers。 如果 True,则 FSDP 显式同步 CPU 线程以防止过多的正在进行中的 all-gathers。此 bool 仅影响调度 all-gathers 的分片策略。启用此选项可以帮助减少 CUDA malloc 重试次数。

  • ignored_parameters (Optional[Iterable[torch.nn.Parameter]]) – 被忽略的参数不会由该 FSDP 实例进行管理, 这意味着这些参数不会被 FSDP 展平和分片,它们的梯度也不会被同步。通过这个新添加的参数,ignored_modules 可能很快会被弃用。为了向后兼容性, 目前仍然保留 ignored_parametersignored_modules,但 FSDP 只允许其中一个不为 None

apply(fn)[source]

递归地将 fn 应用于每个子模块(如由 .children() 返回的)以及 self。典型用法包括初始化模型的参数 (另见 torch.nn.init)。

torch.nn.Module.apply相比,此版本在应用fn之前还会收集所有参数。不应在另一个summon_full_params上下文中调用它。

Parameters:

fn (Module -> None) – 应用于每个子模块的函数

Returns:

自我

Return type:

模块

clip_grad_norm_(max_norm, norm_type=2.0)[source]

对所有参数的梯度范数进行裁剪。该范数是将所有参数的梯度视为一个单一向量后计算得到的,梯度会以原地修改的方式进行调整。

Parameters:
  • max_norm (floatint) – 梯度的最大范数

  • norm_type (floatint) – 使用的p-范数类型。可以是 'inf' 表示无穷范数。

Returns:

参数的总范数(视为单个向量)。

Return type:

张量

注意

如果每个FSDP实例都使用NO_SHARD,这意味着梯度不会在各个秩之间进行分片,那么你可以直接使用 torch.nn.utils.clip_grad_norm_()

注意

如果至少有一个FSDP实例使用了分片策略(即除了NO_SHARD之外的其他策略),那么你应该使用这种方法,而不是torch.nn.utils.clip_grad_norm_(),因为这种方法处理了梯度在各个秩之间分片的事实。

注意

返回的总范数将具有所有参数/梯度中由PyTorch类型提升语义定义的“最大”dtype。例如,如果所有参数/梯度使用低精度dtype,则返回的范数的dtype将是该低精度dtype,但如果至少存在一个参数/梯度使用FP32,则返回的范数的dtype将是FP32。

警告

这需要在所有级别上调用,因为它使用了集体通信。

static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source]

API 类似于 shard_full_optim_state_dict()。唯一的 区别是输入 sharded_optim_state_dict 应该从 sharded_optim_state_dict() 返回。因此,每个秩上将 有 all-gather 调用来收集 ShardedTensor s。

Parameters:
  • sharded_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存分片优化器状态。

  • 模型 (torch.nn.Module) – 请参见 :meth:shard_full_optim_state_dict.

  • optim (torch.optim.Optimizer) – 用于 model 的优化器

  • parameters.

Returns:

参考 shard_full_optim_state_dict()

Return type:

字典[字符串, 任意类型]

forward(*args, **kwargs)[source]

对包装模块执行前向传递,在前向传递中插入特定于FSDP的预处理和后处理分片逻辑。

Return type:

任何

static fsdp_modules(module, root_only=False)[source]

返回所有嵌套的FSDP实例,可能包括module本身 并且仅包括FSDP根模块如果root_only=True

Parameters:
  • 模块 (torch.nn.Module) – 根模块,可能是一个或不是一个 FSDP 模块。

  • root_only (布尔值) – 是否仅返回FSDP根模块。 (默认值:False

Returns:

嵌套在输入module中的FSDP模块。

Return type:

List[FullyShardedDataParallel]

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]

将完整的优化器状态在rank 0上进行整合并返回它 作为一个 dict,遵循 torch.optim.Optimizer.state_dict() 的惯例,即带有键 "state""param_groups"。在 FSDP 模块中包含的 model 中的扁平化参数被映射回它们未扁平化的参数。

警告

这需要在所有秩上调用,因为它使用了 集体通信。但是,如果 rank0_only=True,那么 状态字典仅在秩 0 上填充,而所有其他秩 返回一个空的 dict

警告

torch.optim.Optimizer.state_dict() 不同,此方法 使用完整的参数名称作为键,而不是参数ID。

注意

就像在 torch.optim.Optimizer.state_dict() 中一样,优化器状态字典中包含的张量不会被克隆,因此可能会有别名意外。为了最佳实践,请考虑立即保存返回的优化器状态字典,例如使用 torch.save()

Parameters:
  • 模型 (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例)的参数被传递给优化器 optim

  • 优化器 (torch.optim.Optimizer) – 用于 model 的参数的优化器。

  • optim_input (可选[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入 optim,表示参数组的列表或参数的可迭代对象; 如果为 None,则此方法假设输入为 model.parameters()。此参数已弃用,无需再传递它。(默认值: None

  • rank0_only (布尔值) – 如果 True,仅在 rank 0 上保存填充的 dict;如果 False,则在所有 ranks 上保存。 (默认:True

  • (dist.ProcessGroup) – 模型的进程组,或者如果使用默认进程组则为 None。 (默认: None)

Returns:

一个 dict 包含优化器状态,适用于 model 的原始未展平参数,并包含键 “state” 和 “param_groups”,遵循 torch.optim.Optimizer.state_dict() 的约定。如果 rank0_only=True, 则非零秩返回一个空的 dict

Return type:

Dict[字符串, 任意]

static load_optim_state_dict_pre_hook(model, optim, optim_state_dict, group=None)[source]

此钩子旨在被 torch.distributed.NamedOptimizer 使用。 该功能与 :meth:optim_state_dict_to_load 相同, 只是参数不同。

Parameters:
  • 模型 (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例)的参数被传递给优化器 optim

  • 优化器 (torch.optim.Optimizer) – 用于 model 的参数的优化器。

  • optim_state_dict (Dict[str, Any]) – 要加载的优化器状态。

  • (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为 None。 ( 默认值: None)

Return type:

字典[字符串, 任意类型]

property module: Module

返回被包装的模块(比如 DistributedDataParallel)。

named_buffers(*args, **kwargs)[source]

覆盖 named_buffers() 以拦截缓冲区名称并在 summon_full_params() 上下文管理器内部移除所有 FSDP 特有的展平缓冲区前缀。

Return type:

迭代器[元组[字符串, 张量]]

named_parameters(*args, **kwargs)[source]

覆盖 named_parameters() 以拦截参数名称并 在 summon_full_params() 上下文管理器内部时 移除所有 FSDP 特有的展平参数前缀。

Return type:

迭代器[元组[字符串, 参数]]

no_sync()[source]

一个上下文管理器,用于禁用 FSDP 实例之间的梯度同步。在此上下文中,梯度将被累积在模块变量中,之后在退出该上下文后的第一次前向-反向传递过程中进行同步。此功能仅应作用于根 FSDP 实例,并会递归地应用于所有子 FSDP 实例。

注意

这可能会导致更高的内存使用,因为FSDP将累积整个模型的梯度(而不是梯度碎片),直到最终同步。

注意

当与CPU卸载一起使用时,在上下文管理器内部,梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。

Return type:

生成器

static optim_state_dict(model, optim, group=None)[source]

返回 optim 的状态字典,该状态字典由 FSDP 部分分片。状态可能是分片的、合并的,或者仅在 rank 0 上合并,具体取决于通过 set_state_dict_type()state_dict_type() 设置的 state_dict_type

Example:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkponit()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     optim_state_dict, model, optim
>>> )
>>> optim.load_state_dict(optim_state_dict)
Parameters:
  • 模型 (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例)的参数被传递给优化器 optim

  • 优化器 (torch.optim.Optimizer) – 用于 model 的参数的优化器。

  • (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为 None。 ( 默认值: None)

Returns:

一个 dict 包含优化器状态用于 model。优化器状态的分片基于 state_dict_type

Return type:

Dict[字符串, 任意]

static optim_state_dict_post_hook(model, optim, optim_state_dict, group=None)[source]

此钩子旨在被 torch.distributed.NamedOptimizer 使用。 该功能与 :meth:optim_state_dict 相同, 只是参数不同。

Parameters:
  • 模型 (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例)的参数被传递给优化器 optim

  • 优化器 (torch.optim.Optimizer) – 用于 model 的参数的优化器。

  • (Dict[str (optim) – 要转换的 optim_state_dict。该值通常由 NamedOptimizer.state_dict() 返回。

  • Any] – 要转换的 optim_state_dict。该值通常由 NamedOptimizer.state_dict() 返回。

  • (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为 None。 ( 默认值: None)

Returns:

一个 dict 包含优化器状态用于 model。优化器状态的分片基于 state_dict_type

Return type:

Dict[字符串, 任意]

static optim_state_dict_to_load(optim_state_dict, model, optim, is_named_optimizer=False, group=None)[source]

给定一个已保存的 optim_state_dict,将其转换为可以加载到 optim 的优化器 state_dict,其中 model 的优化器是 optimmodel 通过 FullyShardedDataParallel 进行(部分)分片。

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkponit()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     optim_state_dict, model, optim
>>> )
>>> optim.load_state_dict(optim_state_dict)
Parameters:
  • optim_state_dict (Dict[str, Any]) – 要加载的优化器状态。

  • 模型 (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例)的参数被传递给优化器 optim

  • 优化器 (torch.optim.Optimizer) – 用于 model 的参数的优化器。

  • is_named_optimizer (布尔值) – 此优化器是否为NamedOptimizer或KeyedOptimizer。仅当optim是TorchRec的KeyedOptimizer或torch.distributed的NamedOptimizer时,设置为True。

  • (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为 None。 ( 默认值: None)

Return type:

字典[字符串, 任意类型]

register_comm_hook(state, hook)[source]

注册一个通信钩子,这是一种增强功能,为用户提供了一个灵活的钩子,他们可以指定FSDP如何在多个工作节点之间聚合梯度。 此钩子可用于实现诸如 GossipGrad 和梯度压缩 等几种算法,这些算法涉及不同的通信策略,在使用 FullyShardedDataParallel 进行训练时用于参数同步。

警告

FSDP通信钩子应在运行初始前向传递之前注册,并且仅注册一次。

Parameters:
  • 状态 (对象) –

    传递给钩子以在训练过程中维护任何状态信息。 示例包括梯度压缩中的错误反馈, 以及在 GossipGrad 中与下一个通信的对等方等。 它由每个工作进程本地存储 并由工作进程上的所有梯度张量共享。

  • 钩子 (可调用对象) – 可调用对象,具有以下签名之一: 1) hook: Callable[torch.Tensor] -> None: 此函数接收一个Python张量,该张量表示 与该FSDP单元包装的模型相关的所有变量的完整、展平且未分片的梯度。 然后执行所有必要的处理并返回None; 2) hook: Callable[torch.Tensor, torch.Tensor] -> None: 此函数接收两个Python张量,第一个张量表示 与该FSDP单元包装的模型相关的所有变量的完整、展平且未分片的梯度 (这些变量未被其他FSDP子单元包装)。第二个张量表示一个预分配大小的张量, 用于存储在归约后分片梯度的一部分。 在这两种情况下,可调用对象执行所有必要的处理并返回None。 具有签名1的可调用对象预计处理NO_SHARD情况下的梯度通信。 具有签名2的可调用对象预计处理分片情况下的梯度通信。

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]

将优化器状态字典 optim_state_dict 的键重新映射为使用键 类型 optim_state_key_type。这可用于实现具有 FSDP 实例的模型和不具有 FSDP 实例的模型之间的优化器状态字典兼容性。

要重新键入FSDP完整优化器状态字典(即从 full_optim_state_dict())以使用参数ID并可加载到 非包装模型中:

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

要将普通优化器的状态字典从非包装模型重新键入以便可以加载到包装模型中:

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
Returns:

优化器状态字典已使用optim_state_key_type指定的参数键重新键入。

Return type:

Dict[字符串, 任意]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]

将完整的优化器状态字典从 rank 0 散列到所有其他 rank, 并在每个 rank 上返回分片的优化器状态字典。返回值与 shard_full_optim_state_dict() 相同,并且在 rank 0 上,第一个参数应该是 full_optim_state_dict() 的返回值。

Example:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意

Both shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。

Parameters:
  • full_optim_state_dict (可选[Dict[str, Any]]) – 优化器状态 字典,对应于未展平的参数,并在秩为0时保存完整的非分片优化器状态;在非零秩上忽略该参数。

  • 模型 (torch.nn.Module) – 根模块(可以是或不是 FullyShardedDataParallel 实例),其参数对应于full_optim_state_dict中的优化器状态。

  • optim_input (可选[联合[列表[字典[字符串, 任意]], 可迭代对象[torch.nn.Parameter]]]]) – 传递给优化器的输入,表示参数组或参数的可迭代对象; 如果 None,则此方法假定输入为 model.parameters()。此参数已弃用,无需再传递。 (默认值:None

  • 优化器 (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比 optim_input 更优选的参数。 (默认: None)

  • (dist.ProcessGroup) – 模型的进程组,或者如果使用默认进程组则为 None。 (默认值: None)

Returns:

完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。

Return type:

Dict[字符串, 任意]

static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]

设置 state_dict_type 以及目标模块的所有后代 FSDP 模块的相应(可选) 配置。目标模块不一定是 FSDP 模块。如果目标 模块是 FSDP 模块,其 state_dict_type 也将被更改。

注意

此API应仅针对顶级(根)模块调用。

注意

此API使用户能够透明地使用传统的 state_dict API,在根FSDP模块被另一个nn.Module包装的情况下进行模型检查点。例如, 以下将确保在所有非FSDP实例上调用state_dict,同时调度到sharded_state_dict实现 用于FSDP:

Example:

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
>>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
>>> )
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
Parameters:
  • 模块 (torch.nn.Module) – 根模块。

  • state_dict_type (StateDictType) – 所需的 state_dict_type 设置。

  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的配置。

Returns:

一个StateDictSettings,其中包含模块的先前state_dict类型和配置。

Return type:

StateDictSettings

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]

将完整的优化器状态字典 full_optim_state_dict 按照分片方式进行处理,通过将状态映射到扁平化参数而非非扁平化参数,并仅保留此 rank 的优化器状态部分。第一个参数应该是 full_optim_state_dict() 的返回值。

Example:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

注意

Both shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。

Parameters:
  • full_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存完整的非分片优化器状态。

  • 模型 (torch.nn.Module) – 根模块(可以是或不是 FullyShardedDataParallel 实例),其参数对应于full_optim_state_dict中的优化器状态。

  • optim_input (可选[联合[列表[字典[字符串, 任意]], 可迭代对象[torch.nn.Parameter]]]]) – 传递给优化器的输入,表示参数组或参数的可迭代对象; 如果 None,则此方法假定输入为 model.parameters()。此参数已弃用,无需再传递。 (默认值:None

  • 优化器 (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比 optim_input 更优选的参数。 (默认: None)

Returns:

完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。

Return type:

Dict[字符串, 任意]

static sharded_optim_state_dict(model, optim, group=None)[source]

该API与full_optim_state_dict()类似,但此API将所有非零维度的状态分块为ShardedTensor以节省内存。 仅当模型state_dict是在上下文管理器with state_dict_type(SHARDED_STATE_DICT):中派生时,才应使用此API。

有关详细用法,请参阅 full_optim_state_dict()

警告

返回的状态字典包含 ShardedTensor 并且 不能直接用于常规的 optim.load_state_dict

Return type:

字典[字符串, 任意类型]

static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]

一个上下文管理器,用于设置目标模块所有后代 FSDP 模块的 state_dict_type。此上下文管理器具有与 set_state_dict_type() 相同的功能。有关详细信息,请阅读 set_state_dict_type() 的文档。

Example:

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>> ):
>>>     checkpoint = model.state_dict()
Parameters:
  • 模块 (torch.nn.Module) – 根模块。

  • state_dict_type (StateDictType) – 所需的 state_dict_type 设置。

  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的配置。

Return type:

生成器

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]

一个上下文管理器,用于暴露FSDP实例的完整参数。 在模型完成前向/反向传播后,可以用于获取参数以进行额外处理或检查。它可以接受一个非FSDP模块,并根据recurse参数的值,为所有包含的FSDP模块及其子模块召唤完整参数。

注意

这可以用于内部的FSDPs。

注意

这不能在前向或后向传递中使用。也不能从这个上下文中启动前向和后向传递。

注意

参数将在上下文管理器退出后恢复为本地碎片,存储行为与前向传播相同。

注意

完整的参数可以被修改,但只有与本地参数分片对应的部分会在上下文管理器退出后保留(除非 writeback=False,在这种情况下,更改将被丢弃)。在FSDP不切分参数的情况下,目前仅当 world_size == 1NO_SHARD 配置时,无论 writeback 如何,修改都会被保留。

注意

此方法适用于本身不是FSDP的模块,但可能包含多个独立的FSDP单元。在这种情况下,给定的参数将应用于所有包含的FSDP单元。

警告

请注意,rank0_only=Truewriteback=True 的组合目前不受支持,并将引发错误。这是因为在上下文中,模型参数的形状在不同的秩之间会有所不同,当退出上下文时,写入它们可能会导致秩之间的不一致。

警告

请注意,offload_to_cpurank0_only=False 将导致完整参数被冗余地复制到CPU内存中,对于位于同一台机器上的GPU,这可能会导致CPU内存溢出的风险。建议使用 offload_to_cpurank0_only=True

Parameters:
  • 递归 (布尔值, 可选) – 递归地召唤所有嵌套的FSDP实例的所有参数(默认值:True)。

  • 写回 (布尔值, 可选) – 如果 False, 在上下文管理器退出后对参数的修改将被丢弃; 禁用此选项可能会稍微提高效率(默认值:True)

  • rank0_only (布尔值, 可选) – 如果 True,完整的参数仅在全局排名为0的设备上实例化。这意味着在此上下文中,只有排名为0的设备将拥有完整的参数,其他设备将拥有分片的参数。请注意,在此上下文中设置 rank0_only=Truewriteback=True 是不支持的,因为模型参数形状在不同排名之间会有所不同,写入这些参数可能会导致退出上下文时各排名之间的不一致。

  • offload_to_cpu (布尔值, 可选) – 如果 True,全部参数将被卸载到CPU。请注意,这种卸载目前仅在参数被分片时发生(这仅在world_size = 1或NO_SHARD配置时不适用)。建议与offload_to_cpurank0_only=True一起使用,以避免将模型参数的冗余副本卸载到相同的CPU内存。

  • with_grads (bool, Optional) – 如果为True,则梯度也会与参数一起取消分片。目前,这仅在将use_orig_params=True传递给FSDP构造函数并将offload_to_cpu=False传递给此方法时支持。(默认值:False)

Return type:

生成器

class torch.distributed.fsdp.BackwardPrefetch(value)[source]

这配置了显式的反向预取,可以提高吞吐量,但可能会略微增加峰值内存使用。

对于 NCCL 后端,任何集合操作,即使是在不同的流中发出的, 都会竞争同一个设备的 NCCL 流,这就是为什么集合操作发出的相对顺序 会影响重叠效果。不同的反向预取设置对应于不同的顺序。

  • BACKWARD_PRE: 这会在当前参数梯度计算之前预取下一组参数。这通过重叠通信(下一个all-gather)和计算(当前梯度计算)来提高反向传递的吞吐量。

  • BACKWARD_POST: 这会在当前参数梯度计算之后预取下一组参数。这可能会通过重叠通信(当前的reduce-scatter)和计算(下一个梯度计算)来提高反向传播的吞吐量。具体来说,下一次的all-gather操作会被重新排序到当前的reduce-scatter之前。

注意

如果预取导致峰值内存使用量增加是一个问题,您可以考虑将 limit_all_gathers=True 传递给 FSDP 构造函数,这在某些情况下可能有助于减少峰值内存使用量。

class torch.distributed.fsdp.ShardingStrategy(value)[source]

这指定了用于分布式训练的分片策略,由 FullyShardedDataParallel

  • FULL_SHARD: 参数、梯度和优化器状态被分片。 对于参数,此策略在前向传播之前通过全聚操作取消分片,在前向传播之后重新分片,在反向计算之前取消分片,并在反向计算之后重新分片。对于梯度,它在反向计算之后通过归约分散操作同步并分片它们。分片的优化器状态按秩本地更新。

  • SHARD_GRAD_OP: 在计算过程中,梯度和优化器状态会被分片,并且参数在计算之外也会被分片。对于参数,此策略在前向传播之前进行不分片,在前向传播之后不重新分片,仅在反向传播计算之后重新分片。分片的优化器状态在每个秩上本地更新。在 no_sync() 中,参数在反向传播计算之后不会被重新分片。

  • NO_SHARD: 参数、梯度和优化器状态不会被分片,而是像PyTorch的 DistributedDataParallel API一样在各个秩之间复制。对于梯度,此策略 在反向计算后通过全归约进行同步。未分片的优化器状态在每个秩上本地更新。

  • HYBRID_SHARD: Apply FULL_SHARD within a node, and replicate parameters across

    节点。这导致通信量减少,因为昂贵的 all-gathers 和 reduce-scatters 操作仅在节点内部执行,这对于中等规模的模型可能性能更优。

  • _HYBRID_SHARD_ZERO2: Apply SHARD_GRAD_OP within a node, and replicate parameters across

    节点。这类似于 HYBRID_SHARD,但可能提供更高的吞吐量,因为未分片的参数在前向传播后不会被释放,从而节省了预反向阶段的 all-gathers 操作。

class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True)[source]

这配置了FSDP原生混合精度训练。

Variables:
  • param_dtype (torch.dtype) – 这指定了模型参数、输入(当为 cast_forward_inputscast_root_forward_inputs``is set to ``True 时)的 dtype,因此也指定了计算的 dtype。 然而,在前向和反向传播之外,参数始终以全精度存储。模型检查点始终以全精度进行。

  • reduce_dtype (torch.dtype) – 这指定了梯度 缩减的 dtype,允许与 param_dtype 不同。

  • buffer_dtype (torch.dtype) – 这指定了缓冲区的数据类型。FSDP 不会对缓冲区进行分片,会在第一次前向传递中将其转换为 buffer_dtype,之后一直保持该数据类型。模型 检查点始终以全精度进行。

  • keep_low_precision_grads (bool) – 这指定是否在反向传播后将梯度上转换回全参数精度。如果使用可以以reduce_dtype执行优化器步骤的自定义优化器,可以将其设置为False以节省内存。 (Default: False)

  • cast_forward_inputs (bool) – 将前向传播参数和关键字参数中的浮点张量转换为 param_dtype。 (默认值: False)

  • cast_root_forward_inputs (bool) – 将前向参数和关键字参数中的浮点张量转换为 param_dtype,用于根 FSDP 实例。 它会覆盖根 FSDP 实例的 cast_forward_inputs 设置。 (默认值: True)

注意

此API为实验性质,可能会发生变化。

注意

只有浮点张量会被转换为指定的数据类型。

注意

summon_full_params 中,参数被强制为全精度,但缓冲区不是。

注意

state_dict checkpoints parameters and buffers in full precision. For buffers, this is only supported for StateDictType.FULL_STATE_DICT.

注意

每个低精度数据类型必须明确指定。例如,MixedPrecision(reduce_dtype=torch.float16) 仅指定缩减数据类型为低精度,FSDP 不会将参数或缓冲区转换为低精度。

注意

如果未指定 reduce_dtype,则梯度缩减 会在 param_dtype 中进行(如果指定了的话),否则使用原始参数的 dtype。

注意

如果用户向FSDP构造函数传递一个包含BatchNorm个模块的模型和 auto_wrap_policy,那么FSDP将为BatchNorm个模块禁用混合精度,方法是将它们分别包装在自己的FSDP实例中,并禁用混合精度。这是由于缺少一些低精度的BatchNorm内核。如果用户不使用auto_wrap_policy,则用户必须注意不要对包含BatchNorm个模块的FSDP实例使用混合精度。

注意

MixedPrecision 默认具有 cast_root_forward_inputs=Truecast_forward_inputs=False。对于根 FSDP 实例, 其 cast_root_forward_inputs 优先于其 cast_forward_inputs。对于非根 FSDP 实例,它们的 cast_root_forward_inputs 值被忽略。默认设置适用于典型情况,即每个 FSDP 实例具有相同的 MixedPrecision 配置,并且只需要在模型的前向传递开始时将输入转换为 param_dtype

注意

对于具有不同MixedPrecision配置的嵌套FSDP实例,我们建议设置各个cast_forward_inputs值以配置在每个实例的前向传播之前是否进行输入类型转换。在这种情况下,由于类型转换发生在每个FSDP实例的前向传播之前,父FSDP实例应该在其FSDP子模块之前运行其非FSDP子模块,以避免由于不同的MixedPrecision配置而导致激活数据类型发生变化。

Example:

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1],
>>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>>     model,
>>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )

上面显示了一个工作示例。另一方面,如果 model[1] 被替换为 model[0],意味着使用了不同的 MixedPrecision 的子模块首先运行其前向计算,那么 model[1] 将错误地看到 float16 激活而不是 bfloat16 激活。

class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]

此配置用于CPU卸载。

Variables:

offload_params (bool) – 这指定了在不参与计算时是否将参数卸载到 CPU。如果启用,这也会隐式地将梯度卸载到 CPU。这是为了支持优化器 步骤,该步骤要求参数和梯度位于同一设备上。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源