FullyShardedDataParallel¶
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None)[source]¶
A wrapper for sharding module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP。
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
警告
优化器必须在模块被FSDP包装之后进行初始化,因为FSDP将以可能不保留原始参数变量的方式对模块的参数进行分片和转换。因此,之前初始化的优化器可能对参数有陈旧的引用。
警告
如果目标CUDA设备的ID为
dev_id,则应满足以下条件之一:(1)module应该已经放置在该设备上,(2) 应使用torch.cuda.set_device(dev_id)设置设备,或 (3)dev_id应作为device_id构造函数参数传递。 此FSDP实例的计算设备将是该目标设备。对于(1)和(3),FSDP初始化总是在GPU上进行。 对于(2),FSDP初始化发生在module的当前设备上,这可能是CPU。警告
FSDP 目前不支持在使用 CPU 卸载时将梯度累积到
no_sync()之外。尝试这样做会导致 错误的结果,因为 FSDP 将使用新减少的梯度 而不是与任何现有的梯度进行累积。警告
在构造后更改原始参数变量名称将导致未定义行为。
警告
传递
sync_module_states=True标志需要module在GPU上或使用device_id参数来指定FSDP将在FSDP构造函数中移动module的CUDA设备。这是因为sync_module_states=True需要GPU通信。警告
从PyTorch 1.12开始,FSDP仅对共享参数提供有限支持(例如,将一个
Linear层的权重设置为另一个层的)。特别是,共享参数的模块必须作为同一FSDP单元的一部分进行包装。如果您的使用情况需要增强的共享参数支持,请联系https://github.com/pytorch/pytorch/issues/77724警告
FSDP 对冻结参数(即设置
param.requires_grad=False)有一些约束。对于use_orig_params=False,每个 FSDP 实例必须管理全部冻结或全部 非冻结的参数。对于use_orig_params=True,FSDP 支持混合冻结 和非冻结,但我们建议不要这样做,因为这样梯度 内存使用量将高于预期(即相当于没有 冻结这些参数)。这意味着理想情况下,冻结参数 应隔离到自己的nn.Module中并单独用 FSDP 包装。注意
尝试运行包含在FSDP实例中的子模块的前向传递是不支持的,并会导致错误。这是因为子模块的参数会被分片,但它本身并不是一个FSDP实例,因此它的前向传递不会正确地收集所有完整的参数。这可能发生在尝试仅运行编码器-解码器模型的编码器时,而编码器没有被包裹在它自己的FSDP实例中。为了解决这个问题,请将子模块包裹在它自己的FSDP单元中。
注意
FSDP 将输入张量移动到
forward方法到 GPU 计算设备,因此用户无需手动将它们从 CPU 移动。警告
用户不应在前向和后向之间修改参数,除非使用
summon_full_params()上下文,因为这些修改可能不会持久。此外,对于use_orig_params=False,在前向和后向之间访问原始参数可能会引发非法内存访问。警告
对于
use_orig_params=True,ShardingStrategy.SHARD_GRAD_OP在前向传播后暴露的是未分片的参数,而不是分片的参数, 因为它不会像ShardingStrategy.FULL_SHARD那样释放未分片的参数。一个需要注意的地方是,由于梯度 总是被分片或None,ShardingStrategy.SHARD_GRAD_OP将 不会在前向传播后与未分片的参数一起暴露分片的梯度。如果你想检查梯度,尝试使用summon_full_params()并带有with_grads=True。警告
FSDP 在前向和反向计算期间出于自动梯度相关的原因,将托管模块的参数替换为
torch.Tensor个视图。 如果您的模块的前向依赖于保存的参数引用而不是在每次迭代中重新获取引用,则它将看不到 FSDP 新创建的视图,并且自动梯度将无法正常工作。注意
使用
limit_all_gathers=True时,您可能会在FSDP预前向传播中看到一个间隙,此时CPU线程没有发出任何内核。这是有意为之,并显示了速率限制器正在生效。以这种方式同步CPU线程可以防止为后续的全聚操作过度分配内存,并且实际上不应延迟GPU内核执行。注意
当使用
sharding_strategy=ShardingStrategy.HYBRID_SHARD并且分片进程组为节点内,复制进程组为节点间时,设置NCCL_CROSS_NIC=1可以帮助在某些集群配置中提高复制进程组上的全归约时间。- Parameters
模块 (nn.Module) – 这是要用FSDP包装的模块。
进程组 (可选[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是模型分片的进程组,因此也是用于FSDP的all-gather和reduce-scatter集体通信的进程组。如果为
None,则FSDP使用默认的进程组。对于如ShardingStrategy.HYBRID_SHARD这样的混合分片策略,用户可以传递一个进程组的元组,分别表示要分片和复制的组。如果为None,则FSDP为用户构建进程组,以在节点内分片和跨节点复制。(默认:None)分片策略 (可选[ShardingStrategy]) – 此配置设置分片策略,该策略可能在内存节省和通信开销之间进行权衡。详情请参见
ShardingStrategy。 (默认值:FULL_SHARD)cpu_offload (可选[CPUOffload]) – 此配置用于CPU卸载。如果此值设置为
None,则 不会发生CPU卸载。详情请参见CPUOffload。 (默认值:None)auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy]]) –
这指定了一个策略,用于将FSDP应用于
module的子模块, 这是通信和计算重叠所必需的,因此会影响性能。如果None,那么FSDP仅应用于module,用户应手动将FSDP应用于父模块 本身(自下而上进行)。为了方便起见,这直接接受ModuleWrapPolicy,允许用户指定要包装的 模块类(例如变压器块)。否则, 这应该是一个可调用函数,它接受三个参数module: nn.Module、recurse: bool和nonwrapped_numel: int,并应返回一个bool,指定 传递的module是否应在recurse=False时应用FSDP,或者如果recurse=True则继续遍历模块的子树。用户可以向可调用函数添加其他 参数。size_based_auto_wrap_policy中的torch.distributed.fsdp.wrap.py给出了一个示例可调用函数,该函数在模块子树中的参数超过 100M numel时应用FSDP。我们建议在应用FSDP后打印模型 并根据需要进行调整。Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (可选[BackwardPrefetch]) – 此配置显式预取所有gather的反向传播。如果
None,则FSDP不进行反向预取,并且在反向传播过程中没有通信和计算重叠。详情请参见BackwardPrefetch。 (默认值:BACKWARD_PRE)混合精度 (可选[MixedPrecision]) – 此配置为FSDP设置原生混合精度。如果此值设置为
None, 则不使用混合精度。否则,可以设置参数、缓冲区和梯度减少的数据类型。详情请参见MixedPrecision。 (默认值:None)ignored_modules (可选[Iterable[torch.nn.Module]]) – 本实例将忽略这些模块及其子模块的参数和缓冲区。在
ignored_modules中的任何直接模块都不应是FullyShardedDataParallel实例,并且如果它们嵌套在本实例下,已经构建的FullyShardedDataParallel实例的任何子模块也不会被忽略。此参数可用于避免在使用auto_wrap_policy时按模块粒度对特定参数进行分片,或者如果参数的分片不由FSDP管理。(默认值:None)参数初始化函数 (可选[Callable[[nn.Module], None]]) –
一个
Callable[torch.nn.Module] -> None用于指定当前在元设备上的模块应该如何初始化到实际设备上。从v1.12开始,FSDP通过is_meta检测参数或缓冲区在元设备上的模块,并根据指定应用param_init_fn或者调用nn.Module.reset_parameters()。在这两种情况下,实现应该仅初始化模块的参数/缓冲区,而不是其子模块的参数/缓冲区。这是为了避免重新初始化。此外,FSDP还支持通过torchdistX的(https://github.com/pytorch/torchdistX)deferred_init()API进行延迟初始化,其中延迟模块通过调用指定的param_init_fn或者torchdistX的默认materialize_module()进行初始化。如果指定了param_init_fn,则将其应用于所有元设备模块,这意味着它可能需要根据模块类型进行判断。FSDP在参数扁平化和分片之前调用初始化函数。Example:
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
设备ID (可选[Union[int, torch.device]]) – 一个
int或torch.device,指定进行FSDP初始化的CUDA设备,包括模块初始化(如果需要)和参数分片。如果module在CPU上运行,应指定此参数以提高初始化速度。如果已设置默认CUDA设备(例如通过torch.cuda.set_device设置),则用户可以将torch.cuda.current_device传递给此参数。 (默认:None)sync_module_states (布尔值) – 如果
True,则每个FSDP模块将从rank 0广播模块参数和缓冲区以确保它们在各rank之间复制(为此构造函数增加通信开销)。这可以帮助以一种节省内存的方式加载state_dict检查点通过load_state_dict。参见FullStateDictConfig的示例。(默认:False)forward_prefetch (bool) – 如果
True,则 FSDP 显式地 在当前前向计算之前预取下一个前向传递的全归约。这仅对CPU密集型工作负载有用,在这种情况下,提前发出下一个全归约可能会提高重叠度。这应该只用于静态图模型,因为预取遵循第一次迭代的执行顺序。(默认值:False)limit_all_gathers (bool) – 如果
True,则 FSDP 明确同步 CPU 线程以确保仅从两个连续的 FSDP 实例(当前正在运行计算的实例和下一个预取 all-gather 的实例)使用 GPU 内存。如果False,则 FSDP 允许 CPU 线程在没有任何额外同步的情况下发出 all-gathers。(默认值:True)我们通常将此功能称为“速率限制器”。此标志仅应为特定的 CPU 密集型工作负载设置为False,这些工作负载具有较低的内存压力,在这种情况下,CPU 线程可以积极地发出所有内核而不必担心 GPU 内存使用情况。use_orig_params (bool) – 将此设置为
True会使 FSDP 使用module的原始参数。FSDP 通过nn.Module.named_parameters()向用户暴露这些原始参数,而不是 FSDP 内部的FlatParameters。这意味着 优化器步骤在原始参数上运行,从而启用 每个原始参数的超参数。FSDP 保留原始 参数变量并在非分片和分片形式之间操作它们的数据,其中它们始终是底层 非分片或分片FlatParameter的视图,分别。使用当前算法,分片形式始终为 1D,丢失了 原始张量结构。一个原始参数可能在其给定秩中具有全部、部分或没有数据。在没有的情况下, 其数据将类似于大小为 0 的空张量。用户不应编写依赖于给定 原始参数在其分片形式中存在什么数据的程序。需要True来 使用torch.compile()。将其设置为False会通过nn.Module.named_parameters()向用户暴露 FSDP 的内部FlatParameters。(默认值:False)忽略的状态 (可选[可迭代对象[torch.nn.Parameter]], 可选[可迭代对象[torch.nn.Module]]) – 忽略的参数或模块,这些参数或模块将不由该FSDP实例管理,这意味着参数不会被分片且其梯度不会在不同排名间进行归约。此参数与现有的
ignored_modules参数统一,并且我们可能会很快弃用ignored_modules。为了向后兼容性,我们保留了ignored_states和ignored_modules`,但FSDP只允许指定其中一个,而不是None。
- apply(fn)[source]¶
递归地将
fn应用于每个子模块(如由.children()返回的)以及 self。典型用法包括初始化模型的参数 (另见 torch.nn.init)。与
torch.nn.Module.apply相比,此版本在应用fn之前还会收集所有参数。不应在另一个summon_full_params上下文中调用它。- Parameters
fn (
Module-> None) – 应用于每个子模块的函数- Returns
自我
- Return type
- clip_grad_norm_(max_norm, norm_type=2.0)[source]¶
对所有参数的梯度范数进行裁剪。该范数是将所有参数的梯度视为一个单一向量后计算得到的,梯度会以原地修改的方式进行调整。
- Parameters
- Returns
参数的总范数(视为单个向量)。
- Return type
注意
如果每个FSDP实例都使用
NO_SHARD,这意味着梯度不会在各个秩之间进行分片,那么你可以直接使用torch.nn.utils.clip_grad_norm_()。注意
如果至少有一个FSDP实例使用了分片策略(即除了
NO_SHARD之外的其他策略),那么你应该使用这种方法,而不是torch.nn.utils.clip_grad_norm_(),因为这种方法处理了梯度在各个秩之间分片的事实。注意
返回的总范数将具有所有参数/梯度中由PyTorch类型提升语义定义的“最大”dtype。例如,如果所有参数/梯度使用低精度dtype,则返回的范数的dtype将是该低精度dtype,但如果至少存在一个参数/梯度使用FP32,则返回的范数的dtype将是FP32。
警告
这需要在所有级别上调用,因为它使用了集体通信。
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source]¶
API 类似于
shard_full_optim_state_dict()。唯一的 区别是输入sharded_optim_state_dict应该从sharded_optim_state_dict()返回。因此,每个秩上将 有 all-gather 调用来收集ShardedTensors。- Parameters
sharded_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存分片优化器状态。
模型 (torch.nn.Module) – 参见
shard_full_optim_state_dict()。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。
- Returns
- Return type
- static fsdp_modules(module, root_only=False)[source]¶
返回所有嵌套的FSDP实例,可能包括
module本身 并且仅包括FSDP根模块如果root_only=True。- Parameters
模块 (torch.nn.Module) – 根模块,可能是一个或不是一个
FSDP模块。root_only (布尔值) – 是否仅返回FSDP根模块。 (默认值:
False)
- Returns
嵌套在输入
module中的FSDP模块。- Return type
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]¶
将完整的优化器状态在rank 0上进行整合并返回它 作为一个
dict,遵循torch.optim.Optimizer.state_dict()的惯例,即带有键"state"和"param_groups"。在FSDP模块中包含的model中的扁平化参数被映射回它们未扁平化的参数。警告
这需要在所有秩上调用,因为它使用了 集体通信。但是,如果
rank0_only=True,那么 状态字典仅在秩 0 上填充,而所有其他秩 返回一个空的dict。警告
与
torch.optim.Optimizer.state_dict()不同,此方法 使用完整的参数名称作为键,而不是参数ID。注意
就像在
torch.optim.Optimizer.state_dict()中一样,优化器状态字典中包含的张量不会被克隆,因此可能会有别名意外。为了最佳实践,请考虑立即保存返回的优化器状态字典,例如使用torch.save()。- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_input (可选[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入
optim,表示参数组的列表或参数的可迭代对象; 如果为None,则此方法假设输入为model.parameters()。此参数已弃用,无需再传递它。(默认值:None)rank0_only (布尔值) – 如果
True,仅在 rank 0 上保存填充的dict;如果False,则在所有 ranks 上保存。 (默认:True)组 (dist.ProcessGroup) – 模型的进程组,或者如果使用默认进程组则为
None。 (默认:None)
- Returns
一个
dict包含优化器状态,适用于model的原始未展平参数,并包含键 “state” 和 “param_groups”,遵循torch.optim.Optimizer.state_dict()的约定。如果rank0_only=True, 则非零秩返回一个空的dict。- Return type
Dict[字符串, 任意]
- static get_state_dict_type(module)[source]¶
获取 state_dict_type 以及对应配置 用于以
module为根的 FSDP 模块。目标模块 不一定是 FSDP 模块。- Returns
一个
StateDictSettings包含当前设置的状态字典类型和 状态字典 / 优化器状态字典配置。- Raises
如果StateDictSettings不同,则引发`AssertionError` –
FSDP 子模块不同。 –
- Return type
- named_buffers(*args, **kwargs)[source]¶
覆盖
named_buffers()以拦截缓冲区名称并在summon_full_params()上下文管理器内部移除所有 FSDP 特有的展平缓冲区前缀。
- named_parameters(*args, **kwargs)[source]¶
覆盖
named_parameters()以拦截参数名称并 在summon_full_params()上下文管理器内部时 移除所有 FSDP 特有的展平参数前缀。
- no_sync()[source]¶
一个上下文管理器,用于禁用 FSDP 实例之间的梯度同步。在此上下文中,梯度将被累积在模块变量中,之后在退出该上下文后的第一次前向-反向传递过程中进行同步。此功能仅应作用于根 FSDP 实例,并会递归地应用于所有子 FSDP 实例。
注意
这可能会导致更高的内存使用,因为FSDP将累积整个模型的梯度(而不是梯度碎片),直到最终同步。
注意
当与CPU卸载一起使用时,在上下文管理器内部,梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。
- Return type
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source]¶
将
optim的state_dict转换为三种类型之一:1) 完整的优化器state_dict,2) 分片的优化器state_dict,3) 本地的优化器state_dict,用于由FSDP分片的model。对于完整的优化器state_dict,所有状态都是未展平且未分片的。 仅限Rank0和CPU,可以通过
state_dict_type()指定以 避免内存不足(OOM)。对于分片优化器的state_dict,所有状态都是未展平但已分片的。 仅限CPU可以通过
state_dict_type()进一步节省 内存。对于本地的 state_dict,不会进行任何转换。但是,状态将从 nn.Tensor 转换为 ShardedTensor 以表示其分片特性(这目前还不支持)。
Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> optim_state_dict, model, optim >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_state_dict (Dict[str, Any]) – 要转换的目标优化器状态字典。如果值为 None,则将使用 optim.state_dict()。( 默认值:
None)组 (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为
None。 ( 默认值:None)
- Returns
一个
dict包含优化器状态用于model。优化器状态的分片基于state_dict_type。- Return type
Dict[字符串, 任意]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source]¶
给定一个
optim_state_dict,通过optim_state_dict()进行转换,将其转换为可以加载到optim的展平优化器 state_dict,其中optim是model的优化器。model必须通过 FullyShardedDataParallel 进行分片。>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> optim_state_dict, model, optim >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_state_dict (Dict[str, Any]) – 要加载的优化器状态。
is_named_optimizer (布尔值) – 此优化器是否为NamedOptimizer或KeyedOptimizer。仅当
optim是TorchRec的KeyedOptimizer或torch.distributed的NamedOptimizer时,设置为True。直接加载 (布尔值) – 如果设置为 True,此 API 将在返回结果之前调用 optim.load_state_dict(result)。否则,用户需要自行调用
optim.load_state_dict()(默认值:False)组 (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为
None。 ( 默认值:None)
- Return type
- register_comm_hook(state, hook)[source]¶
注册一个通信钩子,这是一种增强功能,为用户提供了一个灵活的钩子,他们可以指定FSDP如何在多个工作节点之间聚合梯度。 此钩子可用于实现诸如 GossipGrad 和梯度压缩 等几种算法,这些算法涉及不同的通信策略,在使用
FullyShardedDataParallel进行训练时用于参数同步。警告
FSDP通信钩子应在运行初始前向传递之前注册,并且仅注册一次。
- Parameters
状态 (对象) –
传递给钩子以在训练过程中维护任何状态信息。 示例包括梯度压缩中的错误反馈, 以及在 GossipGrad 中与下一个通信的对等方等。 它由每个工作进程本地存储 并由工作进程上的所有梯度张量共享。
钩子 (可调用对象) – 可调用对象,具有以下签名之一: 1)
hook: Callable[torch.Tensor] -> None: 此函数接收一个Python张量,该张量表示 与该FSDP单元包装的模型相关的所有变量的完整、展平且未分片的梯度。 然后执行所有必要的处理并返回None; 2)hook: Callable[torch.Tensor, torch.Tensor] -> None: 此函数接收两个Python张量,第一个张量表示 与该FSDP单元包装的模型相关的所有变量的完整、展平且未分片的梯度 (这些变量未被其他FSDP子单元包装)。第二个张量表示一个预分配大小的张量, 用于存储在归约后分片梯度的一部分。 在这两种情况下,可调用对象执行所有必要的处理并返回None。 具有签名1的可调用对象预计处理NO_SHARD情况下的梯度通信。 具有签名2的可调用对象预计处理分片情况下的梯度通信。
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]¶
将优化器状态字典
optim_state_dict的键重新映射为使用键 类型optim_state_key_type。这可用于实现具有 FSDP 实例的模型和不具有 FSDP 实例的模型之间的优化器状态字典兼容性。要重新键入FSDP完整优化器状态字典(即从
full_optim_state_dict())以使用参数ID并可加载到 非包装模型中:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
要将普通优化器的状态字典从非包装模型重新键入以便可以加载到包装模型中:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns
优化器状态字典已使用
optim_state_key_type指定的参数键重新键入。- Return type
Dict[字符串, 任意]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]¶
将完整的优化器状态字典从 rank 0 散列到所有其他 rank, 并在每个 rank 上返回分片的优化器状态字典。返回值与
shard_full_optim_state_dict()相同,并且在 rank 0 上,第一个参数应该是full_optim_state_dict()的返回值。Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
Both
shard_full_optim_state_dict()和scatter_full_optim_state_dict()都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (可选[Dict[str, Any]]) – 优化器状态 字典,对应于未展平的参数,并在秩为0时保存完整的非分片优化器状态;在非零秩上忽略该参数。
模型 (torch.nn.Module) – 根模块(可以是或不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (可选[联合[列表[字典[字符串, 任意]], 可迭代对象[torch.nn.Parameter]]]]) – 传递给优化器的输入,表示参数组或参数的可迭代对象; 如果
None,则此方法假定输入为model.parameters()。此参数已弃用,无需再传递。 (默认值:None)优化器 (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更优选的参数。 (默认:None)组 (dist.ProcessGroup) – 模型的进程组,或者如果使用默认进程组则为
None。 (默认值:None)
- Returns
完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。
- Return type
Dict[字符串, 任意]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]¶
设置
state_dict_type以及目标模块的所有后代 FSDP 模块的相应(可选) 配置。目标模块不一定是 FSDP 模块。如果目标 模块是 FSDP 模块,其state_dict_type也将被更改。注意
此API应仅针对顶级(根)模块调用。
注意
此API使用户能够透明地使用传统的
state_dictAPI,在根FSDP模块被另一个nn.Module包装的情况下进行模型检查点。例如, 以下将确保在所有非FSDP实例上调用state_dict,同时调度到sharded_state_dict实现 用于FSDP:Example:
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 所需的
state_dict_type设置。state_dict_config (可选[StateDictConfig]) – 目标
state_dict_type的配置。
- Returns
一个StateDictSettings,其中包含模块的先前state_dict类型和配置。
- Return type
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]¶
将完整的优化器状态字典
full_optim_state_dict按照分片方式进行处理,通过将状态映射到扁平化参数而非非扁平化参数,并仅保留此 rank 的优化器状态部分。第一个参数应该是full_optim_state_dict()的返回值。Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
注意
Both
shard_full_optim_state_dict()和scatter_full_optim_state_dict()都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存完整的非分片优化器状态。
模型 (torch.nn.Module) – 根模块(可以是或不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (可选[联合[列表[字典[字符串, 任意]], 可迭代对象[torch.nn.Parameter]]]]) – 传递给优化器的输入,表示参数组或参数的可迭代对象; 如果
None,则此方法假定输入为model.parameters()。此参数已弃用,无需再传递。 (默认值:None)优化器 (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更优选的参数。 (默认:None)
- Returns
完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。
- Return type
Dict[字符串, 任意]
- static sharded_optim_state_dict(model, optim, group=None)[source]¶
该API与
full_optim_state_dict()类似,但此API将所有非零维度的状态分块为ShardedTensor以节省内存。 仅当模型state_dict是在上下文管理器with state_dict_type(SHARDED_STATE_DICT):中派生时,才应使用此API。有关详细用法,请参阅
full_optim_state_dict()。警告
返回的状态字典包含
ShardedTensor并且 不能直接用于常规的optim.load_state_dict。
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]¶
一个上下文管理器,用于设置目标模块所有后代 FSDP 模块的
state_dict_type。此上下文管理器具有与set_state_dict_type()相同的功能。有关详细信息,请阅读set_state_dict_type()的文档。Example:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 所需的
state_dict_type设置。state_dict_config (可选[StateDictConfig]) – 目标模型的
state_dict配置state_dict_type。optim_state_dict_config (可选[OptimStateDictConfig]) – 优化器
state_dict配置用于目标state_dict_type。
- Return type
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]¶
一个上下文管理器,用于暴露FSDP实例的完整参数。 在模型完成前向/反向传播后,可以用于获取参数以进行额外处理或检查。它可以接受一个非FSDP模块,并根据
recurse参数的值,为所有包含的FSDP模块及其子模块召唤完整参数。注意
这可以用于内部的FSDPs。
注意
这不能在前向或后向传递中使用。也不能从这个上下文中启动前向和后向传递。
注意
参数将在上下文管理器退出后恢复为本地碎片,存储行为与前向传播相同。
注意
完整的参数可以被修改,但只有与本地参数分片对应的部分会在上下文管理器退出后保留(除非
writeback=False,在这种情况下,更改将被丢弃)。在FSDP不切分参数的情况下,目前仅当world_size == 1或NO_SHARD配置时,无论writeback如何,修改都会被保留。注意
此方法适用于本身不是FSDP的模块,但可能包含多个独立的FSDP单元。在这种情况下,给定的参数将应用于所有包含的FSDP单元。
警告
请注意,
rank0_only=True与writeback=True的组合目前不受支持,并将引发错误。这是因为在上下文中,模型参数的形状在不同的秩之间会有所不同,当退出上下文时,写入它们可能会导致秩之间的不一致。警告
请注意,
offload_to_cpu和rank0_only=False将导致完整参数被冗余地复制到CPU内存中,对于位于同一台机器上的GPU,这可能会导致CPU内存溢出的风险。建议使用offload_to_cpu与rank0_only=True。- Parameters
递归 (布尔值, 可选) – 递归地召唤所有嵌套的FSDP实例的所有参数(默认值:True)。
写回 (布尔值, 可选) – 如果
False, 在上下文管理器退出后对参数的修改将被丢弃; 禁用此选项可能会稍微提高效率(默认值:True)rank0_only (布尔值, 可选) – 如果
True,完整的参数仅在全局排名为0的设备上实例化。这意味着在此上下文中,只有排名为0的设备将拥有完整的参数,其他设备将拥有分片的参数。请注意,在此上下文中设置rank0_only=True和writeback=True是不支持的,因为模型参数形状在不同排名之间会有所不同,写入这些参数可能会导致退出上下文时各排名之间的不一致。offload_to_cpu (布尔值, 可选) – 如果
True,全部参数将被卸载到CPU。请注意,这种卸载目前仅在参数被分片时发生(这仅在world_size = 1或NO_SHARD配置时不适用)。建议与offload_to_cpu和rank0_only=True一起使用,以避免将模型参数的冗余副本卸载到相同的CPU内存。with_grads (bool, Optional) – 如果为
True,则梯度也会与参数一起取消分片。目前,这仅在将use_orig_params=True传递给FSDP构造函数并将offload_to_cpu=False传递给此方法时支持。(默认值:False)
- Return type
- class torch.distributed.fsdp.BackwardPrefetch(value)[source]¶
这配置了显式的反向预取,通过在反向传播过程中实现通信和计算的重叠来提高吞吐量,代价是略微增加内存使用。
BACKWARD_PRE: 这种方式可以实现最大程度的重叠,但也会导致内存使用量最大。它会在当前参数集的梯度计算之前预取下一组参数。这使得下一个all-gather和当前梯度计算重叠,在峰值时,它会将当前参数集、下一组参数以及当前梯度同时保留在内存中。BACKWARD_POST: 这减少了重叠,但需要更少的内存使用。这在当前参数集的梯度计算之后预取下一组参数。这将当前reduce-scatter与下一个梯度计算重叠,并在为下一组参数分配内存之前释放当前参数集,仅在峰值时在内存中保留下一组参数和当前梯度集。FSDP 的
backward_prefetch参数接受None,这将完全禁用反向预取。这没有重叠,并且不会增加内存使用量。一般来说,我们不推荐这种设置,因为它可能会显著降低吞吐量。
对于更多的技术背景:对于使用NCCL后端的单个进程组, 任何集合操作,即使是从不同的流中发出的,也会争夺每个设备的NCCL流, 这意味着集合操作发出的相对顺序对重叠很重要。两个反向预取值对应于不同的发出顺序。
- class torch.distributed.fsdp.ShardingStrategy(value)[source]¶
这指定了用于分布式训练的分片策略,由
FullyShardedDataParallel。FULL_SHARD: 参数、梯度和优化器状态被分片。 对于参数,此策略在前向传播之前通过全聚操作取消分片,在前向传播之后重新分片,在反向计算之前取消分片,并在反向计算之后重新分片。对于梯度,它在反向计算之后通过归约分散操作同步并分片它们。分片的优化器状态按秩本地更新。SHARD_GRAD_OP: 在计算过程中,梯度和优化器状态会被分片,并且参数在计算之外也会被分片。对于参数,此策略在前向传播之前进行不分片,在前向传播之后不重新分片,仅在反向传播计算之后重新分片。分片的优化器状态在每个秩上本地更新。在no_sync()中,参数在反向传播计算之后不会被重新分片。NO_SHARD: 参数、梯度和优化器状态不会被分片,而是像PyTorch的DistributedDataParallelAPI一样在各个秩之间复制。对于梯度,此策略 在反向计算后通过全归约进行同步。未分片的优化器状态在每个秩上本地更新。HYBRID_SHARD: 在节点内应用FULL_SHARD,并在节点之间复制参数。这导致通信量减少,因为昂贵的全收集和减少分散操作仅在节点内进行,这对于中等规模的模型可以更高效。_HYBRID_SHARD_ZERO2: 在节点内应用SHARD_GRAD_OP,并在节点之间复制参数。这类似于HYBRID_SHARD,但可能提供更高的吞吐量,因为未分片的参数在前向传递后不会被释放,从而节省了预反向传播中的所有聚集。
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source]¶
这配置了FSDP原生混合精度训练。
- Variables
param_dtype (可选[torch.dtype]) – 这指定了模型参数在前向和后向传播期间的数据类型,因此也决定了前向和后向计算的数据类型。在前向和后向传播之外,分片的参数保持全精度(例如,用于优化器步骤),并且在保存模型检查点时,参数始终以全精度保存。(默认值:
None)reduce_dtype (可选[torch.dtype]) – 这指定了梯度归约(即reduce-scatter或all-reduce)的dtype。如果这是
None但param_dtype不是None,那么它将采用param_dtype值,仍然以低精度运行梯度归约。这允许与param_dtype不同,例如 强制梯度归约以全精度运行。(默认:None)buffer_dtype (可选[torch.dtype]) – 这指定了缓冲区的数据类型。FSDP 不会分割缓冲区。相反,FSDP 在第一次前向传递时将它们转换为
buffer_dtype并在此后保持该数据类型。对于模型检查点,缓冲区以全精度保存,除了LOCAL_STATE_DICT。(默认值:None)keep_low_precision_grads (bool) – 如果
False,则 FSDP 在反向传播后将梯度上转换为全精度,以准备优化器步骤。如果True,则 FSDP 保持梯度在用于梯度减少的数据类型中,如果使用支持低精度运行的自定义优化器,则可以节省内存。 (Default:False)cast_forward_inputs (bool) – 如果
True,则此 FSDP 模块会将其前向参数和关键字参数转换为param_dtype。这是为了确保在前向计算中参数和输入的数据类型匹配,这是许多操作所要求的。当仅对某些而不是所有 FSDP 模块应用混合精度时,可能需要将其设置为True,在这种情况下,混合精度的 FSDP 子模块需要重新转换其输入。(默认值:False)cast_root_forward_inputs (bool) – 如果为
True,则根 FSDP 模块 将其前向参数和关键字参数转换为param_dtype,覆盖cast_forward_inputs的值。对于非根 FSDP 模块, 这不会做任何事情。 (默认:True)_module_classes_to_ignore (Sequence[Type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 这指定了在使用混合精度时要忽略的模块类。
auto_wrap_policy: 这些类别的模块将单独应用FSDP,并禁用混合精度(这意味着最终的FSDP构建将偏离指定的策略)。如果未指定auto_wrap_policy,则此设置无效。此API是实验性的,可能会发生变化。 (Default:(_BatchNorm,))
注意
此API为实验性质,可能会发生变化。
注意
只有浮点张量会被转换为指定的数据类型。
注意
在
summon_full_params中,参数被强制为全精度,但缓冲区不是。注意
Layer norm 和 batch norm 在
float32中累积,即使 它们的输入是低精度的,如float16或bfloat16。 禁用 FSDP 的混合精度仅针对这些规范化模块意味着 仿射参数保持在float32。然而,这会导致 这些规范化模块进行单独的全聚集和缩减分散,这 可能效率低下,因此如果工作负载允许,用户应优先 对这些模块应用混合精度。注意
默认情况下,如果用户传递了一个包含任何
_BatchNorm模块的模型,并指定了一个auto_wrap_policy,那么批归一化模块将单独应用FSDP,并禁用混合精度。请参阅_module_classes_to_ignore参数。注意
MixedPrecision默认具有cast_root_forward_inputs=True和cast_forward_inputs=False。对于根 FSDP 实例, 其cast_root_forward_inputs优先于其cast_forward_inputs。对于非根 FSDP 实例,它们的cast_root_forward_inputs值被忽略。默认设置适用于典型情况,即每个 FSDP 实例具有相同的MixedPrecision配置,并且只需要在模型的前向传递开始时将输入转换为param_dtype。注意
对于具有不同
MixedPrecision配置的嵌套FSDP实例,我们建议设置各个cast_forward_inputs值以配置在每个实例的前向传播之前是否进行输入类型转换。在这种情况下,由于类型转换发生在每个FSDP实例的前向传播之前,父FSDP实例应该在其FSDP子模块之前运行其非FSDP子模块,以避免由于不同的MixedPrecision配置而导致激活数据类型发生变化。Example:
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
上面显示了一个工作示例。另一方面,如果
model[1]被替换为model[0],意味着使用了不同的MixedPrecision的子模块首先运行其前向计算,那么model[1]将错误地看到float16激活而不是bfloat16激活。
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]¶
此配置用于CPU卸载。
- Variables
offload_params (bool) – 这指定了是否在不参与计算时将参数卸载到CPU。如果为
True,则还会将梯度卸载到CPU,这意味着优化器步骤将在CPU上运行。
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False, use_dtensor=False)[source]¶
StateDictConfig是所有state_dict配置类的基础类。用户应实例化一个子类(例如FullStateDictConfig)以配置 FSDP 支持的相应state_dict类型的设置。
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, use_dtensor=False, rank0_only=False)[source]¶
FullStateDictConfig是一个配置类,用于与StateDictType.FULL_STATE_DICT一起使用。我们建议在保存完整状态字典时分别启用offload_to_cpu=True和rank0_only=True,以节省 GPU 内存和 CPU 内存。此配置类应通过以下方式使用state_dict_type()上下文管理器:>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- Variables
rank0_only (bool) – 如果为
True,则只有 rank 0 保存完整的状态字典,而非零 rank 保存空字典。如果为False,则所有 rank 都保存完整的状态字典。(默认值:False)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu: bool = False, use_dtensor: bool = False)[source]¶
- class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False, use_dtensor: bool = False)[source]¶
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True, use_dtensor=False)[source]¶
OptimStateDictConfig是所有optim_state_dict配置类的基本类。用户应实例化一个子类(例如FullOptimStateDictConfig)以配置 FSDP 支持的相应optim_state_dict类型的设置。
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, use_dtensor=False, rank0_only=False)[source]¶
- Variables
rank0_only (bool) – 如果为
True,则只有 rank 0 保存完整的状态字典,而非零 rank 保存空字典。如果为False,则所有 rank 都保存完整的状态字典。(默认值:False)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu: bool = True, use_dtensor: bool = False)[source]¶
- class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False, use_dtensor: bool = False)[source]¶
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source]¶