目录

FullyShardedDataParallel

class moduleprocess_group=sharding_strategy=cpu_offload=auto_wrap_policy=无backward_prefetch=mixed_precision=ignored_modules=param_init_fn=device_id=sync_module_states=错误[来源]torch.distributed.fsdp.FullyShardedDataParallel

用于跨数据并行工作程序对 Module 参数进行分片的包装器。这 的灵感来自 Xu 等人以及 DeepSpeed 的 ZeRO Stage 3。 FullyShardedDataParallel 通常简称为 FSDP。

例:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

警告

优化器必须在模块包装初始化, 因为 FSDP 将就地分片参数,这将破坏任何 以前初始化的优化器。

警告

如果目标 CUDA 设备具有 ID ,则 (1) 应已放置在该设备上,(2) 设备 应该使用 , 进行设置,或者 (3) 应该传递到构造函数中 论点。此 FSDP 实例的计算设备将是该目标 装置。对于 (1) 和 (3),FSDP 初始化始终在 GPU 上进行。 对于 (2),FSDP 初始化发生在 的当前 device,可能是 CPU。dev_idmoduletorch.cuda.set_device(dev_id)dev_iddevice_idmodule

警告

FSDP 目前不支持在使用 CPU 卸载时在外部进行梯度累积。尝试这样做会产生 结果不正确,因为 FSDP 将使用新降低的梯度 而不是与任何现有梯度累积。no_sync()

警告

构造后更改原始参数变量名称将 导致未定义的行为。

警告

传入 sync_module_states=True 标志需要将 module 在 GPU 上,或使用参数指定 CUDA 设备 FSDP 将 move module to.这是因为需要 GPU 通信。device_idsync_module_states=True

警告

从 PyTorch 1.12 开始,FSDP 仅提供对共享参数的有限支持 (例如,将一个图层的权重设置为另一个图层的权重)。在 特别是,共享参数的模块必须包装为 相同的 FSDP 单元。如果您的 使用案例,请 ping https://github.com/pytorch/pytorch/issues/77724Linear

注意

FSDP 函数的输入将移动到计算设备 (同一设备 FSDP 模块开启)之前,因此用户执行 不必手动从 CPU > GPU 移动输入。forwardforward

