目录

分布式检查点 - torch.distributed.checkpoint

分布式检查点 (DCP) 支持并行加载和保存来自多个 rank 的模型。 它处理加载时重新分片,从而支持在一个集群拓扑中保存并加载到另一个集群拓扑中。

DCP 与 torch.savetorch.load 在几个重要方面有所不同:

  • 它为每个检查点生成多个文件,每个等级至少生成一个文件。

  • 它在原地运行,这意味着模型应该首先分配其数据,然后 DCP 使用该存储。

加载和保存检查点的入口点如下:

torch.distributed.checkpoint 中。load_state_dictstate_dictstorage_readerprocess_group=coordinator_rank=0no_dist=False规划者=[来源]

以 SPMD 样式加载 distributed。state_dict

每个等级将尝试读取所需的最少数据量 完成请求的state_dict。加载实例时,每个排名仅读取其本地分片的数据。ShardedTensor

警告

中的所有张量都必须在其 destination device 之前state_dict

所有非张量数据都使用 torch.load() 加载并就地修改 在 state_dict。

警告

用户必须在根模块上调用 load_state_dict 以确保负载 pos-processing 和非 Tensor 数据正确传播。

参数
  • state_dictDict[strAny]) – 要加载state_dict。请注意,此 state dict 将就地更新。

  • storage_readerStorageReader) – 用于从中加载数据的 StorageReader。

  • process_groupProcessGroup) – 用于跨等级同步的 ProcessGroup。

  • coordinator_rankint) – 用于协调检查点的排名。 默认情况下使用 rank0。

  • no_distbool) – 如果 ,则分布式检查点不会保存 在 SPMD 样式中。(默认:TrueFalse)

结果

没有。

返回类型

没有

例子
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_loader = torch.distributed.checkpoint.FileSystemLoader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_loader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 使用 collectives 来协调跨等级的读取。 对于基于 NCCL 的进程组, 在进行通信之前,必须将对象移动到 GPU 设备。 在这种情况下,使用的设备由 提供,用户有责任确保将其设置为每个 rank 具有单个 GPU,通过 .torch.cuda.current_device()torch.cuda.set_device()

torch.distributed.checkpoint 中。save_state_dictstate_dictstorage_writerprocess_group=coordinator_rank=0no_dist=False规划者=[来源]

以 SPMD 样式保存分布式模型。

此函数与它处理的不同之处在于,每个排名仅保存其本地分片。torch.save()ShardedTensor

警告

无法保证 PyTorch 版本之间的向后兼容性 用于节省state_dicts。

警告

如果使用 process_group 参数,请确保只有其 ranks 调用 save_state_dict,并且 state_dict 中的所有数据都属于它。

注意

此函数可用于保存具有初始进程的 state_dict group 通过传递 .这可用于生成 checkpoint 可以被 load_state_dict 消费的是 SPMD 时尚。no_dist=True

参数
  • state_dictDict[strAny]) – 一个state_dict

  • storage_writerStorageWriter) – StorageWrite 实例用于执行写入。

  • process_groupProcessGroup) – 用于跨等级同步的 ProcessGroup。

  • coordinator_rankint) – 用于协调检查点的排名。 默认情况下使用 rank0。

  • no_distbool) – 如果 ,则分布式检查点不会保存 在 SPMD 样式中。(默认:TrueFalse)

结果

Metadata 对象。

返回类型

元数据

>>> my_model = MyModule()
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_writer=fs_stroage_writer,
>>> )

注意

save_state_dict 使用 collectives 来协调跨等级的写入。 对于基于 NCCL 的进程组, 在进行通信之前,必须将对象移动到 GPU 设备。 在这种情况下,使用的设备由 提供,用户有责任确保将其设置为: 每个等级都有一个单独的 GPU,通过 .torch.cuda.current_device()torch.cuda.set_device()

以下类型定义了 checkpoint 期间使用的 IO 接口:

torch.distributed.checkpoint 中。StorageReader[来源]

用于从存储中读取的接口。load_state_dict

一个 StorageReader 实例同时充当协调器和追随者 在分布式检查点中。作为初始化的一部分,每个实例 被告知其角色。

子类应期望以下调用序列:load_state_dict

  1. (所有等级) read_metadata()

  2. (所有等级) set_up_storage_reader()

  3. (所有等级) prepare_local_plan()

  4. (协调) prepare_global_plan()

  5. (所有等级) read_data()

摘要 prepare_global_plan计划[来源]

执行存储加载的集中规划。

此方法仅在协调器实例上调用。

