目录

分布式检查点 - torch.distributed.checkpoint

分布式检查点 (DCP) 支持并行加载和保存来自多个 rank 的模型。 它处理加载时重新分片,从而支持在一个集群拓扑中保存并加载到另一个集群拓扑中。

DCP 与 torch.savetorch.load 在几个重要方面有所不同:

  • 它为每个检查点生成多个文件,每个等级至少生成一个文件。

  • 它在原地运行,这意味着模型应该首先分配其数据,然后 DCP 使用该存储。

加载和保存检查点的入口点如下:

torch.distributed.checkpoint.state_dict_saver。savestate_dict*checkpoint_id=storage_writer=planner=process_group=[来源]

以 SPMD 样式保存分布式模型。

此函数与它处理的不同之处在于,每个排名仅保存其本地分片。torch.save()ShardedTensorDTensor

对于每个对象(同时具有 a 和 ), save 将在序列化之前调用。Statefulstate_dictload_state_dictstate_dict

警告

无法保证 PyTorch 版本之间的向后兼容性 用于节省state_dicts。

警告

如果使用 process_group 参数,请确保只有其 ranks 调用 save_state_dict,并且 state_dict 中的所有数据都属于它。

注意

为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存 checkpoint 时,只有一个 shard_group 应调用 save_state_dict 和相应的进程 group 需要传入。

注意

如果没有可用的进程组,则此函数假定目的是保存

state_dict在本地进程中。

参数
  • state_dictDict[strAny]) – 要保存的state_dict。

  • checkpoint_idUnion[stros.PathLikeNone]) – 此检查点实例的 ID。checkpoint_id的含义 取决于存储。它可以是文件夹或文件的路径。 如果存储是键值存储,它也可以是键。 (默认:None)

  • storage_writerOptional[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果这不是 指定,DCP 将根据 checkpoint_id。如果 checkpoint_id 也是 None,则异常将 被提高。(默认:None)

  • plannerOptional[SavePlanner]) – SavePlanner 的实例。如果未指定,则默认的 Planner 的 Planner 中。(默认:None)

  • process_groupOptional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。 (默认:None)

返回

Metadata 对象。

返回类型

元数据

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

注意

save_state_dict 使用 collectives 来协调跨等级的写入。 对于基于 NCCL 的进程组, 在进行通信之前,必须将对象移动到 GPU 设备。 在这种情况下,使用的设备由 提供,用户有责任确保将其设置为: 每个等级都有一个单独的 GPU,通过 .torch.cuda.current_device()torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_saver。async_savestate_dict*checkpoint_id=storage_writer=planner=process_group=[来源]

的异步版本。此代码首先将 state_dict 取消暂存到 暂存存储(默认为 CPU 内存),然后在单独的线程中调用 Savesave

警告

此功能是实验性的,可能会发生更改。

参数
  • state_dictDict[strAny]) – 要保存的state_dict。

  • checkpoint_idUnion[stros.PathLikeNone]) – 此检查点实例的 ID。checkpoint_id的含义 取决于存储。它可以是文件夹或文件的路径。 如果存储是键值存储,它也可以是键。 (默认:None)

  • storage_writerOptional[StorageWriter]) – 用于执行“stage”和“save”的 StorageWriter 实例。如果 未指定,DCP 将根据 checkpoint_id。如果 checkpoint_id 也是 None,则异常将 被提高。(默认:None)

  • plannerOptional[SavePlanner]) – SavePlanner 的实例。如果未指定,则默认的 Planner 的 Planner 中。(默认:None)

  • process_groupOptional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。 (默认:None)

返回

一个 future,用于保存 save 生成的 Metadata 对象。

返回类型

前途

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver。save_state_dictstate_dictstorage_writerprocess_group=coordinator_rank=0no_dist=False规划者=[来源]

此方法已弃用。请切换到 'save'。

返回类型

元数据

torch.distributed.checkpoint.state_dict_loader。loadstate_dict*checkpoint_id=storage_reader=planner=process_group=[来源]

以 SPMD 样式加载 distributed。state_dict

每个等级将尝试读取所需的最少数据量 完成请求的state_dict。在加载或实例时,每个排名仅读取其本地分片的数据。ShardedTensorDTensor

对于每个对象(同时具有 a 和 ), load 将在尝试反序列化之前首先调用,然后在反序列化完成后调用。Statefulstate_dictload_state_dictstate_dictload_state_dict

警告

中的所有张量都必须在其 destination device 之前state_dict

所有非张量数据都使用 torch.load() 加载并就地修改 在 state_dict。

警告

用户必须在根模块上调用 load_state_dict 以确保负载 pos-processing 和非 Tensor 数据正确传播。