参数
  • 模块nn.Module) – 要用 FSDP 包装的模块。

  • process_groupOptional[ProcessGroup]) – 用于分片的进程组

  • sharding_strategyOptional[ShardingStrategy]) – 配置分片算法,不同的分片算法有交易 off 在内存节省和通信开销之间。 如果未指定 sharding_strategy。FULL_SHARD

  • cpu_offloadOptional[CPUOffload]) – CPU 卸载配置。目前只有 parameter 和 gradient CPU 支持卸载。可以通过传入来启用它。请注意,此 当前隐式启用梯度卸载到 CPU,以便 params 和 grads 位于同一设备上才能与 Optimizer 一起使用。这 API 可能会发生更改。默认是在这种情况下有 将不卸载。cpu_offload=CPUOffload(offload_params=True)None

  • auto_wrap_policy可选[可调用]) –

    一个可调用对象,指定一个策略以 FSDP 递归方式包装层。 请注意,此策略目前仅适用于 传入的模块。其余模块始终包装在 返回的 FSDP 根实例。 写入 is Callable 的一个示例,此策略包装层 参数个数大于 100M。 written in 是类 Transformer 模型架构的 Callable 示例。用户可以提供应接受以下参数的自定义可调用对象:、、、 额外的自定义参数也可以添加到自定义的可调用对象中。最好打印出来 分片模型,并检查分片模型是否是什么 应用程序需要,然后进行相应调整。size_based_auto_wrap_policytorch.distributed.fsdp.wrapauto_wrap_policytransformer_auto_wrap_policytorch.distributed.fsdp.wrapauto_wrap_policyauto_wrap_policymodule: nn.Modulerecurse: boolunwrapped_params: intauto_wrap_policy

    例:

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     unwrapped_params: int,
    >>>     # These are customizable for this policy function.
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return unwrapped_params >= min_num_params
    

  • backward_prefetchOptional[BackwardPrefetch]) – 这是一项实验性功能,可能会在 不久的将来。它允许用户启用两种不同的backward_prefetch 算法来帮助反向通信和计算重叠。 每种算法的优缺点在 类 中进行了解释。BackwardPrefetch

  • mixed_precisionOptional[MixedPrecision]) – 实例 描述要使用的混合精度训练配置。 支持配置 Parameter、Buffer 和 Gradient 通信 dtype注意 仅将浮点数据强制转换为降低的精度。这允许 用户可能会节省内存并加快训练速度,同时进行权衡 模型训练期间的准确性。如果 ,则不应用混合精度。 请注意,如果为 FSDP 模型启用了 contains with ,FSDP 将采用 注意通过包装来禁用 Units 的混合精度 它们分别在自己的 FSDP 单元中。 这样做是因为一些内核没有实现 目前减少了类型支持。如果单独包装模型,则 用户必须注意设置 for units。 (默认:MixedPrecisionMixedPrecisionNonemixed_precisionBatchNormauto_wrap_policyBatchNormmixed_precision=NoneBatchNormmixed_precision=NoneBatchNormNone)

  • ignored_modulesOptional[Iterable[torch.nn.Module]]) – 其 自己的 parameters 和子模块的 parameters 和 buffer 是 被此实例忽略。直接进入的模块都不应该是实例,并且任何已经构造的子模块都不会被忽略,如果 它们嵌套在此实例下。此参数可用于 避免在使用 an 时对特定参数进行分片,或者如果参数的分片不是由 FSDP.(默认:ignored_modulesauto_wrap_policyNone)

  • param_init_fn可选[Callable[[nn.模块]]]) –

    一个 指定当前位于 meta 设备上的模块应如何初始化 拖动到实际设备上。请注意,从 v1.12 开始,我们在 meta device 并应用默认初始化,该初始化在传入的 if 上调用方法,否则我们运行以初始化传入的 在。具体而言,这意味着如果对于任何 module 参数,则假定你的模块正确实现了 a,否则将引发错误。请注意,我们还提供对模块的支持 使用 torchdistX 的 (https://github.com/pytorch/torchdistX) API 初始化。在这种情况下,将初始化延迟的模块 通过调用 torchdistX 的默认初始化函数,如果不是,则调用传入的 。这同样适用于初始化所有 meta 模块。 请注意,此初始化函数在执行任何 FSDP 分片之前应用 逻辑。Callable[torch.nn.Module] -> Noneis_metareset_parametersnn.Moduleparam_init_fnparam_init_fnnn.Moduleis_meta=Trueparam_init_fnreset_paramters()deferred_initmaterialize_moduleparam_init_fnNoneCallable

    例:

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module):
    >>>     # responsible for initializing a module, such as with reset_parameters
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_idOptional[Union[inttorch.device]]) – 描述 FSDP 模块应移动到的 CUDA 设备的 或,以确定 FSDP 模块的位置 进行分片等初始化。如果未指定此参数 并且使用的是 CPU,我们将更快地迁移到当前的 CUDA 设备 initialization 并在返回之前移回 CPU。 如果指定,则生成的 FSDP 实例将驻留在此设备上。 请注意,如果已指定但已 在不同的 CUDA 设备上,将引发错误。(默认:inttorch.devicemodulemodulemoduledevice_idmoduleNone)

  • sync_module_statesbool) – 如果,每个单独包装的 FSDP 单元将广播 module 参数,以确保它们在 0 之后的所有等级中都相同 初始化。这有助于确保模型参数在不同等级之间相同 ,但至少会给 增加通信开销 每个单独包装的 FSDP 单元触发一次广播。 这也有助于以内存高效的方式加载 Takes Taken 和 To be loading 的 checkpoint。有关此示例,请参阅文档。(默认:True__init__state_dictload_state_dictFullStateDictConfigFalse)