虽然这种方法可以产生完全不同的计划,但首选的 方式是将存储特定数据存储在 LoadPlan::storage_data 中。

参数

plansList[LoadPlan]) – 实例列表,每个等级一个实例。LoadPlan

结果

存储后全局规划的改造清单LoadPlan

返回类型

列表[LoadPlan]]

摘要 prepare_local_plan计划[来源]

执行特定于存储的本地规划。

虽然此方法可以生成完全不同的计划,但建议的 方式是将存储特定数据存储在 LoadPlan::storage_data 中。

参数

planLoadPlan) – 正在使用的本地计划。LoadPlan

结果

A 仓储后本地规划LoadPlan

返回类型

LoadPlan (负载计划)

摘要 read_dataPlanPlanner[来源]

从 using 中读取所有项以解析数据。planplanner

应调用子类以反序列化 BytesIO object 移动到正确的位置。LoadPlanner::load_bytes

子类应该调用以获取对 应该将数据加载到的张量。LoadPlanner::resolve_tensor

StorageLayer 负责正确安排任何跨设备副本 必填。

参数
  • planLoadPlan) – 要执行的本地计划

  • plannerLoadPlanner) – 用于解析项目的 planner 对象。

结果

一个 future ,在完成所有读取后完成。

返回类型

Future[无]

摘要 read_metadata[来源]

读取检查点元数据。

结果

与正在加载的检查点关联的 metatada 对象。

返回类型

元数据

摘要 set_up_storage_readermetadatais_coordinator[来源]

初始化此实例。

参数
  • metadataMetadata) (元数据) – 要使用的元数据架构。

  • is_coordinatorbool) – 此实例是否可负责协调 检查点。

torch.distributed.checkpoint 中。StorageWriter[来源]

用于写入存储的接口。save_state_dict

一个 StorageWriter 实例同时充当协调器和追随者 在分布式检查点中。作为初始化的一部分,每个实例 被告知其角色。

子类应期望以下调用序列。

  1. (所有等级) set_up_storage_writer()

  2. (所有等级) prepare_local_plan()

  3. (协调) prepare_global_plan()

  4. (所有等级) write_data()

  5. (协调器) finish()

抽象完成元数据结果[来源]

写入元数据并将当前检查点标记为成功。

用于序列化元数据的实际格式/架构是 实施细节。唯一的要求是它是可恢复的 in 添加到同一对象图中。

参数
  • metadataMetadata) – 新检查点的元数据

  • resultsList[List[WriteResult]]) – 所有排名的 WriteResults 列表。

结果

没有

返回类型

没有

摘要 prepare_global_plan计划[来源]

执行集中存储规划。

此方法仅在协调器实例上调用。

虽然这种方法可以产生完全不同的计划,但首选的 方法是将存储特定数据存储在 SavePlan::storage_data 中。

参数

plansList[SavePlan]) – 实例列表,每个等级一个实例。SavePlan

结果

存储后全局规划的改造清单SavePlan

返回类型

列表[SavePlan]

摘要 prepare_local_plan计划[来源]

执行特定于存储的本地规划。

虽然此方法可以生成完全不同的计划,但建议的 方法是将存储特定数据存储在 SavePlan::storage_data 中。

参数

planSavePlan) – 正在使用的本地计划。SavePlanner

结果

A 仓储后本地规划SavePlan

返回类型

保存计划

摘要 set_up_storage_writeris_coordinator[来源]

初始化此实例。

参数

is_coordinatorbool) – 此实例是否可负责协调 检查点。

摘要 write_dataPlanPlanner[来源]

写入 using 中的所有项目以解析数据。planplanner

子类应该调用每个项目 从 plan 获取对要写入的基础对象的访问权。SavePlanner::resolve_data

子类应该延迟调用 resolve_data因为它可以分配内存。 如果是张量,请做出以下假设:

  • 它们可能在任何设备上,包括不匹配的WriteItem::tensor_data

  • 它们可能是视图,也可能不是连续的。只需保存投影。

参数
  • planSavePlan) – 要执行的保存计划。

  • plannerSavePlanner) – 用于将项目解析为数据的 Planner 对象。

结果

完成 WriteResult 列表的 future

返回类型

future[List[WriteResult]]]

以下类型定义了 checkpoint 期间使用的 planner 接口:

torch.distributed.checkpoint 中。LoadPlanner[来源]

定义 load_state_dict 用于规划加载过程的协议的抽象类。

LoadPlanner 是可用于自定义整个加载过程的有状态对象。

LoadPlanner 充当 state_dict 的访问代理,因此对它执行的任何转换 将对整个过程可见。