参数
  • state_dictDict[strAny]) – 要保存的state_dict。

  • checkpoint_idUnion[stros.PathLikeNone]) – 此检查点实例的 ID。checkpoint_id的含义 取决于存储。它可以是文件夹或文件的路径。 如果存储是键值存储,它也可以是键。 (默认:None)

  • storage_readerOptional[StorageReader]) – 用于执行读取的 StorageWriter 实例。如果这不是 指定,DCP 将根据 checkpoint_id。如果 checkpoint_id 也是 None,则异常将 被提高。(默认:None)

  • plannerOptional[LoadPlanner]) – LoadPlanner 的实例。如果未指定,则默认的 Planner 的 Planner 中。(默认:None)

  • process_groupOptional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。 (默认:None)

返回

没有。

返回类型

没有

例子
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 使用 collectives 来协调跨等级的读取。 对于基于 NCCL 的进程组, 在进行通信之前,必须将对象移动到 GPU 设备。 在这种情况下,使用的设备由 提供,用户有责任确保将其设置为每个 rank 具有单个 GPU,通过 .torch.cuda.current_device()torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_loader。load_state_dictstate_dictstorage_readerprocess_group=coordinator_rank=0no_dist=False规划者=[来源]

此方法已弃用。请切换到 'load'。

以下模块还可用于对用于异步检查点 (torch.distributed.checkpoint.async_save) 的暂存机制进行其他自定义:

torch.distributed.checkpoint.staging 中。AsyncStager*args**kwargs[来源]

此协议旨在为 dcp.async_save 提供自定义和可扩展性,从而允许用户 自定义在并行执行通常的 DCP.Save 路径之前如何暂存数据。 预期的操作顺序(在 torch.distributed.state_dict_saver.async_save 中具体定义) 如下所示:

  1. AsyncStager.stage_data(state_dict):

    此调用为 AsyncStager 提供了“暂存”的机会 state_dict。在这种情况下,暂存的期望和目的是创建一个“训练安全” 状态 dict 的表示形式,这意味着暂存后对模块数据的任何更新都已完成 不应反映在此方法返回的 state dict 中。例如,在默认的 如果在 CPU RAM 上创建整个 state dict 的副本并在此处返回,则允许用户 继续训练,而不会冒更改正在序列化的数据的风险。

  2. dcp.save 在从 Stage 并行返回的 state_dict 上调用。此调用负责

    用于序列化state_dict并将其写入 storage。

  3. 如果 AsyncStager.should_synchronize_after_execute 为 True,则该方法将在

    序列化线程在从 dcp.async_save 返回之前启动。如果设置为 False,则 假设用户已定义自定义同步点,以便进一步 优化保存训练循环中的延迟(例如,通过将暂存与 forward/backward pass),并且这是用户在适当时间调用 AsyncStager.synchronize_staging 的响应能力。

属性should_synchronize_after_execute: bool

是否在执行 stage 后同步。

stagestate_dict[来源]

返回 state_dict 的 “暂存” 副本。暂存副本的预期是它 从 stage 调用完成后发生的任何更新中接种。

返回类型

dict[strunion[StatefulT任意]]

synchronize_staging[来源]

如果 stage 在某种程度上是异步的,则应调用此方法以确保暂存 已完成,可以安全地开始修改原始state_dict

torch.distributed.checkpoint.staging 中。BlockingAsyncStagercache_staged_state_dict=Falsetype_check=False[来源]

AsyncStager 的一种实现,它将state_dict暂存在 CPU RAM 上并阻止,直到复制完成。 此实现还提供了一个选项,用于使用固定内存优化阶段延迟。

注意 在这种情况下,synchronize_staging 是无操作的。

stagestate_dict[来源]

返回 CPU 上 state_dict 的副本。

返回类型

dict[strunion[StatefulT任意]]

synchronize_staging[来源]

无操作功能,因为 staging 是阻塞的。

除了上述入口点之外,如下所述的 Stateful 对象在保存/加载期间提供了额外的自定义 ..automodule::torch.distributed.checkpoint.stateful

torch.distributed.checkpoint.stateful 中。Stateful*args**kwargs[来源]

用于可进行检查点和还原的对象的有状态协议。

load_state_dictstate_dict[来源]

从提供state_dict恢复对象的状态。

参数

state_dictDict[strAny]) – 要从中恢复的状态 dict

state_dict[来源]

对象应将其state_dict表示形式作为字典返回。 此函数的输出将被检查点,稍后在 load_state_dict() 中恢复。

警告

由于恢复检查点的就地性质,此函数 也在 torch.distributed.checkpoint.load 期间调用。

返回

对象状态 dict

返回类型

字典

此示例说明如何使用 Pytorch 分布式检查点保存 FSDP 模型。

以下类型定义了 checkpoint 期间使用的 IO 接口:

torch.distributed.checkpoint 中。StorageReader[来源]

用于从存储中读取的接口。load_state_dict

一个 StorageReader 实例同时充当协调器和追随者 在分布式检查点中。作为初始化的一部分,每个实例 被告知其角色。

