FullyShardedDataParallel¶
-
class
torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False)[source]¶ A wrapper for sharding module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP。
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
警告
优化器必须在模块被包装 之后 进行初始化, 因为 FSDP 会原地对参数进行分片,这将破坏任何 之前初始化的优化器。
警告
如果目标CUDA设备的ID为
dev_id,则应满足以下条件之一:(1)module应该已经放置在该设备上,(2) 应使用torch.cuda.set_device(dev_id)设置设备,或 (3)dev_id应作为device_id构造函数参数传递。 此FSDP实例的计算设备将是该目标设备。对于(1)和(3),FSDP初始化总是在GPU上进行。 对于(2),FSDP初始化发生在module的当前设备上,这可能是CPU。警告
FSDP 目前不支持在使用 CPU 卸载时将梯度累积到
no_sync()之外。尝试这样做会导致 错误的结果,因为 FSDP 将使用新减少的梯度 而不是与任何现有的梯度进行累积。警告
在构造后更改原始参数变量名称将导致未定义行为。
警告
传递 sync_module_states=True 标志需要将模块放在 GPU 上,或者使用
device_id参数指定一个 CUDA 设备,FSDP 会将模块移动到该设备上。这是因为sync_module_states=True需要 GPU 通信。警告
从PyTorch 1.12开始,FSDP仅对共享参数提供有限支持(例如,将一个
Linear层的权重设置为另一个层的)。特别是,共享参数的模块必须作为同一FSDP单元的一部分进行包装。如果您的使用情况需要增强的共享参数支持,请联系https://github.com/pytorch/pytorch/issues/77724注意
输入到 FSDP
forward函数的参数将在运行forward之前被移动到计算设备 (与 FSDP 模块所在的设备相同),因此用户无需手动将输入从 CPU 移动到 GPU。- Parameters
模块 (nn.Module) – 要用FSDP包装的模块。
process_group (Optional[ProcessGroup]) – 分片用的进程组
sharding_strategy (Optional[ShardingStrategy]) – 配置分片算法,不同的分片算法在内存节省和通信开销之间存在权衡。如果未指定 sharding_strategy,则会选择
FULL_SHARD。cpu_offload (Optional[CPUOffload]) – CPU卸载配置。目前仅支持参数和梯度的CPU卸载。可以通过传入
cpu_offload=CPUOffload(offload_params=True)来启用。请注意,这 目前会隐式地启用梯度卸载到CPU,以便 参数和梯度位于同一设备上以与优化器配合使用。这个 API可能会发生变化。默认值是None,在这种情况下 不会进行任何卸载。auto_wrap_policy (Optional[Callable]) –
一个可调用对象,用于指定递归地将层用FSDP包装的策略。 请注意,此策略目前仅适用于传入模块的子模块。 其余模块始终会被返回的FSDP根实例所包装。
size_based_auto_wrap_policy用torch.distributed.fsdp.wrap编写的是 一个auto_wrap_policy可调用对象的例子,该策略会包装参数数量大于100M的层。transformer_auto_wrap_policy用torch.distributed.fsdp.wrap编写的是auto_wrap_policy可调用对象的一个例子,适用于类似Transformer的模型架构。用户可以提供自定义的auto_wrap_policy可调用对象,它应接受以下参数:module: nn.Module,recurse: bool,unwrapped_params: int, 还可以向自定义的auto_wrap_policy可调用对象添加额外的自定义参数。打印出分片后的模型并检查分片后的模型是否符合应用需求,然后进行相应调整是一个良好的实践。Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> unwrapped_params: int, >>> # These are customizable for this policy function. >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return unwrapped_params >= min_num_params
backward_prefetch (Optional[BackwardPrefetch]) – 这是一个实验性功能,未来可能会发生变化。它允许用户启用两种不同的 backward_prefetch 算法,以帮助反向通信和计算重叠。每种算法的优缺点在类
BackwardPrefetch中进行了说明。混合精度 (可选[MixedPrecision]) – 一个
MixedPrecision实例 描述要使用的混合精度训练配置。MixedPrecision支持配置参数、缓冲区和梯度通信的数据类型。请注意 只有浮点数据会被转换为较低的精度。这允许 用户在模型训练过程中通过牺牲精度来节省潜在的内存和加快训练速度。如果None,则不应用混合精度。 请注意,如果启用了mixed_precision的 FSDP 模型 包含带有BatchNorm的auto_wrap_policy,FSDP 将 确保通过将它们单独包装在自己的 FSDP 单元中并使用mixed_precision=None来禁用BatchNorm单元的混合精度。 这是因为在目前有一些BatchNorm内核尚未实现 降低类型的支持。如果单独包装模型, 用户必须注意为BatchNorm单元设置mixed_precision=None。 (默认:None)ignored_modules (Optional[Iterable[torch.nn.Module]]) – 被此实例忽略的模块,这些模块自身的参数以及子模块的参数和缓冲区将被忽略。直接位于
ignored_modules中的模块不应是FullyShardedDataParallel实例,任何已经构建的FullyShardedDataParallel实例如果嵌套在此实例下,也不会被忽略。此参数可用于在使用auto_wrap_policy时避免对特定参数进行分片,或者当参数的分片不由 FSDP 管理时。 (默认值:None)param_init_fn (Optional[Callable[[nn.Module], None]]) –
一个
Callable[torch.nn.Module] -> None,用于指定当前位于元设备上的模块应该如何初始化到实际设备上。请注意,从v1.12版本开始,我们通过is_meta检查来检测元设备上的模块,并应用默认的初始化方法,该方法会在传入的nn.Module上调用reset_parameters方法(如果未指定param_init_fn),否则我们会运行param_init_fn来初始化传入的nn.Module。特别是,这意味着如果任何将被FSDP包装的模块参数的is_meta=True未指定,且param_init_fn也未指定,我们将假设你的模块正确实现了reset_paramters(),否则会抛出错误。请注意,此外我们还支持使用torchdistX的 (https://github.com/pytorch/torchdistX)deferred_initAPI 初始化的模块。在这种情况下,延迟初始化的模块将通过默认的初始化函数进行初始化,该函数会调用 torchdistX 的materialize_module,或者如果传入的param_init_fn不是None,则使用该传入的param_init_fn。同样的Callable会被用来初始化所有元模块。请注意,这个初始化函数会在执行任何FSDP分片逻辑之前应用。Example:
>>> module = MyModule(device="meta") >>> def my_init_fn(module): >>> # responsible for initializing a module, such as with reset_parameters >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]) – 一个
int或torch.device描述 FSDP 模块应被移动到的 CUDA 设备,确定分片等初始化操作的位置。如果未指定此参数 且module在 CPU 上,我们将把module移动到当前 CUDA 设备以加快 初始化,并在返回前将module移回 CPU。 如果指定了该参数,生成的 FSDP 实例将驻留在该设备上。 请注意,如果指定了device_id但module已经 位于不同的 CUDA 设备上,将引发错误。(默认值:None)sync_module_states (bool) – 如果为
True,每个单独封装的 FSDP 单位将从 rank 0 广播模块参数,以确保初始化后所有 rank 上的参数一致。这有助于在训练开始前确保所有 rank 上的模型参数一致,但会增加__init__的通信开销,因为每个单独封装的 FSDP 单位至少会触发一次广播。 这也可以帮助以内存高效的方式加载由state_dict保存并由load_state_dict加载的检查点。有关此功能的示例,请参阅FullStateDictConfig的文档。(默认值:False)
-
apply(fn)[source]¶ 递归地将
fn应用于每个子模块(如由.children()返回的)以及 self。典型用法包括初始化模型的参数 (另见 torch.nn.init)。与
torch.nn.Module.apply相比,此版本在应用fn之前还会收集所有参数。不应在另一个summon_full_params上下文中调用它。- Parameters
fn (
Module-> None) – 应用于每个子模块的函数- Returns
自我
- Return type
-
clip_grad_norm_(max_norm, norm_type=2.0)[source]¶ 在此时剪切所有梯度。规范是将所有梯度视为连接成一个单一向量后计算的。梯度会原地修改。
- Parameters
- Returns
参数的总范数(视为单个向量)。
注意
这类似于
torch.nn.utils.clip_grad_norm_,但 在内部处理了分区和每个rank的多个设备。默认的torch工具在这里不适用,因为每个rank仅拥有模型所有梯度的部分视图,因此 对FSDP模型调用它会导致不同子集的模型参数应用不同的缩放比例。警告
这需要在所有 rank 上调用,因为会使用同步原语。
-
static
fsdp_modules(module, root_only=False)[source]¶ 返回所有嵌套的FSDP实例,可能包括
module本身 并且仅包括FSDP根模块如果root_only=True。- Parameters
模块 (torch.nn.Module) – 根模块,可能是一个或不是一个
FSDP模块。root_only (布尔值) – 是否仅返回FSDP根模块。 (默认值:
False)
- Returns
嵌套在输入
module中的FSDP模块。- Return type
List[FullyShardedDataParallel]
-
static
full_optim_state_dict(model, optim, optim_input=None, rank0_only=True)[source]¶ 将完整的优化器状态在rank 0上进行整合并返回它 作为一个
dict,遵循torch.optim.Optimizer.state_dict()的惯例,即带有键"state"和"param_groups"。在FSDP模块中包含的model中的扁平化参数被映射回它们未扁平化的参数。警告
这需要在所有等级上调用,因为使用了同步原语。然而,如果
rank0_only=True,则状态字典仅在 rank 0 上填充,其他所有 rank 都会返回一个空的dict。警告
与
torch.optim.Optimizer.state_dict()不同,此方法 使用完整的参数名称作为键,而不是参数ID。警告
如果你没有将
model.parameters()作为优化器的第一个参数传递,那么你应该将相同的值以optim_input的形式传递给此方法。注意
就像在
torch.optim.Optimizer.state_dict()中一样,优化器状态字典中包含的张量不会被克隆,因此可能会有别名意外。为了最佳实践,请考虑立即保存返回的优化器状态字典,例如使用torch.save()。- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例)的参数被传递给优化器optim。优化器 (torch.optim.Optimizer) – 用于
model的参数的优化器。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入
optim,表示参数组的list或参数的可迭代对象; 如果为None,则此方法假定输入为model.parameters()。(默认值:None)rank0_only (布尔值) – 如果
True,仅在 rank 0 上保存填充的dict;如果False,则在所有 ranks 上保存。 (默认:True)
- Returns
一个
dict包含优化器状态,适用于model的原始未展平参数,并包含键 “state” 和 “param_groups”,遵循torch.optim.Optimizer.state_dict()的约定。如果rank0_only=True, 则非零秩返回一个空的dict。- Return type
Dict[字符串, 任意]
-
load_state_dict(state_dict, *args)[source]¶ 所有三个FSDP
load_state_dictAPI的入口点。默认情况下, 调用load_state_dict一个FSDP模块将导致FSDP 尝试加载一个“完整”的state_dict,即由 完整的、未拆分的、未展平的原始模块参数组成的state_dict。这需要 FSDP在每个rank上加载完整的参数上下文,这可能导致 GPU内存溢出。因此,state_dict_type()API可用于 在load_state_dict实现之间进行配置。用户可以使用with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)上下文 管理器来加载一个本地state_dict检查点,该检查点仅恢复 模块的本地分片。目前,唯一支持的 实现是StateDictType.LOCAL_STATE_DICT和StateDictType.FULL_STATE_DICT(默认)。有关创建FSDP检查点的文档,请参见state_dict()。Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> torch.cuda.set_device(device_id) >>> my_module = nn.Linear(...) >>> sharded_module = FSDP(my_module) >>> checkpoint = torch.load(PATH) >>> full_state_dict = checkpoint['full_state_dict'] >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT): >>> sharded_module.load_state_dict(full_state_dict) >>> full_dict.keys() >>> odict_keys(['weight', 'bias']) >>> # using local state dict >>> local_state_dict = checkpoint['local_state_dict'] >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): >>> sharded_module.load_state_dict(local_state_dict) >>> local_dict.keys() >>> odict_keys(['flat_param', 'inner.flat_param'])
警告
这需要在所有等级上调用,因为可能会使用同步原语。
-
property
module¶ 使 model.module 可访问,就像 DDP 一样。
-
named_buffers(*args, **kwargs)[source]¶ 覆盖
named_buffers()以拦截缓冲区名称并在summon_full_params()上下文管理器内部移除所有 FSDP 特有的展平缓冲区前缀。
-
named_parameters(*args, **kwargs)[source]¶ 覆盖
named_parameters()以拦截参数名称并 在summon_full_params()上下文管理器内部时 移除所有 FSDP 特有的展平参数前缀。
-
no_sync()[source]¶ 一个上下文管理器,用于禁用 FSDP 实例之间的梯度同步。在此上下文中,梯度将被累积在模块变量中,之后在退出该上下文后的第一次前向-反向传递过程中进行同步。此功能仅应作用于根 FSDP 实例,并会递归地应用于所有子 FSDP 实例。
注意
这可能会导致更高的内存使用,因为FSDP将累积整个模型的梯度(而不是梯度碎片),直到最终同步。
注意
当与CPU卸载一起使用时,在上下文管理器内部,梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。
-
property
params_with_grad¶ 递归地返回所有具有梯度的模块参数列表。
-
static
rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None)[source]¶ 将优化器状态字典
optim_state_dict的键重新映射为使用键 类型optim_state_key_type。这可用于实现具有 FSDP 实例的模型和不具有 FSDP 实例的模型之间的优化器状态字典兼容性。要重新键入FSDP完整优化器状态字典(即从
full_optim_state_dict())以使用参数ID并可加载到 非包装模型中:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
要将普通优化器的状态字典从非包装模型重新键入以便可以加载到包装模型中:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns
优化器状态字典已使用
optim_state_key_type指定的参数键重新键入。- Return type
Dict[字符串, 任意]
-
static
scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, group=None)[source]¶ 将完整的优化器状态字典从 rank 0 散列到所有其他 rank, 并在每个 rank 上返回分片的优化器状态字典。返回值与
shard_full_optim_state_dict()相同,并且在 rank 0 上,第一个参数应该是full_optim_state_dict()的返回值。Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
Both
shard_full_optim_state_dict()和scatter_full_optim_state_dict()都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (可选[Dict[str, Any]]) – 优化器状态 字典,对应于未展平的参数,并在秩为0时保存完整的非分片优化器状态;在非零秩上忽略该参数。
模型 (torch.nn.Module) – 根模块(可以是或不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示一个
list参数组或参数的可迭代对象; 如果为None,则此方法假定输入为model.parameters();在非零 rank 上该参数会被忽略。(默认值:None)group (Optional[Any]) – 模型的进程组或
None如果使用默认进程组。 (默认值:None)
- Returns
完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。
- Return type
Dict[字符串, 任意]
-
static
shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None)[source]¶ 将完整的优化器状态字典
full_optim_state_dict按照分片方式进行处理,通过将状态映射到扁平化参数而非非扁平化参数,并仅保留此 rank 的优化器状态部分。第一个参数应该是full_optim_state_dict()的返回值。Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
警告
如果你没有将
model.parameters()作为优化器的第一个参数传递,那么你应该将相同的值以optim_input的形式传递给此方法。注意
Both
shard_full_optim_state_dict()和scatter_full_optim_state_dict()都可以用来获取分片的优化器状态字典以加载。假设完整的优化器状态字典存储在CPU内存中,前者要求每个秩在CPU内存中都有完整的字典,每个秩独立地对字典进行分片而无需任何通信;而后者只需要秩0在CPU内存中有完整的字典,秩0将每个分片移动到GPU内存(对于NCCL)并适当地与各秩通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (Dict[str, Any]) – 优化器状态字典 对应于未展平的参数,并保存完整的非分片优化器状态。
模型 (torch.nn.Module) – 根模块(可以是或不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示一个
list参数组或参数的可迭代对象; 如果为None,则此方法假定输入为model.parameters()。(默认值:None)
- Returns
完整的优化器状态字典现在映射到 扁平化参数而不是非扁平化参数,并且 仅限于包含此秩的优化器状态部分。
- Return type
Dict[字符串, 任意]
-
state_dict(*args, **kwargs)[source]¶ 这是所有三个FSDP
state_dictAPI的入口点:完整、本地和分片。对于完整的状态字典 (StateDictType.FULL_STATE_DICT),FSDP会尝试在所有rank上进行反分片,如果完整模型无法 适合单个GPU,可能会导致OOM错误。在这种情况下,用户可以传入一个FullStateDictConfig仅在rank 0上保存检查点和/ 或逐层将其卸载到CPU内存中,从而支持更大的检查点。如果完整模型无法放入CPU内存,则用户可以 改用本地状态字典 (StateDictType.LOCAL_STATE_DICT) ,它只保存模型的本地分片。分片状态字典 (StateDictType.SHARDED_STATE_DICT) 将模型参数保存为ShardedTensors。state_dict类型可以通过state_dict_type()上下文管理器进行配置。Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> torch.cuda.set_device(device_id) >>> my_module = nn.Linear(...) >>> sharded_module = FSDP(my_module) >>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config): >>> full_dict = sharded_module.state_dict() >>> full_dict.keys() >>> odict_keys(['weight', 'bias']) >>> # using local state dict >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): >>> local_dict = sharded_module.state_dict() >>> local_dict.keys() >>> odict_keys(['flat_param', 'inner.flat_param'])
警告
这需要在所有等级上调用,因为可能会使用同步原语。
-
static
state_dict_type(module, state_dict_type, state_dict_config=None)[source]¶ 一个上下文管理器,用于设置目标模块所有后代 FSDP 模块的
state_dict_type。目标模块不一定是 一个 FSDP 模块。如果目标模块是一个 FSDP 模块,其state_dict_type也将被更改。注意
此API应仅针对顶级(根)模块调用。
注意
此API使用户能够透明地使用传统的
state_dictAPI,在根FSDP模块被另一个nn.Module包装的情况下进行模型检查点。例如, 以下将确保在所有非FSDP实例上调用state_dict,同时调度到local_state_dict实现 用于FSDP:Example:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): >>> checkpoint = model.state_dict()
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 所需的
state_dict_type设置。
-
static
summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False)[source]¶ 一个上下文管理器,用于暴露FSDP实例的完整参数。 在模型完成前向/反向传播后,可以用于获取参数以进行额外处理或检查。它可以接受一个非FSDP模块,并根据
recurse参数的值,为所有包含的FSDP模块及其子模块召唤完整参数。注意
这可以用于内部的FSDPs。
注意
这不能在前向或后向传递中使用。也不能从这个上下文中启动前向和后向传递。
注意
参数将在上下文管理器退出后恢复为本地碎片,存储行为与前向传播相同。
注意
完整的参数可以被修改,但只有与本地参数分片对应的部分会在上下文管理器退出后保留(除非
writeback=False,在这种情况下,更改将被丢弃)。在FSDP不切分参数的情况下,目前仅当world_size == 1或NO_SHARD配置时,无论writeback如何,修改都会被保留。注意
此方法适用于本身不是FSDP的模块,但可能包含多个独立的FSDP单元。在这种情况下,给定的参数将应用于所有包含的FSDP单元。
警告
请注意,
rank0_only=True与writeback=True的组合目前不受支持,并将引发错误。这是因为在上下文中,模型参数的形状在不同的秩之间会有所不同,当退出上下文时,写入它们可能会导致秩之间的不一致。警告
请注意,
offload_to_cpu和rank0_only=False将导致完整参数被冗余地复制到CPU内存中,对于位于同一台机器上的GPU,这可能会导致CPU内存溢出的风险。建议使用offload_to_cpu与rank0_only=True。- Parameters
递归 (布尔值, 可选) – 递归地召唤所有嵌套的FSDP实例的所有参数(默认值:True)。
writeback (bool, Optional) – 如果
False,在上下文管理器退出后,对 params 的修改将被丢弃; 禁用此功能可能会稍微提高效率(默认:True)rank0_only (布尔值, 可选) – 如果
True,完整的参数仅在全局排名为0的设备上实例化。这意味着在此上下文中,只有排名为0的设备将拥有完整的参数,其他设备将拥有分片的参数。请注意,在此上下文中设置rank0_only=True和writeback=True是不支持的,因为模型参数形状在不同排名之间会有所不同,写入这些参数可能会导致退出上下文时各排名之间的不一致。offload_to_cpu (布尔值, 可选) – 如果
True,全部参数将被卸载到CPU。请注意,这种卸载目前仅在参数被分片时发生(这仅在world_size = 1或NO_SHARD配置时不适用)。建议与offload_to_cpu和rank0_only=True一起使用,以避免将模型参数的冗余副本卸载到相同的CPU内存。