apply(fn[来源]

递归应用于每个子模块(由 ) 以及自我。典型用途包括初始化模型的参数 (另请参见 torch.nn.init)。fn.children()

与 相比,此版本还收集了 应用之前的完整参数。它不应从 在另一个上下文中。torch.nn.Module.applyfnsummon_full_params

参数

fn ( -> None) – 要应用于每个子模块的函数Module

返回

自我

返回类型

模块

clip_grad_norm_(max_normnorm_type=2.0[来源]

在此时间点剪辑所有渐变。范数是计算 梯度组合在一起,就像它们被连接成一个向量一样。 就地修改渐变。

参数
  • max_normfloat or int) - 梯度的最大范数

  • norm_typefloat or int) - 使用的 p-norm 的类型。可以是无穷大范数。'inf'

返回

参数的总范数(视为单个向量)。

注意

这类似于 但是 处理分区和每个 rank 下的多个设备 罩。默认的 torch util 在这里不适用,因为每个 rank 仅具有模型中所有 grads 的部分视图,因此 为 FSDP 模型调用它会导致不同的缩放 按模型参数的子集应用。torch.nn.utils.clip_grad_norm_

警告

这需要在所有等级上调用,因为同步 primitives 的调用。

static moduleroot_only=False[来源]fsdp_modules

返回所有嵌套的 FSDP 实例,可能包括其自身 并且仅在 .moduleroot_only=True

参数
  • moduletorch.nn.Module) – 根模块,可以是模块,也可以是模块。FSDP

  • root_onlybool) – 是否仅返回 FSDP 根模块。 (默认:False)

返回

嵌套在 输入 .module

返回类型

列表[FullyShardedDataParallel]

static modeloptimoptim_input=Nonerank0_only=True[来源]full_optim_state_dict

合并排名 0 上的完整优化器状态并返回它 遵循 的约定,即使用键和 。模块中扁平化的参数 包含在 中,则映射回其未拼合的参数。"state""param_groups"FSDPmodel

警告

自同步以来,需要在所有等级上调用 使用基元。但是,如果 ,则 状态 dict 仅在排名 0 上填充,所有其他排名都返回 空的 .rank0_only=True

警告

与 不同,此方法 使用完整的参数名称作为键,而不是参数 ID。torch.optim.Optimizer.state_dict()

警告

如果您未作为第一个传递 参数传递给优化器,那么您应该将相同的值传递给 此方法为 .model.parameters()optim_input

注意

与 中一样,张量 包含在优化器状态 dict 中,因此可能会有 是别名惊喜。对于最佳实践,请考虑将 立即返回 Optimizer state dict,例如使用 .torch.save()

参数
  • modeltorch.nn.Module) – 根模块 (可以是也可能不是实例),其参数 传递到 Optimizer 中。optim

  • optimtorch.optim.Optimizer) – 的 Optimizer 参数。model

  • optim_inputOptional[Union[List[Dict[strAny]], Iterable[torch.nn.Parameter]]] – 传入优化器的输入,表示of 参数组或参数的可迭代对象; 如果 ,则此方法假定输入为 。(默认:optimNonemodel.parameters()None)

  • rank0_onlybool) – 如果 ,则仅保存排名 0 上填充的内容;if ,则将其保存在所有等级上。(默认:TrueFalseTrue)

返回

A 包含 的原始未拼合参数的优化器状态,并包含键 “state” 和 “param_groups” 遵循 .如果 则非零排名返回空 modelrank0_only=True

返回类型

Dict[str, 任意]

