FullyShardedDataParallel¶
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source]¶
一个用于将模块参数分片到数据并行工作进程中的包装器。
这是受 Xu 等人 以及 DeepSpeed 的 ZeRO 第三阶段启发。 FullyShardedDataParallel 通常缩写为 FSDP。
要了解FSDP内部结构,请参阅 FSDP笔记。
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
使用 FSDP 涉及包装你的模块,然后初始化你的优化器。这是必需的,因为 FSDP 改变了参数变量。
在设置FSDP时,您需要考虑目标CUDA设备。如果设备有ID(
dev_id),您有三个选项:将模块放在该设备上
使用
torch.cuda.set_device(dev_id)设置设备传入
dev_id到device_id构造函数参数中。
这确保了 FSDP 实例的计算设备是目标设备。对于选项 1 和 3,FSDP 初始化总是发生在 GPU 上。对于选项 2,FSDP 初始化发生在模块的当前设备上,该设备可能是 CPU。
如果你使用的是
sync_module_states=True标志,你需要确保模块在GPU上,或者使用device_id参数指定一个CUDA设备,FSDP将在FSDP构造函数中将其移动到该设备。这是必要的,因为sync_module_states=True需要GPU通信。FSDP 还负责将输入张量移动到前向方法所在的 GPU 计算设备,因此您无需手动从 CPU 移动它们。
对于
use_orig_params=True,ShardingStrategy.SHARD_GRAD_OP暴露了未分片的 参数,而不是前向传播后的分片参数,不像ShardingStrategy.FULL_SHARD。如果你想 检查梯度,你可以使用summon_full_params方法与with_grads=True。使用
limit_all_gathers=True时,您可能会在FSDP预前向传播中看到一个间隙,此时CPU线程没有发出任何内核。这是有意为之,并显示了速率限制器正在生效。以这种方式同步CPU线程可以防止为后续的全聚操作过度分配内存,并且实际上不应延迟GPU内核执行。FSDP 在前向和后向计算过程中,为了与 autograd 相关的原因,会用
torch.Tensor替换托管模块的参数视图。如果你的模块前向传播依赖于保存的参数引用,而不是在每次迭代中重新获取引用,那么它将看不到 FSDP 新创建的视图,autograd 将无法正确工作。最后,在使用
sharding_strategy=ShardingStrategy.HYBRID_SHARD时,当分片过程组为节点内且复制过程组为节点间时,设置NCCL_CROSS_NIC=1可以帮助在某些集群配置下改善复制过程组的全减时间。限制条件
在使用 FSDP 时,需要注意以下几个限制:
FSDP 当前不支持在使用 CPU 卸载时在外进行梯度累加。
no_sync()这是因为 FSDP 使用新减少的梯度而不是与任何现有梯度进行累加,这可能导致错误的结果。FSDP 不支持运行包含在 FSDP 实例中的子模块的前向传播。这是因为子模块的参数会被分片,但子模块本身并不是一个 FSDP 实例,因此其前向传播不会适当收集完整的参数。
FSDP 由于其注册反向钩的方式,无法与双重反向一起使用。
在冻结参数时,FSDP 有一些约束。 对于
use_orig_params=False,每个 FSDP 实例必须管理 全部被冻结或全部未被冻结的参数。对于use_orig_params=True,FSDP 支持混合使用冻结和 未被冻结的参数,但为了避免高于预期的梯度内存使用, 建议避免这样做。截至PyTorch 1.12,FSDP 对共享参数提供有限支持。如果你的使用案例需要增强的共享参数支持,请在 此问题 中发帖。
您应该避免在前向传播和反向传播之间不使用
summon_full_params上下文的情况下修改参数,因为这些修改可能不会持久。
- Parameters
模块 (nn.Module) – 这是要用FSDP包装的模块。
进程组 (可选[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是模型分片的进程组,因此也是用于FSDP的all-gather和reduce-scatter集体通信的进程组。如果为
None,则FSDP使用默认的进程组。对于如ShardingStrategy.HYBRID_SHARD这样的混合分片策略,用户可以传递一个进程组的元组,分别表示要分片和复制的组。如果为None,则FSDP为用户构建进程组,以在节点内分片和跨节点复制。(默认:None)分片策略 (可选[ShardingStrategy]) – 此配置设置分片策略,该策略可能在内存节省和通信开销之间进行权衡。详情请参见
ShardingStrategy。 (默认值:FULL_SHARD)cpu_offload (可选[CPUOffload]) – 此配置用于CPU卸载。如果此值设置为
None,则 不会发生CPU卸载。详情请参见CPUOffload。 (默认值:None)自动包装策略 (可选[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, 自定义策略]]) –
这指定了一个策略,用于将FSDP应用于
module的子模块, 这是通信和计算重叠所必需的,因此会影响性能。如果None,那么FSDP仅应用于module,用户应手动将FSDP应用于父模块 本身(自下而上进行)。为了方便起见,这直接接受ModuleWrapPolicy,允许用户指定要包装的 模块类(例如变压器块)。否则, 这应该是一个可调用函数,它接受三个参数module: nn.Module、recurse: bool和nonwrapped_numel: int,并应返回一个bool,指定 传递的module是否应在recurse=False时应用FSDP,或者如果recurse=True则继续遍历模块的子树。用户可以向可调用函数添加其他 参数。size_based_auto_wrap_policy中的torch.distributed.fsdp.wrap.py给出了一个示例可调用函数,该函数在模块子树中的参数超过 100M numel时应用FSDP。我们建议在应用FSDP后打印模型 并根据需要进行调整。Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (可选[BackwardPrefetch]) – 此配置显式预取所有gather的反向传播。如果
None,则FSDP不进行反向预取,并且在反向传播过程中没有通信和计算重叠。详情请参见BackwardPrefetch。 (默认值:BACKWARD_PRE)混合精度 (可选[MixedPrecision]) – 此配置为FSDP设置原生混合精度。如果此值设置为
None, 则不使用混合精度。否则,可以设置参数、缓冲区和梯度减少的数据类型。详情请参见MixedPrecision。 (默认值:None)ignored_modules (可选[Iterable[torch.nn.Module]]) – 本实例将忽略这些模块及其子模块的参数和缓冲区。在
ignored_modules中的任何直接模块都不应是FullyShardedDataParallel实例,并且如果它们嵌套在本实例下,已经构建的FullyShardedDataParallel实例的任何子模块也不会被忽略。此参数可用于避免在使用auto_wrap_policy时按模块粒度对特定参数进行分片,或者如果参数的分片不由FSDP管理。(默认值:None)参数初始化函数 (可选[Callable[[nn.Module], None]]) –
一个
Callable[torch.nn.Module] -> None用于指定当前在元设备上的模块应该如何初始化到实际设备上。从v1.12开始,FSDP通过is_meta检测参数或缓冲区在元设备上的模块,并根据指定应用param_init_fn或者调用nn.Module.reset_parameters()。在这两种情况下,实现应该仅初始化模块的参数/缓冲区,而不是其子模块的参数/缓冲区。这是为了避免重新初始化。此外,FSDP还支持通过torchdistX的(https://github.com/pytorch/torchdistX)deferred_init()API进行延迟初始化,其中延迟模块通过调用指定的param_init_fn或者torchdistX的默认materialize_module()进行初始化。如果指定了param_init_fn,则将其应用于所有元设备模块,这意味着它可能需要根据模块类型进行判断。FSDP在参数扁平化和分片之前调用初始化函数。Example:
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
设备ID (可选[Union[int, torch.device]]) – 一个
int或torch.device,指定进行FSDP初始化的CUDA设备,包括模块初始化(如果需要)和参数分片。如果module在CPU上运行,应指定此参数以提高初始化速度。如果已设置默认CUDA设备(例如通过torch.cuda.set_device设置),则用户可以将torch.cuda.current_device传递给此参数。 (默认:None)sync_module_states (布尔值) – 如果
True,则每个FSDP模块将从rank 0广播模块参数和缓冲区以确保它们在各rank之间复制(为此构造函数增加通信开销)。这可以帮助以一种节省内存的方式加载state_dict检查点通过load_state_dict。参见FullStateDictConfig的示例。(默认:False)forward_prefetch (bool) – 如果
True,则 FSDP 显式地 在当前前向计算之前预取下一个前向传递的全归约。这仅对CPU密集型工作负载有用,在这种情况下,提前发出下一个全归约可能会提高重叠度。这应该只用于静态图模型,因为预取遵循第一次迭代的执行顺序。(默认值:False)limit_all_gathers (bool) – 如果
True,则 FSDP 明确同步 CPU 线程以确保仅从两个连续的 FSDP 实例(当前正在运行计算的实例和下一个预取 all-gather 的实例)使用 GPU 内存。如果False,则 FSDP 允许 CPU 线程在没有任何额外同步的情况下发出 all-gathers。(默认值:True)我们通常将此功能称为“速率限制器”。此标志仅应为特定的 CPU 密集型工作负载设置为False,这些工作负载具有较低的内存压力,在这种情况下,CPU 线程可以积极地发出所有内核而不必担心 GPU 内存使用情况。use_orig_params (bool) – 将此设置为
True会使 FSDP 使用module的原始参数。FSDP 通过nn.Module.named_parameters()向用户暴露这些原始参数,而不是 FSDP 内部的FlatParameters。这意味着 优化器步骤在原始参数上运行,从而启用 每个原始参数的超参数。FSDP 保留原始 参数变量并在非分片和分片形式之间操作它们的数据,其中它们始终是底层 非分片或分片FlatParameter的视图,分别。使用当前算法,分片形式始终为 1D,丢失了 原始张量结构。一个原始参数可能在其给定秩中具有全部、部分或没有数据。在没有的情况下, 其数据将类似于大小为 0 的空张量。用户不应编写依赖于给定 原始参数在其分片形式中存在什么数据的程序。需要True来 使用torch.compile()。将其设置为False会通过nn.Module.named_parameters()向用户暴露 FSDP 的内部FlatParameters。(默认值:False)忽略的状态 (可选[可迭代对象[torch.nn.Parameter]], 可选[可迭代对象[torch.nn.Module]]) – 忽略的参数或模块,这些参数或模块将不由该FSDP实例管理,这意味着参数不会被分片且其梯度不会在不同排名间进行归约。此参数与现有的
ignored_modules参数统一,并且我们可能会很快弃用ignored_modules。为了向后兼容性,我们保留了ignored_states和ignored_modules`,但FSDP只允许指定其中一个,而不是None。设备网格 (可选[DeviceMesh]) – 设备网格可以作为进程组的替代方案。当传递设备网格时,FSDP 将使用底层进程组进行全gather和reduce-scatter通信。因此,这两个参数需要互斥。对于像
ShardingStrategy.HYBRID_SHARD这样的混合分片策略,用户可以传递一个2D设备网格而不是进程组元组。对于2D FSDP + TP,用户必须传递设备网格而不是进程组。有关更多设备网格信息,请访问: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
- apply(fn)[source]¶
递归地将
fn应用于每个子模块(由.children()返回)以及自身。典型用法包括初始化模型的参数(另见 torch.nn.init)。
与
torch.nn.Module.apply相比,此版本在应用fn之前还会收集所有参数。不应在另一个summon_full_params上下文中调用它。- Parameters
fn (
Module-> None) – 应用于每个子模块的函数- Returns
自我
- Return type
- clip_grad_norm_(max_norm, norm_type=2.0)[source]¶
裁剪所有参数的梯度范数。
范数是在将所有参数的梯度视为单个向量的情况下计算的,并且梯度会在原地进行修改。
- Parameters
- Returns
参数的总范数(视为单个向量)。
- Return type
如果每个FSDP实例都使用
NO_SHARD,这意味着梯度不会在各个秩之间进行分片,那么你可以直接使用torch.nn.utils.clip_grad_norm_()。如果至少有一个FSDP实例使用了分片策略(即除了
NO_SHARD之外的其他策略),那么你应该使用这种方法,而不是torch.nn.utils.clip_grad_norm_(),因为这种方法处理了梯度在各个秩之间分片的事实。返回的总范数将具有所有参数/梯度中由PyTorch类型提升语义定义的“最大”dtype。例如,如果所有参数/梯度使用低精度dtype,则返回的范数的dtype将是该低精度dtype,但如果至少存在一个参数/梯度使用FP32,则返回的范数的dtype将是FP32。
警告
这需要在所有级别上调用,因为它使用了集体通信。
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source]¶
扁平化分片优化器状态字典。
API 类似于
shard_full_optim_state_dict()。唯一的 区别是输入sharded_optim_state_dict应该从sharded_optim_state_dict()返回。因此,每个秩上将 有 all-gather 调用来收集ShardedTensors。- Parameters
sharded_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存分片优化器状态。
模型 (torch.nn.Module) – 参见
shard_full_optim_state_dict()。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。
- Returns
- Return type
- static fsdp_modules(module, root_only=False)[source]¶
返回所有嵌套的FSDP实例。
这可能包括
module本身,并且仅在root_only=True的情况下包含 FSDP 根模块。- Parameters
模块 (torch.nn.Module) – 根模块,可能是一个或不是一个
FSDP模块。root_only (布尔值) – 是否仅返回FSDP根模块。 (默认值:
False)
- Returns
嵌套在输入
module中的FSDP模块。- Return type
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]¶
返回完整的优化器状态字典。
将完整的优化器状态在rank 0上进行整合并返回它 作为一个
dict,遵循torch.optim.Optimizer.state_dict()的惯例,即带有键"state"和"param_groups"。在FSDP模块中包含的model中的扁平化参数被映射回它们未扁平化的参数。这需要在所有秩上调用,因为它使用了 集体通信。但是,如果
rank0_only=True,那么 状态字典仅在秩 0 上填充,而所有其他秩 返回一个空的dict。与
torch.optim.Optimizer.state_dict()不同,此方法 使用完整的参数名称作为键,而不是参数ID。就像在
torch.optim.Optimizer.state_dict()中一样,优化器状态字典中包含的张量不会被克隆,因此可能会有别名意外。为了最佳实践,请考虑立即保存返回的优化器状态字典,例如使用torch.save()。- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_input (可选[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入
optim,表示参数组的列表或参数的可迭代对象; 如果为None,则此方法假设输入为model.parameters()。此参数已弃用,无需再传递它。(默认值:None)rank0_only (布尔值) – 如果
True,仅在 rank 0 上保存填充的dict;如果False,则在所有 ranks 上保存。 (默认:True)组 (dist.ProcessGroup) – 模型的进程组,或者如果使用默认进程组则为
None。 (默认:None)
- Returns
一个
dict包含优化器状态,适用于model的原始未展平参数,并包含键 “state” 和 “param_groups”,遵循torch.optim.Optimizer.state_dict()的约定。如果rank0_only=True, 则非零秩返回一个空的dict。- Return type
Dict[字符串, 任意]
- static get_state_dict_type(module)[source]¶
获取根节点为
module的FSDP模块的状态字典类型及其对应的配置。目标模块不一定是FSDP模块。
- Returns
一个
StateDictSettings包含当前设置的状态字典类型和 状态字典 / 优化器状态字典配置。- Raises
如果StateDictSettings不同,则引发`AssertionError` –
FSDP 子模块不同。 –
- Return type
- named_buffers(*args, **kwargs)[source]¶
返回一个迭代器,遍历模块的缓冲区,同时生成缓冲区的名称和缓冲区本身。
拦截缓冲区名称并移除所有在
summon_full_params()上下文管理器内的FSDP特定扁平化缓冲区前缀的所有实例。
- named_parameters(*args, **kwargs)[source]¶
返回一个迭代器,遍历模块参数,同时生成参数的名称和参数本身。
拦截参数名称并移除所有在
summon_full_params()上下文管理器内部出现的FSDP特定扁平化参数前缀的所有实例。
- no_sync()[source]¶
禁用跨FSDP实例的梯度同步。
在此上下文中,梯度将在模块变量中累积,并在退出上下文后的第一次前向-反向传递中进行同步。这仅应在根FSDP实例上使用,并将递归应用于所有子FSDP实例。
注意
这可能会导致更高的内存使用,因为FSDP将累积整个模型的梯度(而不是梯度碎片),直到最终同步。
注意
当与CPU卸载一起使用时,在上下文管理器内部,梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。
- Return type
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source]¶
转换与分片模型对应的优化器的状态字典。
给定的状态字典可以转换为以下三种类型之一: 1) 完整优化器状态字典,2) 分片优化器状态字典,3) 本地优化器状态字典。
对于完整的优化器state_dict,所有状态都是未展平且未分片的。 仅限Rank0和CPU,可以通过
state_dict_type()指定以 避免内存不足(OOM)。对于分片优化器的state_dict,所有状态都是未展平但已分片的。 仅限CPU可以通过
state_dict_type()进一步节省 内存。对于本地的 state_dict,不会进行任何转换。但是,状态将从 nn.Tensor 转换为 ShardedTensor 以表示其分片特性(这目前还不支持)。
Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_state_dict (Dict[str, Any]) – 要转换的目标优化器状态字典。如果值为 None,则将使用 optim.state_dict()。( 默认值:
None)组 (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为
None。 ( 默认值:None)
- Returns
一个
dict包含优化器状态用于model。优化器状态的分片基于state_dict_type。- Return type
Dict[字符串, 任意]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source]¶
将优化器的状态字典转换为可以加载到与FSDP模型关联的优化器中的格式。
给定一个
optim_state_dict,它通过optim_state_dict()进行转换,将被转换为可以加载到optim的扁平化优化器state_dict,其中model是优化器。model必须由FullyShardedDataParallel进行分片。>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_state_dict (Dict[str, Any]) – 要加载的优化器状态。
is_named_optimizer (布尔值) – 此优化器是否为NamedOptimizer或KeyedOptimizer。仅当
optim是TorchRec的KeyedOptimizer或torch.distributed的NamedOptimizer时,设置为True。直接加载 (布尔值) – 如果设置为 True,此 API 将在返回结果之前调用 optim.load_state_dict(result)。否则,用户需要自行调用
optim.load_state_dict()(默认值:False)组 (dist.ProcessGroup) – 模型的进程组,参数在此组中进行分片或如果使用默认进程组则为
None。 ( 默认值:None)
- Return type
- register_comm_hook(state, hook)[source]¶
注册一个通信钩子。
这是一个增强功能,为用户提供了一个灵活的钩子,他们可以指定FSDP如何在多个工作进程之间聚合梯度。 这个钩子可以用来实现几种算法,如 GossipGrad 和梯度压缩, 这些算法涉及在训练过程中使用不同的通信策略进行参数同步
FullyShardedDataParallel。警告
FSDP通信钩子应在运行初始前向传递之前注册,并且仅注册一次。
- Parameters
状态 (对象) –
传递给钩子以在训练过程中维护任何状态信息。 示例包括梯度压缩中的错误反馈, 以及在 GossipGrad 中与下一个通信的对等方等。 它由每个工作进程本地存储 并由工作进程上的所有梯度张量共享。
钩子 (可调用对象) – 可调用对象,具有以下签名之一: 1)
hook: Callable[torch.Tensor] -> None: 此函数接收一个Python张量,该张量表示 与该FSDP单元包装的模型相关的所有变量的完整、展平且未分片的梯度。 然后执行所有必要的处理并返回None; 2)hook: Callable[torch.Tensor, torch.Tensor] -> None: 此函数接收两个Python张量,第一个张量表示 与该FSDP单元包装的模型相关的所有变量的完整、展平且未分片的梯度 (这些变量未被其他FSDP子单元包装)。第二个张量表示一个预分配大小的张量, 用于存储在归约后分片梯度的一部分。 在这两种情况下,可调用对象执行所有必要的处理并返回None。 具有签名1的可调用对象预计处理NO_SHARD情况下的梯度通信。 具有签名2的可调用对象预计处理分片情况下的梯度通信。
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]¶
重新键入优化器状态字典
optim_state_dict以使用键类型optim_state_key_type。这可以用于实现具有FSDP实例的模型和没有FSDP实例的模型之间的优化器状态字典的兼容性。
要重新键入FSDP完整优化器状态字典(即从
full_optim_state_dict())以使用参数ID并可加载到 非包装模型中:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
要将普通优化器的状态字典从非包装模型重新键入以便可以加载到包装模型中:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns
优化器状态字典已使用
optim_state_key_type指定的参数键重新键入。- Return type
Dict[字符串, 任意]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]¶
将优化器的完整状态字典从秩0分散到所有其他秩。
返回每个秩上的分片优化器状态字典。 返回值与
shard_full_optim_state_dict()相同,在秩 0 上,第一个参数应该是full_optim_state_dict()的返回值。Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
Both
shard_full_optim_state_dict()和scatter_full_optim_state_dict()都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (可选[Dict[str, Any]]) – 优化器状态 字典,对应于未展平的参数,并在秩为0时保存完整的非分片优化器状态;在非零秩上忽略该参数。
模型 (torch.nn.Module) – 根模块(可以是或不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (可选[联合[列表[字典[字符串, 任意]], 可迭代对象[torch.nn.Parameter]]]]) – 传递给优化器的输入,表示参数组或参数的可迭代对象; 如果
None,则此方法假定输入为model.parameters()。此参数已弃用,无需再传递。 (默认值:None)优化器 (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更优选的参数。 (默认:None)组 (dist.ProcessGroup) – 模型的进程组,或者如果使用默认进程组则为
None。 (默认值:None)
- Returns
完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。
- Return type
Dict[字符串, 任意]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]¶
设置目标模块的所有后代FSDP模块的
state_dict_type。还接受(可选)模型和优化器状态字典的配置。 目标模块不一定是FSDP模块。如果目标 模块是FSDP模块,其
state_dict_type也会被更改。注意
此API应仅针对顶级(根)模块调用。
注意
此API使用户能够透明地使用传统的
state_dictAPI,在根FSDP模块被另一个nn.Module包装的情况下进行模型检查点。例如, 以下将确保在所有非FSDP实例上调用state_dict,同时调度到sharded_state_dict实现 用于FSDP:Example:
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 所需的
state_dict_type设置。state_dict_config (可选[StateDictConfig]) – 目标
state_dict_type的配置。optim_state_dict_config (可选[OptimStateDictConfig]) – 优化器状态字典的配置。
- Returns
一个StateDictSettings,其中包含模块的先前state_dict类型和配置。
- Return type
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]¶
共享一个完整的优化器状态字典。
将状态从
full_optim_state_dict映射到展平的参数而不是未展平的参数,并仅限于此秩的优化器状态部分。 第一个参数应该是full_optim_state_dict()的返回值。Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
注意
Both
shard_full_optim_state_dict()和scatter_full_optim_state_dict()都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存完整的非分片优化器状态。
模型 (torch.nn.Module) – 根模块(可以是或不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (可选[联合[列表[字典[字符串, 任意]], 可迭代对象[torch.nn.Parameter]]]]) – 传递给优化器的输入,表示参数组或参数的可迭代对象; 如果
None,则此方法假定输入为model.parameters()。此参数已弃用,无需再传递。 (默认值:None)优化器 (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更优选的参数。 (默认:None)
- Returns
完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。
- Return type
Dict[字符串, 任意]
- static sharded_optim_state_dict(model, optim, group=None)[source]¶
返回优化器状态字典的分片形式。
该API与
full_optim_state_dict()类似,但此API将所有非零维度的状态分块为ShardedTensor以节省内存。 仅当模型state_dict是在上下文管理器with state_dict_type(SHARDED_STATE_DICT):中派生时,才应使用此API。有关详细用法,请参阅
full_optim_state_dict()。警告
返回的状态字典包含
ShardedTensor并且 不能直接用于常规的optim.load_state_dict。
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]¶
设置目标模块的所有后代FSDP模块的
state_dict_type。这个上下文管理器具有与
set_state_dict_type()相同的功能。请阅读set_state_dict_type()的文档以了解详细信息。Example:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 所需的
state_dict_type设置。state_dict_config (可选[StateDictConfig]) – 目标模型的
state_dict配置state_dict_type。optim_state_dict_config (可选[OptimStateDictConfig]) – 优化器
state_dict配置用于目标state_dict_type。
- Return type
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]¶
使用此上下文管理器公开FSDP实例的所有参数。
在模型的前向/反向传播之后,这可能很有用,可以获取参数进行额外处理或检查。它可以接受一个非FSDP模块,并根据
recurse参数调用所有包含的FSDP模块及其子模块的完整参数。注意
这可以用于内部的FSDPs。
注意
这不能在前向或后向传递中使用。也不能从这个上下文中启动前向和后向传递。
注意
参数将在上下文管理器退出后恢复为本地碎片,存储行为与前向传播相同。
注意
完整的参数可以被修改,但只有与本地参数分片对应的部分会在上下文管理器退出后保留(除非
writeback=False,在这种情况下,更改将被丢弃)。在FSDP不切分参数的情况下,目前仅当world_size == 1或NO_SHARD配置时,无论writeback如何,修改都会被保留。注意
此方法适用于本身不是FSDP的模块,但可能包含多个独立的FSDP单元。在这种情况下,给定的参数将应用于所有包含的FSDP单元。
警告
请注意,
rank0_only=True与writeback=True的组合目前不受支持,并将引发错误。这是因为在上下文中,模型参数的形状在不同的秩之间会有所不同,当退出上下文时,写入它们可能会导致秩之间的不一致。警告
请注意,
offload_to_cpu和rank0_only=False将导致完整参数被冗余地复制到CPU内存中,对于位于同一台机器上的GPU,这可能会导致CPU内存溢出的风险。建议使用offload_to_cpu与rank0_only=True。- Parameters
递归 (布尔值, 可选) – 递归地召唤所有嵌套的FSDP实例的所有参数(默认值:True)。
写回 (布尔值, 可选) – 如果
False, 在上下文管理器退出后对参数的修改将被丢弃; 禁用此选项可能会稍微提高效率(默认值:True)rank0_only (布尔值, 可选) – 如果
True,完整的参数仅在全局排名为0的设备上实例化。这意味着在此上下文中,只有排名为0的设备将拥有完整的参数,其他设备将拥有分片的参数。请注意,在此上下文中设置rank0_only=True和writeback=True是不支持的,因为模型参数形状在不同排名之间会有所不同,写入这些参数可能会导致退出上下文时各排名之间的不一致。offload_to_cpu (布尔值, 可选) – 如果
True,全部参数将被卸载到CPU。请注意,这种卸载目前仅在参数被分片时发生(这仅在world_size = 1或NO_SHARD配置时不适用)。建议与offload_to_cpu和rank0_only=True一起使用,以避免将模型参数的冗余副本卸载到相同的CPU内存。with_grads (bool, Optional) – 如果为
True,则梯度也会与参数一起取消分片。目前,这仅在将use_orig_params=True传递给FSDP构造函数并将offload_to_cpu=False传递给此方法时支持。(默认值:False)
- Return type
- class torch.distributed.fsdp.BackwardPrefetch(value)[source]¶
这配置了显式的反向预取,通过在反向传播过程中实现通信和计算的重叠来提高吞吐量,代价是略微增加内存使用。
BACKWARD_PRE: 这种方式可以实现最大程度的重叠,但也会导致内存使用量最大。它会在当前参数集的梯度计算之前预取下一组参数。这使得下一个all-gather和当前梯度计算重叠,在峰值时,它会将当前参数集、下一组参数以及当前梯度同时保留在内存中。BACKWARD_POST: 这减少了重叠,但需要更少的内存使用。这在当前参数集的梯度计算之后预取下一组参数。这将当前reduce-scatter与下一个梯度计算重叠,并在为下一组参数分配内存之前释放当前参数集,仅在峰值时在内存中保留下一组参数和当前梯度集。FSDP 的
backward_prefetch参数接受None,这将完全禁用反向预取。这没有重叠,并且不会增加内存使用量。一般来说,我们不推荐这种设置,因为它可能会显著降低吞吐量。
对于更多的技术背景:对于使用NCCL后端的单个进程组, 任何集合操作,即使是从不同的流中发出的,也会争夺每个设备的NCCL流, 这意味着集合操作发出的相对顺序对重叠很重要。两个反向预取值对应于不同的发出顺序。
- class torch.distributed.fsdp.ShardingStrategy(value)[source]¶
这指定了用于分布式训练的分片策略,由
FullyShardedDataParallel。FULL_SHARD: 参数、梯度和优化器状态被分片。 对于参数,此策略在前向传播之前通过全聚操作取消分片,在前向传播之后重新分片,在反向计算之前取消分片,并在反向计算之后重新分片。对于梯度,它在反向计算之后通过归约分散操作同步并分片它们。分片的优化器状态按秩本地更新。SHARD_GRAD_OP: 在计算过程中,梯度和优化器状态会被分片,并且参数在计算之外也会被分片。对于参数,此策略在前向传播之前进行不分片,在前向传播之后不重新分片,仅在反向传播计算之后重新分片。分片的优化器状态在每个秩上本地更新。在no_sync()中,参数在反向传播计算之后不会被重新分片。NO_SHARD: 参数、梯度和优化器状态不会被分片,而是像PyTorch的DistributedDataParallelAPI一样在各个秩之间复制。对于梯度,此策略 在反向计算后通过全归约进行同步。未分片的优化器状态在每个秩上本地更新。HYBRID_SHARD: 在节点内应用FULL_SHARD,并在节点之间复制参数。这导致通信量减少,因为昂贵的全收集和减少分散操作仅在节点内进行,这对于中等规模的模型可以更高效。_HYBRID_SHARD_ZERO2: 在节点内应用SHARD_GRAD_OP,并在节点之间复制参数。这类似于HYBRID_SHARD,但可能提供更高的吞吐量,因为未分片的参数在前向传递后不会被释放,从而节省了预反向传播中的所有聚集。
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source]¶
这配置了FSDP原生混合精度训练。
- Variables
param_dtype (可选[torch.dtype]) – 这指定了模型参数在前向和后向传播期间的数据类型,因此也决定了前向和后向计算的数据类型。在前向和后向传播之外,分片的参数保持全精度(例如,用于优化器步骤),并且在保存模型检查点时,参数始终以全精度保存。(默认值:
None)reduce_dtype (可选[torch.dtype]) – 这指定了梯度归约(即reduce-scatter或all-reduce)的dtype。如果这是
None但param_dtype不是None,那么它将采用param_dtype值,仍然以低精度运行梯度归约。这允许与param_dtype不同,例如 强制梯度归约以全精度运行。(默认:None)buffer_dtype (可选[torch.dtype]) – 这指定了缓冲区的数据类型。FSDP 不会分割缓冲区。相反,FSDP 在第一次前向传递时将它们转换为
buffer_dtype并在此后保持该数据类型。对于模型检查点,缓冲区以全精度保存,除了LOCAL_STATE_DICT。(默认值:None)keep_low_precision_grads (bool) – 如果
False,则 FSDP 在反向传播后将梯度上转换为全精度,以准备优化器步骤。如果True,则 FSDP 保持梯度在用于梯度减少的数据类型中,如果使用支持低精度运行的自定义优化器,则可以节省内存。 (Default:False)cast_forward_inputs (bool) – 如果
True,则此 FSDP 模块会将其前向参数和关键字参数转换为param_dtype。这是为了确保在前向计算中参数和输入的数据类型匹配,这是许多操作所要求的。当仅对某些而不是所有 FSDP 模块应用混合精度时,可能需要将其设置为True,在这种情况下,混合精度的 FSDP 子模块需要重新转换其输入。(默认值:False)cast_root_forward_inputs (bool) – 如果为
True,则根 FSDP 模块 将其前向参数和关键字参数转换为param_dtype,覆盖cast_forward_inputs的值。对于非根 FSDP 模块, 这不会做任何事情。 (默认:True)_module_classes_to_ignore (Sequence[Type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 这指定了在使用混合精度时要忽略的模块类。
auto_wrap_policy: 这些类别的模块将单独应用FSDP,并禁用混合精度(这意味着最终的FSDP构建将偏离指定的策略)。如果未指定auto_wrap_policy,则此设置无效。此API是实验性的,可能会发生变化。 (Default:(_BatchNorm,))
注意
此API为实验性质,可能会发生变化。
注意
只有浮点张量会被转换为指定的数据类型。
注意
在
summon_full_params中,参数被强制为全精度,但缓冲区不是。注意
Layer norm 和 batch norm 在
float32中累积,即使 它们的输入是低精度的,如float16或bfloat16。 禁用 FSDP 的混合精度仅针对这些规范化模块意味着 仿射参数保持在float32。然而,这会导致 这些规范化模块进行单独的全聚集和缩减分散,这 可能效率低下,因此如果工作负载允许,用户应优先 对这些模块应用混合精度。注意
默认情况下,如果用户传递了一个包含任何
_BatchNorm模块的模型,并指定了一个auto_wrap_policy,那么批归一化模块将单独应用FSDP,并禁用混合精度。请参阅_module_classes_to_ignore参数。注意
MixedPrecision默认具有cast_root_forward_inputs=True和cast_forward_inputs=False。对于根 FSDP 实例, 其cast_root_forward_inputs优先于其cast_forward_inputs。对于非根 FSDP 实例,它们的cast_root_forward_inputs值被忽略。默认设置适用于典型情况,即每个 FSDP 实例具有相同的MixedPrecision配置,并且只需要在模型的前向传递开始时将输入转换为param_dtype。注意
对于具有不同
MixedPrecision配置的嵌套FSDP实例,我们建议设置各个cast_forward_inputs值以配置在每个实例的前向传播之前是否进行输入类型转换。在这种情况下,由于类型转换发生在每个FSDP实例的前向传播之前,父FSDP实例应该在其FSDP子模块之前运行其非FSDP子模块,以避免由于不同的MixedPrecision配置而导致激活数据类型发生变化。Example:
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
上面显示了一个工作示例。另一方面,如果
model[1]被替换为model[0],意味着使用了不同的MixedPrecision的子模块首先运行其前向计算,那么model[1]将错误地看到float16激活而不是bfloat16激活。
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]¶
此配置用于CPU卸载。
- Variables
offload_params (bool) – 这指定了是否在不参与计算时将参数卸载到CPU。如果为
True,则还会将梯度卸载到CPU,这意味着优化器步骤将在CPU上运行。
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source]¶
StateDictConfig是所有state_dict配置类的基础类。用户应实例化一个子类(例如FullStateDictConfig)以配置 FSDP 支持的相应state_dict类型的设置。- Variables
offload_to_cpu (bool) – 如果为
True,则 FSDP 将状态字典值卸载到 CPU,如果为False,则 FSDP 保持它们在 GPU 上。 (默认值:False)
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source]¶
FullStateDictConfig是一个配置类,用于与StateDictType.FULL_STATE_DICT一起使用。我们建议在保存完整状态字典时分别启用offload_to_cpu=True和rank0_only=True,以节省 GPU 内存和 CPU 内存。此配置类应通过以下方式使用state_dict_type()上下文管理器:>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- Variables
rank0_only (bool) – 如果为
True,则只有 rank 0 保存完整的状态字典,而非零 rank 保存空字典。如果为False,则所有 rank 都保存完整的状态字典。(默认值:False)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source]¶
ShardedStateDictConfig是一个配置类,用于与StateDictType.SHARDED_STATE_DICT一起使用。- Variables
_use_dtensor (bool) – 如果为
True,则 FSDP 会将状态字典值保存为DTensor,如果为False,则 FSDP 会将它们保存为ShardedTensor。(默认值:False)
警告
_use_dtensor是ShardedStateDictConfig的一个私有字段 并且它被 FSDP 用来确定状态字典值的类型。用户不应手动修改_use_dtensor。
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source]¶
OptimStateDictConfig是所有optim_state_dict配置类的基本类。用户应实例化一个子类(例如FullOptimStateDictConfig)以配置 FSDP 支持的相应optim_state_dict类型的设置。- Variables
offload_to_cpu (bool) – 如果为
True,则 FSDP 将状态字典的张量值卸载到 CPU,如果为False,则 FSDP 将它们保留在原始设备上(除非启用了参数 CPU 卸载,否则该设备为 GPU)。 (默认值:True)
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source]¶
- Variables
rank0_only (bool) – 如果为
True,则只有 rank 0 保存完整的状态字典,而非零 rank 保存空字典。如果为False,则所有 rank 都保存完整的状态字典。(默认值:False)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source]¶
ShardedOptimStateDictConfig是一个配置类,用于与StateDictType.SHARDED_STATE_DICT一起使用。- Variables
_use_dtensor (bool) – 如果为
True,则 FSDP 会将状态字典值保存为DTensor,如果为False,则 FSDP 会将它们保存为ShardedTensor。(默认值:False)
警告
_use_dtensor是ShardedOptimStateDictConfig的一个私有字段 并且它被 FSDP 用来确定状态字典值的类型。用户不应手动修改_use_dtensor。
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source]¶