子类应期望以下调用序列:load_state_dict

  1. (所有排名)如果用户传递有效的checkpoint_id,则设置为 checkpoint_id。

  2. (所有等级) read_metadata()

  3. (所有等级) set_up_storage_reader()

  4. (所有等级) prepare_local_plan()

  5. (协调) prepare_global_plan()

  6. (所有等级) read_data()

摘要 prepare_global_plan计划[来源]

执行存储加载的集中规划。

此方法仅在协调器实例上调用。

虽然这种方法可以产生完全不同的计划,但首选的 方式是将存储特定数据存储在 LoadPlan::storage_data 中。

参数

plansList[LoadPlan]) – 实例列表,每个等级一个实例。LoadPlan

返回

存储后全局规划的改造清单LoadPlan

返回类型

列表[LoadPlan]]

摘要 prepare_local_plan计划[来源]

执行特定于存储的本地规划。

虽然此方法可以生成完全不同的计划,但建议的 方式是将存储特定数据存储在 LoadPlan::storage_data 中。

参数

planLoadPlan) – 正在使用的本地计划。LoadPlan

返回

A 仓储后本地规划LoadPlan

返回类型

LoadPlan (负载计划)

摘要 read_dataPlanPlanner[来源]

从 using 中读取所有项目以解析数据。planplanner

应调用子类以反序列化 BytesIO object 移动到正确的位置。LoadPlanner::load_bytes

子类应该调用以获取对 应该将数据加载到的张量。LoadPlanner::resolve_tensor

StorageLayer 负责正确安排任何跨设备副本 必填。

参数
  • planLoadPlan) – 要执行的本地计划

  • plannerLoadPlanner) – 用于解析项目的 planner 对象。

返回

一个 future ,在完成所有读取后完成。

返回类型

Future[无]

摘要 read_metadata[来源]

读取检查点元数据。

返回

与正在加载的检查点关联的元数据对象。

返回类型

元数据

抽象 resetcheckpoint_id=[来源]

调用 to 表示将要进行全新的检查点读取。 如果用户将 checkpoint_id checkpoint_id 设置为 这个检查点读取。checkpiont_id的含义是 依赖于存储。它可以是文件夹/文件的路径或 键值存储。

参数

checkpoint_idUnion[stros.PathLikeNone]) – 此检查点实例的 ID。checkpoint_id的含义 取决于存储。它可以是文件夹或文件的路径。 如果存储更像键值存储,它也可以是键。 (默认:None)

摘要 set_up_storage_readermetadatais_coordinator[来源]

初始化此实例。

参数
  • metadataMetadata) (元数据) – 要使用的元数据架构。

  • is_coordinatorbool) – 此实例是否负责协调 检查点。

抽象类方法 validate_checkpoint_idcheckpoint_id[来源]

检查存储是否支持给定的 checkpoint_id。这允许 us 启用自动存储选择。

返回类型

布尔

torch.distributed.checkpoint 中。StorageWriter[来源]

用于写入存储的接口。save_state_dict

一个 StorageWriter 实例同时充当协调器和追随者 在分布式检查点中。作为初始化的一部分,每个实例 被告知其角色。

子类应期望以下调用序列。

  1. (所有排名)如果用户传递有效的checkpoint_id,则设置为 checkpoint_id。

  2. (所有等级) set_up_storage_writer()

  3. (所有等级) prepare_local_plan()

  4. (协调) prepare_global_plan()

  5. (所有等级) write_data()

  6. (协调器) finish()

抽象完成元数据结果[来源]

写入元数据并将当前检查点标记为成功。

用于序列化元数据的实际格式/架构是 implementation detail 的 implementation detail 中。唯一的要求是它是可恢复的 in 添加到同一对象图中。

参数
  • metadataMetadata) – 新检查点的元数据

  • resultsList[List[WriteResult]]) – 所有排名的 WriteResults 列表。

返回

没有

返回类型

没有

摘要 prepare_global_plan计划[来源]

执行集中存储规划。

此方法仅在协调器实例上调用。

虽然这种方法可以产生完全不同的计划,但首选的 方法是将存储特定数据存储在 SavePlan::storage_data 中。

参数

plansList[SavePlan]) – 实例列表,每个等级一个实例。SavePlan

返回

存储后全局规划的改造清单SavePlan

返回类型

列表[SavePlan]

摘要 prepare_local_plan计划[来源]

执行特定于存储的本地规划。

虽然此方法可以生成完全不同的计划,但建议的 方法是将存储特定数据存储在 SavePlan::storage_data 中。

参数

planSavePlan) – 正在使用的本地计划。SavePlanner

返回

A 仓储后本地规划SavePlan

返回类型

保存计划

抽象 resetcheckpoint_id=[来源]