load_state_dict(state_dict*args[来源]

所有三个 FSDP API 的入口点。默认情况下, 调用 FSDP 模块将导致 FSDP 尝试加载一个 “full” state_dict,即state_dict一个由 完整、未分片、未拼合的原始模块参数。这需要 FSDP 在每个秩上加载完整的参数上下文,这可能会导致 在 GPU OOM 中。因此, API 可用于 configure 之间的实现。因此,用户可以使用 context manager 加载一个仅恢复的本地 state dict 检查点 模块的本地分片。目前,唯一支持的 implementations 是 and (default)。有关创建 FSDP 检查点的文档,请参阅load_state_dictload_state_dictload_state_dictwith self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)StateDictType.LOCAL_STATE_DICTStateDictType.FULL_STATE_DICT

例:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> torch.cuda.set_device(device_id)
>>> my_module = nn.Linear(...)
>>> sharded_module = FSDP(my_module)
>>> checkpoint = torch.load(PATH)
>>> full_state_dict = checkpoint['full_state_dict']
>>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT):
>>>     sharded_module.load_state_dict(full_state_dict)
>>> full_dict.keys()
>>> odict_keys(['weight', 'bias'])
>>> # using local state dict
>>> local_state_dict = checkpoint['local_state_dict']
>>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT):
>>>     sharded_module.load_state_dict(local_state_dict)
>>> local_dict.keys()
>>> odict_keys(['flat_param', 'inner.flat_param'])

警告

这需要在所有等级上调用,因为同步 可以使用 primitives 。

财产 module

使 model.module 可访问,就像 DDP 一样。

named_buffers(*args**kwargs[来源]

用于拦截缓冲区名称的覆盖,以及 删除所有出现的特定于 FSDP 的扁平化缓冲区前缀 当进入 Context Manager 中时。

named_parameters(*args**kwargs[来源]

用于拦截参数名称的 overrides 和 删除所有出现的特定于 FSDP 的扁平化参数前缀 当进入 Context Manager 中时。

no_sync()[来源]

用于禁用跨 FSDP 的梯度同步的上下文管理器 实例。在此上下文中,梯度将在 module 中累积 变量,稍后将在第一个 forward-backward 传递。这应该只是 在根 FSDP 实例上使用,并将递归地应用于所有 子 FSDP 实例。

注意

这可能会导致更高的内存使用率,因为 FSDP 会 累积完整的模型梯度(而不是梯度分片) 直到最终同步。

注意

当与 CPU 卸载一起使用时,梯度不会 在 Context Manager 中卸载到 CPU。相反,他们 只会在最终同步后立即卸载。

财产 params_with_grad

递归返回具有梯度的所有模块参数的列表。

static optim_state_dictoptim_state_key_typemodeloptim_input=None[来源]rekey_optim_state_dict

重新对优化器 state dict 进行 key 操作以使用 key 类型。这可以用来实现 来自 FSDP 的模型的优化器 state dict 之间的兼容性 实例和没有的实例。optim_state_dictoptim_state_key_type

要重新键入 FSDP 完整优化器状态 dict(即 from )以使用参数 ID 并可加载到 非包装模型:

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

要将普通优化器 state dict 从未包装模型重新生成密钥,请将其设置为 loadable to a wrapped model 的

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
返回

优化器状态 dict 使用 由 指定的参数键。optim_state_key_type

返回类型

Dict[str, 任意]

static full_optim_state_dictmodeloptim_input=Nonegroup=None[来源]scatter_full_optim_state_dict

将完整的优化器状态字典从等级 0 分散到所有其他等级, 返回每个排名的分片优化器状态 dict。回归 value 与 相同,并且 rank 0,则第一个参数应为 的返回值

例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意

Both 都可用于获取 分片优化器状态 dict 来加载。假设 full optimizer state dict 驻留在 CPU 内存中,前者需要 每个 rank 在 CPU 内存中拥有完整的 dict,其中每个 rank 单独对 dict 进行分片而不进行任何通信,而 后者只需要 rank 0 即可在 CPU 内存中拥有完整的 dict, 其中,排名 0 将每个分片移动到 GPU 内存(对于 NCCL),并且 适当地将其传达给 Rank。因此,前者具有 更高的总 CPU 内存成本,而后者具有更高的 通信成本。

参数
  • full_optim_state_dictOptional[Dict[strAny]]) – 优化器状态 dict 对应的未扁平化参数,并按住 如果处于 rank 0 上,则为完整的非分片优化器状态;参数 在非零等级上被忽略。

  • modeltorch.nn.Module) – 根模块 (可以是也可能不是实例),其参数 对应于 中的优化器状态。full_optim_state_dict

  • optim_inputOptional[Union[List[Dict[strAny]], Iterable[torch.nn.Parameter]]] – 传入优化器的输入,表示of 参数组或参数的可迭代对象; 如果 ,则此方法假定输入为 ;该参数在非零时被忽略 行列。(默认:Nonemodel.parameters()None)

  • groupOptional[Any]) – 模型的进程组,或者如果使用 默认进程组。(默认:NoneNone)