在 load_state_dict期间,Planner 子类可以预期以下 Sequences 调用:

  1. set_up_planner - 召集所有级别。

    表示开始加载检查点。

  2. create_local_plan - 召集所有等级。

    处理 state_dict 并生成 LoadPlan,该 LoadPlan 将发送以进行全局规划。

  3. create_global_plan - 仅根据协调者级别调用。

    从所有级别获取 LoadPlan 并做出任何全局决策。

  4. load_bytes - 在每个等级上多次调用

    在 state_dict 中,每个非张量值调用一次。

  5. resolve_tensor 和 commit_tensor - 在每个等级上多次调用

    它们对 state_dict 中的每个 Tensor 值成对调用。

建议用户直接扩展 DefaultLoadPlanner 而不是这个接口 大多数更改可以通过单个方法中的更改来表示。

有两种常见的扩展模式:

重写state_dict。这是扩展 load 进程的最简单方法,因为它 并没有完全理解 LoadPlan 工作原理的内在含义。我们需要 在加载发生时保留对原始state_dict的引用,因此 我们需要能够在现场执行

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(self, state_dict, metadata, is_coordinator):
>>>         self.original_state_dict = state_dict
>>>         super().set_up_planner(self, {"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)

修改 resolve_tensor 和 commit_tensor 以处理加载时间转换。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
摘要 commit_tensorread_itemTensor[来源]

一旦 StorageReader 完成将数据加载到 中,就会调用此方法。tensor

提供的张量与调用 . 仅当此 LoadPlanner 需要在 将其复制回 state_dict 中的那个。resolve_tensortensor

tensor 的内容将遵循其设备同步模型。

摘要 create_global_planglobal_plan[来源]

计算每个等级的全局负载计划和返回计划。

.注这仅在协调器等级上调用

返回类型

列表[LoadPlan]]

摘要 create_local_plan[来源]

根据 set_up_planner 提供的 state_dict 和元数据创建 LoadPlan。

.注每个等级都调用了这个版本。

返回类型

LoadPlan (负载计划)

摘要 finish_plancentral_plan[来源]

接受来自协调器的计划并返回最终的 LoadPlan。

返回类型

LoadPlan (负载计划)

摘要 load_bytesread_itemvalue[来源]

加载 描述的项目。read_item``and ``value

此方法应就地修改基础state_dict。

的内容由用于生成 正在加载的 checkpoint。value

摘要 resolve_tensorread_item[来源]

返回 描述的张量,供 StorageReader 用于加载read_itemread_item

张量应在底层 state_dict 上别名为 1,因为 StorageReader 将替换其内容。 如果出于任何原因无法复制数据,规划者可以使用该方法复制数据 回到 state_dict 年的那个。commit_tensor

返回类型

张肌

摘要 set_up_plannerstate_dict元数据is_coordinator[来源]

初始化此实例以将数据加载到state_dict

.注每个等级都调用了这个版本。

torch.distributed.checkpoint 中。LoadPlanitems List[torch.distributed.checkpoint.planner.ReadItem]storage_data Any = Noneplanner_data:任何 = [来源]
torch.distributed.checkpoint 中。ReadItem类型torch.distributed.checkpoint.planner.LoadItemType,dest_indextorch.distributed.checkpoint.metadata.MetadataIndex,dest_offsetstorch.大小storage_indextorch.distributed.checkpoint.metadata.MetadataIndex,storage_offsetstorch。尺寸长度Torch。尺寸[来源]
torch.distributed.checkpoint 中。SavePlanner[来源]

定义 save_state_dict 用于规划保存过程的协议的抽象类。

SavePlanners 是可用于自定义整个保存过程的有状态对象。

SavePlanner 充当 state_dict 的访问代理,因此对它所做的任何转换 将对整个过程可见。

在 save_state_dict期间,Planner 子类可以预期以下 Sequences 调用:

  1. set_up_planner - 召集所有级别。

    指示检查点保存开始。

  2. create_local_plan - 召集所有等级。

    处理state_dict并生成 SavePlan,该 SavePlan 将发送用于全球规划。

  3. create_global_plan - 仅根据协调者级别调用。

    从所有级别中获取 SavePlan 并做出任何全局决策。

  4. finish_plan - 召集所有军衔。

    这使每个等级都有机会适应全局规划决策。

  5. resolve_data - 在每个等级上多次调用

    state_dict 上查找要写入的存储层的值。

建议用户将 DefaultSavePlanner 而不是这个接口直接扩展为 大多数更改可以通过单个方法中的更改来表示。

有 3 种常见的扩展模式:

重写state_dict。这是扩展保存过程的最简单方法,因为它 并不完全理解 SavePlan 工作原理的内在含义:

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(self, state_dict, is_coordinator):
>>>         # prefix all keys with `foo_``
>>>         super().set_up_planner(self, {"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)

同时修改本地计划和查找。这在精细控制数据的持久化方式时非常有用

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤做出每个等级无法单独做出的中央决策

>>> from itertools import islice
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>>     # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         def chunk(it, size):
>>>             it = iter(it)
>>>         return list(iter(lambda: tuple(islice(it, size)), ()))
>>>         all_plans = [
>>>             replace(plan, items=items) for plan, items in
>>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,一些 planner 需要在 checkpoint 中保存额外的元数据,这是 通过让每个 rank 在本地计划中贡献其数据项来完成,并且 全局规划器会聚合它们:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
摘要 create_global_planall_plans[来源]

计算全局检查点计划并返回每个 rank 的局部计划。

这仅在协调器等级上调用。

返回类型

元组[List[SavePlan], 元数据]

摘要 create_local_plan[来源]

计算当前排名的保存计划。 这将被聚合并传递给 create_global_plan。 Planner 特定的数据可以通过 SavePlan::p lanner_data 传递。

这在所有等级上都调用。

返回类型

保存计划

摘要 finish_plannew_plan[来源]

合并 create_local_plan 创建的计划和 create_global_plan 的结果。

这在所有等级上都调用。

返回类型

保存计划

摘要 resolve_datawrite_item[来源]

查找与 in 关联的对象并应用任何 转换(例如序列化)的 Interface。write_itemstate_dict

对每个排名调用多次,在最终 SavePlan 中对每个 WriteItem 至少调用一次。

此方法应该是 idepotent 和 thread-save。StorageWriter 实现 可以根据需要随意调用它。

任何分配内存的转换都应该在他的方法 以减少检查点所需的峰值内存。

返回张量时,它们可以位于任何设备或格式上,也可以是视图。 存储层有责任弄清楚如何保存它们。

返回类型

联合[TensorBytesIO]

摘要 set_up_plannerstate_dictis_coordinator[来源]

初始化此规划器以保存 。state_dict

implementations 应保存这些值,因为在 save 过程中不会提供这些值。

这在所有等级上都调用。

torch.distributed.checkpoint 中。SavePlanitems List[torch.distributed.checkpoint.planner.WriteItem]storage_data Any = Noneplanner_data:任何 = [来源]
torch.distributed.checkpoint 中。WriteItemindex torch.distributed.checkpoint.metadata.MetadataIndextype torch.distributed.checkpoint.planner.WriteItemTypetensor_data: Union[torch.distributed.checkpoint.planner.TensorWriteData NoneType] = None[来源]

我们提供基于文件系统的存储层:

torch.distributed.checkpoint 中。FileSystemReaderpath[来源]
torch.distributed.checkpoint 中。FileSystemWriter路径single_file_per_rank=Truesync_files=Truethread_count=1per_thread_copy_ahead=10000000[来源]

使用文件 IO 的 StorageWriter 的基本实现。

此实现进行了以下假设和简化:

  • 检查点路径是一个空目录或不存在的目录。

  • 文件创建是原子的

检查点由每个写入请求一个文件以及 包含序列化元数据的 .metadata 文件。

我们提供了 LoadPlannerSavePlanner 的默认实现,这些 可以处理所有 torch.distributed 结构,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。

torch.distributed.checkpoint 中。DefaultSavePlannerflatten_state_dict=Trueflatten_sharded_tensors=Truededup_replicated_tensors=True[来源]
lookup_objectindex[来源]

这是 planner 界面的扩展,可以轻松扩展默认 planner

返回类型

任何

transform_objectwrite_itemobject[来源]

这是 planner 界面的扩展,可以轻松扩展默认 planner

torch.distributed.checkpoint 中。DefaultLoadPlannerflatten_state_dict=Trueflatten_sharded_tensors=True[来源]

DefaultLoadPlanner 在 LoadPlanner 的基础上添加了多个功能。

具体而言,它添加了以下内容:

flatten_state_dict:使用嵌套 dict 处理 state_dict flatten_sharded_tensors:对于 2D 并行模式下的 FSDP

lookup_tensorindex[来源]

这是 planner 界面的扩展,可以轻松扩展默认 planner

返回类型

张肌

transform_tensorread_itemtensor[来源]

这是 planner 界面的扩展,可以轻松扩展默认 planner

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源