调用 to 表示将要进行全新的检查点写入。 如果用户将 checkpoint_id checkpoint_id 设置为 this checkpoint 写入。checkpiont_id的含义是 依赖于存储。它可以是文件夹/文件的路径或 键值存储。

参数

checkpoint_idUnion[stros.PathLikeNone]) – 此检查点实例的 ID。checkpoint_id的含义 取决于存储。它可以是文件夹或文件的路径。 如果存储是键值存储,它也可以是键。 (默认:None)

摘要 set_up_storage_writeris_coordinator[来源]

初始化此实例。

参数

is_coordinatorbool) – 此实例是否负责协调 检查点。

storage_meta)[来源]

返回特定于存储的元数据。这用于存储其他信息 在可用于提供请求级可观察性的检查点中。存储元 在保存调用期间传递给 。默认情况下返回 None。SavePlanner

TODO:提供示例

返回类型

可选[StorageMeta]

抽象类方法 validate_checkpoint_idcheckpoint_id[来源]

检查存储是否支持给定的 checkpoint_id。这允许 us 启用自动存储选择。

返回类型

布尔

摘要 write_dataPlanPlanner[来源]

写入 using 中的所有项目以解析数据。planplanner

子类应该调用每个项目 从 plan 获取对要写入的基础对象的访问权。SavePlanner::resolve_data

子类应该延迟调用 resolve_data因为它可以分配内存。 对于 Tensors,请做出以下假设:

  • 它们可能在任何设备上,包括不匹配的WriteItem::tensor_data

  • 它们可能是视图,也可能不是连续的。只需保存投影。

参数
  • planSavePlan) – 要执行的保存计划。

  • plannerSavePlanner) – 用于将项目解析为数据的 Planner 对象。

返回

完成 WriteResult 列表的 future

返回类型

future[List[WriteResult]]]

以下类型定义了 checkpoint 期间使用的 planner 接口:

torch.distributed.checkpoint 中。LoadPlanner[来源]

定义 load_state_dict 用于规划加载过程的协议的抽象类。

LoadPlanner 是可用于自定义整个加载过程的有状态对象。

LoadPlanner 充当 state_dict 的访问代理,因此对其执行的任何转换 将对整个过程可见。

在 load_state_dict期间,Planner 子类可以预期以下 Sequences 调用:

  1. set_up_planner - 召集所有级别。

    表示开始加载检查点。

  2. create_local_plan - 召集所有等级。

    处理 state_dict 并生成 LoadPlan,该 LoadPlan 将发送以进行全局规划。

  3. create_global_plan - 仅根据协调者级别调用。

    从所有级别获取 LoadPlan 并做出任何全局决策。

  4. load_bytes - 在每个等级上多次调用

    在 state_dict 中,每个非张量值调用一次。

  5. resolve_tensor 和 commit_tensor - 在每个等级上多次调用

    它们对 state_dict 中的每个 Tensor 值成对调用。

建议用户直接扩展 DefaultLoadPlanner 而不是这个接口 大多数更改可以通过单个方法中的更改来表示。

有两种常见的扩展模式:

重写state_dict。这是扩展 load 进程的最简单方法,因为它 并没有完全理解 LoadPlan 工作原理的内在含义。我们需要 在加载发生时保留对原始state_dict的引用,因此 我们需要能够在现场执行

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         metadata: Metadata,
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

修改 resolve_tensor 和 commit_tensor 以处理加载时间转换。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
摘要 commit_tensorread_itemTensor[来源]

在 StorageReader 完成将数据加载到 后调用 。tensor

提供的张量与调用 . 仅当此 LoadPlanner 需要在 将其复制回 state_dict 中的那个。resolve_tensortensor

tensor 的内容将遵循其设备同步模型。

摘要 create_global_planglobal_plan[来源]

计算每个等级的全局负载计划和返回计划。

.注这仅在协调器等级上调用

返回类型

列表[LoadPlan]]

摘要 create_local_plan[来源]

根据 set_up_planner 提供的 state_dict 和元数据创建 LoadPlan。

.注每个等级都调用了这个版本。

返回类型

LoadPlan (负载计划)

摘要 finish_plancentral_plan[来源]

接受来自协调器的计划并返回最终的 LoadPlan。

返回类型

LoadPlan (负载计划)

摘要 load_bytesread_itemvalue[来源]

加载 描述的项目。read_item``and ``value

此方法应就地修改基础state_dict。

的内容由用于生成 正在加载的 checkpoint。value

resolve_bytesread_item[来源]

返回 StorageReader 用于加载read_item的 BytesIO。

BytesIO 应在底层 state_dict 上别名为 1,因为 StorageReader 将替换其内容。

返回类型

字节 IO

摘要 resolve_tensorread_item[来源]

返回 描述的张量,供 StorageReader 用于加载read_itemread_item