返回

完整的优化器状态 dict 现在重新映射到 展平参数而不是未展平参数,以及 restricted 以仅包含此 rank 的 Optimizer 状态部分。

返回类型

Dict[str, 任意]

static full_optim_state_dictmodeloptim_input=None[来源]shard_full_optim_state_dict

按以下方式对完整的优化器 state dict 进行分片 将状态重新映射到 flattened 参数,而不是 unflattened 参数 参数,并限制为仅优化器的此 rank 部分 州。第一个参数应为 .full_optim_state_dict

例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

警告

如果您未作为第一个传递 参数传递给优化器,那么您应该将相同的值传递给 此方法为 .model.parameters()optim_input

注意

Both 都可用于获取 分片优化器状态 dict 来加载。假设 full optimizer state dict 驻留在 CPU 内存中,前者需要 每个 rank 在 CPU 内存中拥有完整的 dict,其中每个 rank 单独对 dict 进行分片而不进行任何通信,而 后者只需要 rank 0 即可在 CPU 内存中拥有完整的 dict, 其中,排名 0 将每个分片移动到 GPU 内存(对于 NCCL),并且 适当地将其传达给 Rank。因此,前者具有 更高的总 CPU 内存成本,而后者具有更高的 通信成本。

参数
  • full_optim_state_dictDict[strAny]) – 优化器状态 dict 对应于未展平的参数,并按住 full non-sharded optimizer 状态。

  • modeltorch.nn.Module) – 根模块 (可以是也可能不是实例),其参数 对应于 中的优化器状态。full_optim_state_dict

  • optim_inputOptional[Union[List[Dict[strAny]], Iterable[torch.nn.Parameter]]] – 传入优化器的输入,表示of 参数组或参数的可迭代对象; 如果 ,则此方法假定输入为 。(默认:Nonemodel.parameters()None)

返回

完整的优化器状态 dict 现在重新映射到 展平参数而不是未展平参数,以及 restricted 以仅包含此 rank 的 Optimizer 状态部分。

返回类型

Dict[str, 任意]

state_dict(*args**kwargs[来源]

这是所有三个 FSDP API 的入口点:完整、 local 和 sharded 的 API 中。对于完整的 state dict (),FSDP 尝试取消对模型进行分片 在所有 ranks 上,如果完整模型不能 适合单个 GPU。在这种情况下,用户可以传入 a 以仅保存排名 0 的检查点和/ 或者将其逐层卸载到 CPU 内存,从而实现更大的 检查站。如果 CPU 内存无法容纳完整模型,则用户可以 而是采用本地状态 dict () 这只会保存模型的局部分片。分片状态 dict () 将模型参数保存为 s。可以使用 上下文管理器。state_dictStateDictType.FULL_STATE_DICTFullStateDictConfigStateDictType.LOCAL_STATE_DICTStateDictType.SHARDED_STATE_DICTShardedTensorstate_dict

例:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> torch.cuda.set_device(device_id)
>>> my_module = nn.Linear(...)
>>> sharded_module = FSDP(my_module)
>>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config):
>>>     full_dict = sharded_module.state_dict()
>>> full_dict.keys()
>>> odict_keys(['weight', 'bias'])
>>> # using local state dict
>>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT):
>>>     local_dict = sharded_module.state_dict()
>>> local_dict.keys()
>>> odict_keys(['flat_param', 'inner.flat_param'])