张量应在底层 state_dict 上别名为 1,因为 StorageReader 将替换其内容。 如果出于任何原因无法复制数据,规划者可以使用该方法复制数据 回到 state_dict 年的那个。commit_tensor

返回类型

张肌

摘要 set_up_plannerstate_dictmetadata=Noneis_coordinator=False[来源]

初始化此实例以将数据加载到 中。state_dict

.注每个等级都调用了这个版本。

torch.distributed.checkpoint 中。LoadPlanitems List[torch.distributed.checkpoint.planner.ReadItem]storage_data Any = Noneplanner_data:任何 = [来源]
torch.distributed.checkpoint 中。ReadItem类型torch.distributed.checkpoint.planner.LoadItemType,dest_indextorch.distributed.checkpoint.metadata.MetadataIndex,dest_offsetstorch.大小storage_indextorch.distributed.checkpoint.metadata.MetadataIndex,storage_offsetstorch。尺寸长度Torch。尺寸[来源]
torch.distributed.checkpoint 中。SavePlanner[来源]

定义 save_state_dict 用于规划保存过程的协议的抽象类。

SavePlanners 是可用于自定义整个保存过程的有状态对象。

SavePlanner 充当 state_dict 的访问代理,因此对它所做的任何转换 将对整个过程可见。

在 save_state_dict期间,Planner 子类可以预期以下 Sequences 调用:

  1. set_up_planner - 召集所有级别。

    指示检查点保存开始。

  2. create_local_plan - 召集所有等级。

    处理state_dict并生成 SavePlan,该 SavePlan 将发送用于全球规划。

  3. create_global_plan - 仅根据协调者级别调用。

    从所有级别中获取 SavePlan 并做出任何全局决策。

  4. finish_plan - 召集所有军衔。

    这使每个等级都有机会适应全局规划决策。

  5. resolve_data - 在每个等级上多次调用

    state_dict 上查找要写入的存储层的值。

建议用户将 DefaultSavePlanner 而不是这个接口直接扩展为 大多数更改可以通过单个方法中的更改来表示。

有 3 种常见的扩展模式:

重写state_dict。这是扩展保存过程的最简单方法,因为它 并不完全理解 SavePlan 工作原理的内在含义:

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         storage_meta: Optional[StorageMeta],
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)

同时修改本地计划和查找。这在精细控制数据的持久化方式时非常有用

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤做出每个等级无法单独做出的中央决策

>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>>     # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         iters = [iter(all_plans[0].items)] * len(all_plans)
>>>         items_per_rank = [
>>>             [item for item in items if item is not None]
>>>             for items in zip(*zip_longest(*iters), strict=True)
>>>         ]
>>>         all_plans = [
>>>             replace(plan, items=items)
>>>             for plan, items in zip(all_plans, items_per_rank, strict=True)
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,一些 planner 需要在 checkpoint 中保存额外的元数据,这是 通过让每个 rank 在本地计划中贡献其数据项来完成,并且 全局规划器会聚合它们:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
摘要 create_global_planall_plans[来源]

计算全局检查点计划并返回每个 rank 的局部计划。

这仅在协调器等级上调用。

返回类型

元组[List[SavePlan], 元数据]

摘要 create_local_plan[来源]

计算当前排名的保存计划。

这将被聚合并传递给 create_global_plan。 Planner 特定的数据可以通过 SavePlan::p lanner_data 传递。

这在所有等级上都调用。

返回类型

保存计划

摘要 finish_plannew_plan[来源]

合并 create_local_plan 创建的计划和 create_global_plan 的结果。

这在所有等级上都调用。

返回类型

保存计划

摘要 resolve_datawrite_item[来源]

转换和准备存储,确保幂等性和线程安全。write_itemstate_dict

查找与 in 关联的对象并应用任何 转换(例如序列化)的 Interface。write_itemstate_dict

对每个排名调用多次,在最终 SavePlan 中对每个 WriteItem 至少调用一次。

此方法应该是幂等的和 thread-save。StorageWriter 实现 可以根据需要随意调用它。

任何分配内存的转换都应该在他的方法 以减少检查点所需的峰值内存。

返回张量时,它们可以位于任何设备或格式上,也可以是视图。 存储层有责任弄清楚如何保存它们。

返回类型

联合[TensorBytesIO]

摘要 set_up_plannerstate_dictstorage_meta=is_coordinator=False[来源]

初始化此规划器以保存 。state_dict

implementations 应保存这些值,因为在 save 过程中不会提供这些值。

这在所有等级上都调用。

torch.distributed.checkpoint 中。SavePlanitems List[torch.distributed.checkpoint.planner.WriteItem]storage_data Any = Noneplanner_data:任何 = [来源]
torch.distributed.checkpoint.planner 中。WriteItemindextypetensor_data=None[来源]

Data类,该类保存有关需要写入存储的内容的信息。

tensor_storage_size)[来源]

计算底层张量的存储大小,如果这不是张量写入,则为 None。

返回

可选[int] 存储大小,以底层张量的字节数(如果有)。

返回类型

可选[int]

我们提供基于文件系统的存储层:

torch.distributed.checkpoint 中。FileSystemReaderpath[来源]
属性checkpoint_id:Union[str PathLike]

返回将用于加载 checkpoint 的 checkpoint_id。

torch.distributed.checkpoint 中。FileSystemWriter路径single_file_per_rank=Truesync_files=Truethread_count=1per_thread_copy_ahead=10000000cache_staged_state_dict=Falseoverwrite=True[来源]

使用文件 IO 的 StorageWriter 的基本实现。

此实现进行了以下假设和简化:

  • 检查点路径是一个空目录或不存在的目录。

  • 文件创建是原子的

检查点由每个写入请求一个文件以及 包含序列化元数据的 .metadata 文件。

stagestate_dict[来源]

AsyncStager.stage 的重写

返回类型

dict[strunion[StatefulT任意]]

我们提供了 LoadPlannerSavePlanner 的默认实现,这些 可以处理所有 torch.distributed 结构,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。

torch.distributed.checkpoint 中。DefaultSavePlannerflatten_state_dict=Trueflatten_sharded_tensors=Truededup_replicated_tensors=Nonededup_save_to_lowest_rank=[来源]
lookup_objectindex[来源]

extension 来轻松扩展默认 planner。

返回类型

任何

transform_objectwrite_itemobject[来源]

extension 来轻松扩展默认 planner。

torch.distributed.checkpoint 中。DefaultLoadPlannerflatten_state_dict=Trueflatten_sharded_tensors=Trueallow_partial_load=False[来源]

DefaultLoadPlanner 在 LoadPlanner 的基础上添加了多个功能。

具体而言,它添加了以下内容:

flatten_state_dict:使用嵌套 dict 处理 state_dict flatten_sharded_tensors:对于 2D 并行模式下的 FSDP allow_partial_load:如果为 False,则如果键存在于 state_dict 中,但不存在于检查点中,则会引发运行时错误。

lookup_tensorindex[来源]

extension 来轻松扩展默认 planner。

返回类型

张肌

transform_tensorread_itemtensor[来源]

extension 来轻松扩展默认 planner。

由于传统的设计决策,FSDPDDP 的状态字典可能具有不同的键或完全限定名称(例如,layer1.weight),即使原始的未并行化模型相同。此外,FSDP 还提供各种类型的模型状态词典,例如完整和分片状态词典。此外,优化器状态字典使用参数 ID 而不是完全限定名称来标识参数,这可能会在使用并行性(例如,管道并行性)时引起问题。

为了应对这些挑战,我们提供了一系列 API 供用户轻松管理state_dicts。get_model_state_dict 返回一个模型状态字典,其键与未并行化模型状态字典返回的键一致。同样,get_optimizer_state_dict 为 optimizer 状态字典提供在所有应用的 parallelism 中 uniform 的键。为了实现这种一致性,get_optimizer_state_dict 会将参数 ID 转换为与未并行化模型状态字典中查找的名称相同的完全限定名称。

请注意,这些 API 返回的结果可以直接与 torch.distributed.checkpoint.save()torch.distributed.checkpoint.load() 方法一起使用,而无需任何额外的转换。

请注意,此功能是实验性的,API 签名将来可能会更改。

torch.distributed.checkpoint.state_dict。get_state_dictmodeloptimizers*submodules=Noneoptions=None[来源]

返回模型 state_dict 和 state_dict 优化器。

get_state_dict可以处理 PyTorch 并行化的任何模块 FSDP/fully_shard、DDP/复制、tensor_parallel/parallelize_module 和任何 这些并行度的组合。的主要功能是:1.) 返回一个可以重新分片的模型和优化器state_dict 具有不同数量的 trainer 和/或不同的 parallelisms。 2.) 隐藏特定于并行度的 state_dict API。用户不必调用 这些 API 的 API 进行验证。 3.) 对结果进行健全性检查state_dict。get_state_dict

结果状态字典的键是规范的 FQN(完全 限定名称)。规范 FQN 是指基于参数的 position 在 nn.模块层次结构。更具体地说,将规范 FQN 转换为 parameter 是模块未由任何 parallelisms 的 Parallel Actions.由于优化器内部使用参数 ID 来表示 参数,则参数 ID 将转换为 规范 FQN。module.named_parameters()module.named_buffers()

get_state_dict还可以处理未并行化的模块。在 在这种情况下,只执行一个功能——将 optimizer 参数 ID 设置为规范 FQN。get_state_dict

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict
参数
  • 模型NN.Module) – nn.Module 添加到模型中。

  • optimizersUnion[NoneOptimizerIterable[Optimizer]]) – 用于优化的优化器。model

  • submodulesdeprecated) – 可选[Set[nn.Module]]:仅返回模型参数 属于子模块。

  • optionsStateDictOptions) – 控制方式的选项 应返回 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions

返回

Tuple,其中包含 Model state_dict 和 Optimizer state_dict。

返回类型

元组[Dict[str, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict。get_model_state_dictmodel*submodules=Noneoptions=None[来源]

返回 的模型state_dict 。model

有关详细用法,请参见。get_state_dict

参数
  • 模型NN.Module) – nn.Module 添加到模型中。

  • submodulesdeprecated) – 可选[Set[nn.Module]]:仅返回模型参数 属于子模块。

  • optionsStateDictOptions) – 控制方式的选项 应返回 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions

返回

的 state_dict 。model

返回类型

Dict[str, 值类型]

torch.distributed.checkpoint.state_dict。get_optimizer_state_dictmodeloptimizers*submodules=Noneoptions=None[来源]

返回优化器的组合state_dict。

有关详细用法,请参见。get_state_dict

参数
  • 模型NN.Module) – nn.Module 添加到模型中。

  • optimizersUnion[NoneOptimizerIterable[Optimizer]]) – 用于优化的优化器。model

  • submodulesdeprecated) – 可选[Set[nn.Module]]:仅返回模型参数 属于子模块。

  • optionsStateDictOptions) – 控制方式的选项 应返回 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions

返回

的 state_dict 。optimizers

返回类型

优化器状态类型

torch.distributed.checkpoint.state_dict。set_state_dictmodeloptimizers*model_state_dictoptim_state_dictoptions=None[来源]

加载模型state_dict state_dict 和优化器。

的对应项将 state_dict 设置为 model,并将 优化器。given 和 do not 必须返回,但必须满足以下条件 要求: 1) 所有 FQN 都是 中定义的规范 FQN, 2) 如果张量被分片,则它必须是 ShardedTensor 或 DTensor, 3) optimizer state_dict 不能包含参数 ID;键应该是 规范的 FQN。get_state_dictmodel_state_dictoptim_state_dictget_state_dictget_state_dict

参数
  • 模型NN.Module) – nn.Module 添加到模型中。

  • optimizersUnion[OptimizerIterable[Optimizer]]) – 用于优化的优化器。model

  • model_state_dictDict[strValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 模型state_dict加载。如果 的 key 是 nn.Module 的 Module,键是 的子模块,值应 是子模块的state_dict。加载 state_dict 时, 子模块的前缀将附加到 state_dict。model_state_dictmodel

  • optim_state_dictOptimizerStateType) – OptimizerStateType: 优化器state_dict加载。

  • optionsStateDictOptions) – 控制方式的选项 应加载 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions

返回

  • missing_keys 是一个 str 列表,其中包含模型state_dict的缺失键。

  • unexpected_keys 是一个 str 列表,其中包含模型state_dict的意外键。

返回类型

NamedTuplewith 和 字段missing_keysunexpected_keys

torch.distributed.checkpoint.state_dict。set_model_state_dictmodelmodel_state_dict*options=None[来源]

state_dict加载模型。

的对应项将 state_dict 设置为 型。有关详细用法,请参见。get_model_state_dictset_state_dict

参数
  • 模型NN.Module) – nn.Module 添加到模型中。

  • model_state_dictDict[strValueType]) – (Dict[str, ValueType]): 模型state_dict加载。如果 的 key 是 nn.Module 的 Module,键是 的子模块,值应 是子模块的state_dict。加载 state_dict 时, 子模块的前缀将附加到 state_dict。model_state_dictmodel

  • optionsStateDictOptions) – 控制方式的选项 应加载 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions

返回

  • missing_keys 是包含缺失键的 str 列表

  • unexpected_keys 是包含意外键的 str 列表

返回类型

NamedTuplewith 和 字段missing_keysunexpected_keys

torch.distributed.checkpoint.state_dict。set_optimizer_state_dictmodeloptimizersoptim_state_dict*options=None[来源]

state_dict加载优化器。

的对应项将 state_dict 设置为 优化器。有关详细用法,请参见。get_optimizer_state_dictset_state_dict

参数
  • 模型NN.Module) – nn.Module 添加到模型中。

  • optimizersUnion[OptimizerIterable[Optimizer]]) – 用于优化的优化器。model

  • optim_state_dictOptimizerStateType) – OptimizerStateType: 优化器state_dict加载。

  • optionsStateDictOptions) – 控制方式的选项 应加载 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions

返回

没有

返回类型

没有

torch.distributed.checkpoint.state_dict 类StateDictOptionsfull_state_dict=Falsecpu_offload=Falseignore_frozen_params=Falsekeep_submodule_prefixes=Truestrict=Truebroadcast_from_rank0=Falseflatten_optimizer_state_dict=False[来源]