警告

这需要在所有等级上调用,因为同步 可以使用 primitives 。

static modulestate_dict_typestate_dict_config=None[来源]state_dict_type

一个上下文管理器,用于设置所有 descendant 目标模块的 FSDP 模块。目标模块不必 是 FSDP 模块。如果目标模块是 FSDP 模块,它也将被更改。state_dict_typestate_dict_type

注意

此 API 应仅针对顶级 (root) 调用 模块。

注意

此 API 使用户能够透明地使用传统 API 来获取模型检查点,在这种情况下, 根 FSDP 模块由另一个 .例如 以下内容将确保在所有非 FSDP 上调用 实例,同时分派到local_state_dict实现中 对于 FSDP:state_dictnn.Modulestate_dict

例:

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
>>>     checkpoint = model.state_dict()
参数
  • moduletorch.nn.Module) – 根模块。

  • state_dict_typeStateDictType) – 要设置的。state_dict_type

static modulerecurse=Truewriteback=Truerank0_only=Falseoffload_to_cpu=False[来源]summon_full_params

一个上下文管理器,用于公开 FSDP 实例的完整参数。 在前进/后退,模型可以得到 用于其他处理或检查的参数。它可以采用非 FSDP 模块,并将为所有包含的 FSDP 模块调用完整的参数作为 以及他们的孩子,这取决于争论。recurse

注意

这可用于内部 FSDP。

注意

不能在向前或向后传递中使用。也不 可以从此上下文中启动 forward 和 backward。

注意

参数将在上下文之后恢复为其本地分片 manager 退出时,存储行为与 forward 相同。

注意

可以修改 full 参数,但只能修改 portion 对应的本地参数分片将在 上下文管理器退出(除非 ,在这种情况下 更改将被丢弃)。在 FSDP 不分片的情况下 参数(当前仅在 、 或 config 时)保留修改,而不管 .writeback=Falseworld_size == 1NO_SHARDwriteback

注意

此方法适用于本身不是 FSDP 但 可能包含多个独立的 FSDP 商品。在这种情况下,给定的 参数将应用于所有包含的 FSDP 单位。

警告

请注意,目前不支持 with 结合使用,并且会引发 错误。这是因为模型参数形状会有所不同 在上下文中跨等级,并写入它们可能会导致 退出上下文时等级之间的不一致。rank0_only=Truewriteback=True

警告

请注意,和 will 导致完整参数被冗余复制到 CPU 内存 GPU 位于同一台计算机上,这可能会产生 CPU OOM 的 OOM 中。建议与 一起使用。offload_to_cpurank0_only=Falseoffload_to_cpurank0_only=True

参数
  • recursebool可选) – 递归调用嵌套的所有参数 FSDP 实例 (默认值:True)。

  • writebackboolOptional) – 如果 ,对参数的修改是 在上下文管理器存在后丢弃; 禁用此选项可能会稍微更有效(默认值:True)False

  • rank0_onlybool可选) – 如果 ,则完整参数为 仅在全局排名 0 上实现。这意味着,在 context,只有排名 0 才会有完整的参数,而其他 ranks 将具有分片参数。请注意,不支持 with 的设置, 因为模型参数形状会因等级而异 在上下文中,写入它们可能会导致 退出上下文时等级之间的不一致。Truerank0_only=Truewriteback=True

  • offload_to_cpuboolOptional) – 如果 ,则完整参数为 卸载到 CPU。请注意,此卸载目前仅 如果参数是分片的(但事实并非如此,则会出现 对于 world_size = 1 或 config)。推荐 与 to use 搭配 to avoid 模型参数的冗余副本被卸载到相同的 CPU 内存。TrueNO_SHARDoffload_to_cpurank0_only=True

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源