此数据类指定 get_state_dict/set_state_dict 的工作方式。

  • full_state_dict:如果设置为 True,则 将收集返回的 state_dict。没有 ShardedTensor 和 DTensor 将位于返回的state_dict中。

  • cpu_offload:将所有张量卸载到 CPU。为防止 CPU OOM,如果也是 true,则只有 rank0 才会获得 state_dict 和所有其他等级都将为空state_dict。full_state_dict

  • ignore_frozen_params:如果值为 True,则返回state_dict 不包含任何冻结的参数 – 为 False。 默认值为 False。requires_grad

  • keep_submodule_prefixes(已弃用):当 is not None 时,此选项 指示是否保留 state_dict 键的 submodule 前缀。 或 example,如果子模块是且 该参数为 param.当此选项 为 True,则返回的 state_dict 中的参数键将为 。如果选项为 False,则键将为 。 请注意,如果为 False,则可能存在冲突 FQN,因此 中应该只有一个子模块 。submodulesmodule.pretrainpretrain.layer1.weightpretrain.layer1.weightlayer1.weightkeep_submodule_prefixessubmodules

  • strict:调用 model.load_state_dict() 中。strictset_state_dict

  • broadcast_from_rank0:当选项为 True 时,rank0 应收到一个

    full state_dict,并将广播 state_dict/ optim_state_dict逐个到其他等级。其他等级将获得 根据模型中的局部分片的 Tensors 和 Shard,以及 优化。 使用此选项时,必须设置为 True。 此选项目前仅支持 DTensor,不支持旧版 ShardedTensor。full_state_dict

对于习惯于使用和共享 torch.save 格式模型的用户,提供了以下方法,这些方法提供了用于转换不同格式的离线实用程序。

torch.distributed.checkpoint.format_utils。dcp_to_torch_savedcp_checkpoint_dirtorch_save_path[来源]

给定一个包含 DCP 检查点的目录,此函数会将其转换为 Torch 保存文件。

参数
  • dcp_checkpoint_dirUnion[strPathLike]) – 包含 DCP 检查点的目录。

  • torch_save_pathUnion[strPathLike]) – 用于存储转换后的 Torch 保存文件的文件名。

警告

为避免 OOM,建议仅在单个排名上运行此函数。

torch.distributed.checkpoint.format_utils。torch_save_to_dcptorch_save_pathdcp_checkpoint_dir[来源]

给定 torch 保存文件的位置,将其转换为 DCP 检查点。

参数
  • torch_save_pathUnion[strPathLike]) – Torch 保存文件的文件名。

  • dcp_checkpoint_dirUnion[strPathLike]) – 用于存储 DCP 检查点的目录。

警告

为避免 OOM,建议仅在单个排名上运行此函数。

以下类也可用于从 torch.save 格式在线加载和重新分片模型。

torch.distributed.checkpoint.format_utils 类BroadcastingTorchSaveReadercheckpoint_id=coordinator_rank=0[来源]

StorageReader 用于读取 Torch Save 文件。此读取器将读取整个 checkpoint 在 Coordinator 排名上,然后将每个 Tensor 广播并分片到所有排名。

.注意:旨在与 DynamicMetaLoadPlanner 一起使用

警告

当前实现仅支持加载 Tensor。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
prepare_global_planglobal_plan[来源]

StorageReader 方法的实现

返回类型

列表[LoadPlan]]

prepare_local_plan计划[来源]

StorageReader 方法的实现

返回类型

LoadPlan (负载计划)

read_data计划规划师[来源]

读取协调器等级的Torch保存数据,然后广播 这会产生通信成本,但避免了加载 每个等级上的整个检查点,希望可以防止 OOM 问题

返回类型

Future[无]

read_metadata[来源]

扩展默认 StorageReader 以支持构建元数据文件

返回类型

元数据

resetcheckpoint_id=[来源]

StorageReader 方法的实现

set_up_storage_reader元数据is_coordinator[来源]

StorageReader 方法的实现

类方法 validate_checkpoint_idcheckpoint_id[源代码]

StorageReader 方法的实现

返回类型

布尔

torch.distributed.checkpoint.format_utils 类DynamicMetaLoadPlannerflatten_state_dict=Trueflatten_sharded_tensors=Trueallow_partial_load=False[来源]

DefaultLoadPlanner 的扩展,它根据传入的状态 dict 创建新的 Metadata 对象, 无需从磁盘读取元数据。这在读取没有 元数据文件,例如 Torch Save files。

.注意:旨在与 BroadcastingTorchSaveReader 一起使用

警告

当前实现仅支持加载 Tensor。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
set_up_plannerstate_dictmetadata=Noneis_coordinator=False[来源]

Planner 的设置,通过从状态 dict 创建 Metadata 对象来扩展默认行为

提供了以下实验性接口,以提高生产环境中的可观测性:

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源