目录

分布式检查点 - torch.distributed.checkpoint

分布式检查点(DCP)支持从多个排名平行加载和保存模型。 它处理加载时的重新分片,从而可以在一种集群拓扑中保存并在另一种集群拓扑中加载。

DCP与torch.savetorch.load在几个重要方面有所不同:

  • 它为每个检查点生成多个文件,每个排名至少一个。

  • 它以就地操作方式进行工作,这意味着模型应先分配其数据,然后DCP使用该存储空间。

加载和保存检查点的入口函数如下:

torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]

以 SPMD 风格保存分布式模型。

此功能不同于 torch.save(),因为它处理 ShardedTensorDTensor,通过让每个排名只保存它们的本地分片。

对于每个Stateful对象(同时具有state_dictload_state_dict), 保存将在序列化之前调用state_dict

警告

在不同版本的 PyTorch 中,无法保证保存的 state_dicts 的向后兼容性。

警告

如果使用process_group参数,请确保只有它的排名调用save_state_dict,并且state_dict中的所有数据都属于它。

注意

在为FSDP的ShardingStrategy.HYBRID_SHARD保存检查点时,shard_group中应该只有一个调用save_state_dict,并且需要传入相应的进程组。

注意

If no process group is available, this function assumes the intention is to save the

在本地进程中的 state_dict。

Parameters
  • state_dict (Dict[str, Any]) – 要保存的state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹路径或文件路径。如果存储是键值存储,则也可以是一个键。(默认值:None)

  • storage_writer (可选[StorageWriter]) – StorageWriter 的实例,用于执行写入操作。如果没有指定此参数,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • 规划器 (可选[SavePlanner]) – SavePlanner 的实例。如果没有指定,将使用默认的规划器。(默认: None)

  • process_group (可选[ProcessGroup]) – 用于跨等级同步的ProcessGroup。 (默认值:None)

Returns

保存检查点的元数据对象。

Return type

元数据

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

注意

save_state_dict 使用集体操作来协调不同排名之间的写入。 对于基于 NCCL 的进程组,在通信发生之前,对象的内部张量表示必须移动到 GPU 设备。 在这种情况下,使用的设备由 torch.cuda.current_device() 指定, 并且用户有责任确保每个排名都有一个单独的 GPU,通过 torch.cuda.set_device() 设置。

torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]

save 的异步版本。此代码首先在 CPU 上取消状态字典的暂存,然后在一个单独的线程中调用 save

警告

此功能处于实验阶段,可能会发生变化。

Parameters
  • state_dict (Dict[str, Any]) – 要保存的state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹路径或文件路径。如果存储是键值存储,则也可以是一个键。(默认值:None)

  • storage_writer (可选[StorageWriter]) – StorageWriter 的实例,用于执行写入操作。如果没有指定此参数,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • 规划器 (可选[SavePlanner]) – SavePlanner 的实例。如果没有指定,将使用默认的规划器。(默认: None)

  • process_group (可选[ProcessGroup]) – 用于跨等级同步的ProcessGroup。 (默认值:None)

Returns

save 中将要持有的 Metadata 对象。

Return type

未来

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]

此方法已弃用。请切换到‘save’。

Return type

元数据

torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[source]

以SPMD风格加载一个分布式的state_dict

每个进程将尝试读取最少的数据以满足请求的state_dict。在加载ShardedTensorDTensor实例时,每个进程只读取其本地分片的数据。

对于每个 Stateful 对象(同时具有 state_dictload_state_dict), load 会在尝试反序列化之前首先调用 state_dict,然后在反序列化完成后调用 load_state_dict

警告

所有在 state_dict 中的张量必须在其目标设备上分配 之前 调用此函数。

所有非张量数据都使用torch.load()加载,并在state_dict上就地修改。

警告

用户必须在根模块上调用load_state_dict以确保加载后处理和非张量数据正确传播。

Parameters
  • state_dict (Dict[str, Any]) – 要保存的state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹路径或文件路径。如果存储是键值存储,则也可以是一个键。(默认值:None)

  • storage_reader (可选[StorageReader]) – StorageWriter 的实例,用于执行读取操作。如果没有指定此参数,DCP 将根据 checkpoint_id 自动推断读取器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • 规划器 (可选[LoadPlanner]) – LoadPlanner 的实例。如果没有指定,将使用默认的规划器。(默认: None)

  • process_group (可选[ProcessGroup]) – 用于跨等级同步的ProcessGroup。 (默认值:None)

Returns

None.

Return type

请提供需要翻译的单词列表。

Examples
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 使用集体操作在各个排名之间协调读取。 对于基于 NCCL 的进程组,在通信发生之前,对象的内部张量表示必须移动到 GPU 设备。 在这种情况下,使用的设备由 torch.cuda.current_device() 指定, 并且用户有责任确保每个排名都有一个单独的 GPU,通过 torch.cuda.set_device() 来实现。

torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]

此方法已弃用。请切换到‘load’。

以下模块也有助于异步检查点所使用的阶段机制的进一步自定义 (torch.distributed.checkpoint.async_save):

class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source]

该协议旨在为dcp.async_save提供自定义和扩展功能,允许用户在并行执行常规的dcp.save路径之前自定义数据的阶段。 操作的预期顺序(在torch.distributed.state_dict_saver.async_save中具体定义)如下:

  1. AsyncStager.stage_data(state_dict):

    此调用为 AsyncStager 提供了“准备” state_dict 的机会。这里的准备预期和目的是创建一个“训练安全”的状态字典表示,这意味着在准备完成后对模块数据的任何更新都不应反映在此方法返回的状态字典中。例如,在默认情况下,整个状态字典的副本会在 CPU 内存中创建并在这里返回,使用户能够在继续训练的同时避免序列化数据发生变化的风险。

  2. dcp.save is called on the state_dict returned from stage in parallel. This call is respondsible

    用于序列化状态字典并将其写入存储。

  3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after

    序列化线程启动,并在dcp.async_save返回之前结束。如果将其设置为False,假设用户已定义了一个自定义同步点,以进一步优化训练循环中的保存延迟(例如,通过将staging与前向/后向传递重叠),此时由用户负责在适当的时候调用AsyncStager.synchronize_staging

property should_synchronize_after_execute: bool

执行阶段后是否同步。

stage(state_dict)[source]

返回一个“阶段化”的state_dict副本。阶段化副本的预期是,在完成阶段调用后,它不受任何更新的影响。

Return type

字典[字符串, 联合[StatefulT, 任意]]

synchronize_staging()[source]

stage以某种方式为异步的情况下,应调用此方法以确保阶段完成,并且可以安全地开始修改原始state_dict

class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source]

实现了 AsyncStager,它将状态字典 staging 在 CPU 内存中,并阻塞直到复制完成。 此实现还提供了使用固定内存优化 staging 延迟的选项。

注意:在此情况下,synchronize_staging 是一个无操作。

stage(state_dict)[source]

返回在CPU上的state_dict副本。

Return type

字典[字符串, 联合[StatefulT, 任意]]

synchronize_staging()[source]

空操作函数,因为暂存阶段是阻塞的。

除了上述入口点外,Stateful个对象,如下面所述,在保存/加载期间提供额外的自定义选项 .. automodule:: torch.distributed.checkpoint.stateful

class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source]

可检查点和恢复的状态化对象协议。

load_state_dict(state_dict)[source]

从提供的状态字典恢复对象的状态。

Parameters

state_dict (字典[字符串, 任意类型]) – 要从中恢复的状态字典

state_dict()[source]

对象应返回其状态字典表示形式作为字典。 此函数的输出将被检查点保存,并在 load_state_dict()中稍后恢复。

警告

由于恢复检查点的就地特性,此函数也在torch.distributed.checkpoint.load期间被调用。

Returns

对象状态字典

Return type

字典

这个 示例 展示了如何使用 Pytorch 分布式检查点保存 FSDP 模型。

以下类型定义了检查点过程中使用的 IO 接口:

class torch.distributed.checkpoint.StorageReader[source]

load_state_dict使用的从存储中读取的接口。

一个 StorageReader 实例在分布式检查点中同时充当协调者和跟随者。 作为初始化的一部分,每个实例都会被告知其角色。

子类应预期以下调用顺序由load_state_dict

  1. (所有排名)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。

  2. (所有排名)read_metadata()

  3. (所有排名)设置存储阅读器()

  4. (所有排名)prepare_local_plan()

  5. (协调员) prepare_global_plan()

  6. (所有排名)读取数据()

abstract prepare_global_plan(plans)[source]

集中规划存储加载。

此方法仅在协调器实例上被调用。

虽然这种方法可以生成完全不同的计划,但更推荐的做法是将特定于存储的数据存储在 LoadPlan::storage_data 中。

Parameters

计划 (列表[加载计划]) – 包含 LoadPlan 个实例的列表,每个实例对应一个等级。

Returns

存储全局规划后的转换列表 LoadPlan

Return type

列表[加载计划]

abstract prepare_local_plan(plan)[source]

执行存储特定的本地规划。

虽然这种方法可以生成完全不同的计划,但推荐的做法是将存储特定数据存储在 LoadPlan::storage_data 中。

Parameters

计划 (加载计划) – 当前使用的 LoadPlan 中的本地计划。

Returns

存储本地规划后转换的 LoadPlan

Return type

LoadPlan

abstract read_data(plan, planner)[source]

使用 planner 读取 plan 中的所有项目以解析数据。

子类应调用LoadPlanner::load_bytes将BytesIO对象反序列化到正确的位置。

子类应调用LoadPlanner::resolve_tensor以访问需要加载数据的张量。

这是 StorageLayer 的职责,负责正确调度任何跨设备复制。

Parameters
Returns

所有读取操作完成后才会完成的未来状态。

Return type

未来[无]

abstract read_metadata()[source]

读取检查点元数据。

Returns

与正在加载的检查点关联的元数据对象。

Return type

元数据

abstract reset(checkpoint_id=None)[source]

调用表示即将进行一次全新的检查点读取。 如果用户为此检查点读取设置了checkpoint_id,则可能会出现checkpoint_id。 checkpoint_id 的含义取决于存储类型。它可以是一个文件夹/文件的路径,也可以是键值存储中的一个键。

Parameters

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹路径或文件路径。如果存储更像键值存储,则也可以是一个键。(默认值:None)

abstract set_up_storage_reader(metadata, is_coordinator)[source]

初始化此实例。

Parameters
  • 元数据 (Metadata) – 要使用的元数据模式。

  • is_coordinator (bool) – 是否此实例负责协调检查点。

abstract classmethod validate_checkpoint_id(checkpoint_id)[source]

检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。

Return type

布尔

class torch.distributed.checkpoint.StorageWriter[source]

save_state_dict使用的接口,用于写入存储。

一个 StorageWriter 实例在一个分布式检查点中同时充当协调者和跟随者。 在初始化过程中,每个实例都会被告知其角色。

一个子类应期望以下调用顺序。

  1. (所有排名)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。

  2. (所有排名)设置存储写入器 ()

  3. (所有排名)prepare_local_plan()

  4. (协调员) prepare_global_plan()

  5. 所有排名 write_data()

  6. (协调员) 结束()

abstract finish(metadata, results)[source]

写入元数据,并将当前检查点标记为成功。

实际用于序列化的metadata格式/模式是一个实现细节。唯一的要求是它可以恢复到相同的对象图。

Parameters
  • 元数据 (Metadata) – 新检查点的元数据

  • 结果 (列表[列表[写入结果]]) – 来自所有排名的写入结果列表。

Returns

请提供需要翻译的单词列表。

Return type

请提供需要翻译的单词列表。

abstract prepare_global_plan(plans)[source]

集中规划存储。

此方法仅在协调器实例上被调用。

虽然这种方法可以生成完全不同的计划,但更推荐的方式是将特定于存储的数据存储在 SavePlan::storage_data 中。

Parameters

计划 (列表[保存计划]) – 包含 SavePlan 个实例的列表,每个实例对应一个等级。

Returns

存储全局规划后的转换列表 SavePlan

Return type

列表[保存计划]

abstract prepare_local_plan(plan)[source]

执行存储特定的本地规划。

虽然这种方法可以生成完全不同的计划,但推荐的做法是将存储特定数据保存在 SavePlan::storage_data 中。

Parameters

计划 (保存计划) – 当前使用的 SavePlanner 中的本地计划。

Returns

存储本地规划后转换的 SavePlan

Return type

SavePlan

abstract reset(checkpoint_id=None)[source]

调用表示即将进行一次全新的检查点写入。 如果用户为此检查点写入设置了checkpoint_id,可能会出现checkpoint_id。 checkpoint_id的意义取决于存储类型。它可以是一个文件夹/文件的路径,也可以是键值存储的键。

Parameters

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id 的含义取决于存储方式。它可以是一个文件夹路径或文件路径。如果存储是键值存储,则也可以是一个键。(默认值:None)

abstract set_up_storage_writer(is_coordinator)[source]

初始化此实例。

Parameters

is_coordinator (bool) – 是否此实例负责协调检查点。

storage_meta()[source]

返回特定存储的元数据。这用于在检查点中存储额外信息,这些信息对于提供请求级别的可观测性很有用。StorageMeta 在保存调用期间传递给 SavePlanner。默认情况下返回 None。

请提供待翻译的文本,以便我可以为您进行准确的翻译。

Return type

可选[存储元数据]

abstract classmethod validate_checkpoint_id(checkpoint_id)[source]

检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。

Return type

布尔

abstract write_data(plan, planner)[source]

使用planner写出所有来自plan的项目以解析数据。

子类应在计划中的每个项目上调用SavePlanner::resolve_data以访问底层对象进行写入。

子类应懒惰地调用resolve_data,因为它可以分配内存。 对于张量,做如下假设:

  • 它们可能出现在任何设备上,包括与WriteItem::tensor_data不匹配的那个

  • 它们可能是视图,也可能不连续。只需保存投影。

Parameters
  • 计划 (保存计划) – 要执行的保存计划。

  • 规划器 (保存规划器) – 用于将项目解析为数据的规划器对象。

Returns

一个将结果完成到 WriteResult 列表的未来

Return type

未来[列表[写入结果]]

以下类型定义了检查点期间使用的计划器接口:

class torch.distributed.checkpoint.LoadPlanner[source]

抽象类,定义了 load_state_dict 使用的协议,以规划加载过程。

LoadPlanner 是有状态的对象,可用于自定义整个加载过程。

LoadPlanner 作为状态字典的访问代理,因此对其所做的任何变换都将对整个过程可见。

在调用 load_state_dict 期间,计划器子类可以预期以下调用顺序:

  1. set_up_planner - called on all ranks.

    表示开始加载检查点。

  2. create_local_plan - called on all ranks.

    处理 state_dict 并生成一个LoadPlan,该值将用于全局规划。

  3. create_global_plan - called on the coordinator rank only.

    从所有 ranks 获取 LoadPlan 并做出任何全局决策。

  4. load_bytes - called multiple times on each rank

    这在状态字典中的每个非张量值上调用一次。

  5. resolve_tensor and commit_tensor - called multiple times on each rank

    它们以成对的方式为 state_dict 中的每个张量值调用。

建议用户扩展 DefaultLoadPlanner 而不是直接扩展此接口,因为大多数更改都可以通过单个方法的更改来表达。

有两种常见的扩展模式:

重写 state_dict。这是扩展加载过程的最简单方式,因为它不需要理解 LoadPlan 的工作原理。在加载过程中需要保持对原始 state_dict 的引用,因此我们需要能够在原地进行操作。

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         metadata: Metadata,
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

修改 resolve_tensor 和 commit_tensor 以处理加载时的转换。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
abstract commit_tensor(read_item, tensor)[source]

调用一次,当StorageReader完成将数据加载到tensor中时。

提供的张量与调用resolve_tensor返回的张量相同。 此方法仅在该LoadPlanner需要在将其复制回state_dict中的张量之前对tensor进行后处理时才需要。

张量的内容将遵循其设备同步模型。

abstract create_global_plan(global_plan)[source]

计算全局负载计划并返回每个排名的计划。

注意:这只会被主节点调用。

Return type

列表[加载计划]

abstract create_local_plan()[source]

基于 set_up_planner 提供的 state_dict 和元数据创建一个 LoadPlan。

注意:这在每个排名上都会被调用。

Return type

LoadPlan

abstract finish_plan(central_plan)[source]

接受协调员的计划并返回最终的负载计划。

Return type

LoadPlan

abstract load_bytes(read_item, value)[source]

加载由 read_item``and ``value 描述的项。

此方法预计将就地修改底层 state_dict。

value 的内容由用于生成正在加载的检查点的 SavePlanner 定义。

resolve_bytes(read_item)[source]

返回由StorageReader加载read_item所使用的BytesIO。

BytesIO 应该与底层状态字典中的一个别名一致,因为 StorageReader 将会替换其内容。

Return type

BytesIO

abstract resolve_tensor(read_item)[source]

返回由read_item描述的张量,供StorageReader加载read_item使用。

张量应与底层state_dict中的一个别名,因为StorageReader将替换其内容。 如果出于任何原因无法做到这一点,计划者可以使用commit_tensor方法将数据复制回state_dict中的一个。

Return type

张量

abstract set_up_planner(state_dict, metadata=None, is_coordinator=False)[source]

初始化此实例以将数据加载到 state_dict

注意:这在每个排名上都会被调用。

class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]
class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]
class torch.distributed.checkpoint.SavePlanner[source]

抽象类,定义了 save_state_dict 使用的协议,以规划保存过程。

SavePlanner 是一种有状态的对象,可用于自定义整个保存过程。

SavePlanner 作为状态字典的访问代理,因此对其所做的任何变换都会在整个过程中可见。

在调用 save_state_dict 期间,计划子类可以预期以下调用顺序:

  1. set_up_planner - called on all ranks.

    标志着检查点保存的开始。

  2. create_local_plan - called on all ranks.

    处理 state_dict 并生成一个SavePlan,该值将用于全局规划。

  3. create_global_plan - called on the coordinator rank only.

    从所有 ranks 中获取 SavePlan 并做出任何全局决策。

  4. finish_plan - called on all ranks.

    这为每个排名有机会调整全局规划决策。

  5. resolve_data - called multiple times on each rank

    在存储层写入时查找state_dict处的值。

建议用户扩展 DefaultSavePlanner 而不是直接扩展此接口,因为大多数更改都可以通过单个方法的修改来表达。

有三种常见的扩展模式:

重写 state_dict。这是扩展保存过程的最简单方式,因为它不需要理解 SavePlan 的工作原理:

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         storage_meta: Optional[StorageMeta],
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)

修改本地计划和查找表的同时进行调整。这在需要精细控制数据如何持久化时很有用。

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤来做出各个层级单独无法做出的关键决策。

>>> from itertools import islice
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>>     # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         def chunk(it, size):
>>>             it = iter(it)
>>>         return list(iter(lambda: tuple(islice(it, size)), ()))
>>>         all_plans = [
>>>             replace(plan, items=items) for plan, items in
>>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,一些规划者需要在检查点中保存额外的元数据,这通过让每个排名在其本地计划中贡献其数据项,并由全局规划者汇总它们来实现:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
abstract create_global_plan(all_plans)[source]

计算全局检查点计划,并返回每个 ranks 的本地计划。

这仅在协调器排名上调用。

Return type

元组[列表[保存计划], 元数据]

abstract create_local_plan()[source]

计算当前排名的保存计划。

这将被聚合并传递给 create_global_plan。 可以通过 SavePlan::planner_data 传递特定于 Planner 的数据。

这在所有排名上都被调用。

Return type

SavePlan

abstract finish_plan(new_plan)[source]

合并由create_local_plan创建的计划和create_global_plan的结果。

这在所有排名上都被调用。

Return type

SavePlan

abstract resolve_data(write_item)[source]

write_itemstate_dict转换并准备存储,确保幂等性和线程安全。

state_dict 中查找与 write_item 相关联的对象,并在存储层使用它之前应用任何转换(如序列化)。

在最终的保存计划中的每个 WriteItem 至少调用一次,并在每个进程中多次调用。

此方法应该是幂等且线程安全的。StorageWriter 实现可以根据需要频繁调用它。

任何分配内存的转换都应在调用其方法时延迟执行,以减少检查点所需的峰值内存。

当返回张量时,它们可以位于任何设备或格式上,也可以是视图。 这是存储层的责任,确定如何保存它们。

Return type

联合[张量, BytesIO]

abstract set_up_planner(state_dict, storage_meta=None, is_coordinator=False)[source]

初始化此计划以保存 state_dict

实现时应保存这些值,因为在保存过程中不会提供这些值。

这在所有排名上都被调用。

class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source]
class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source]

持有需要写入存储的信息的数据类。

tensor_storage_size()[source]

计算底层张量的存储大小,或者如果这不是张量写入,则为 None。

Returns

可选的整数类型存储大小,以字节为单位,表示底层张量的大小(如果有)。

Return type

可选[整数]

我们提供一种基于文件系统的存储层:

class torch.distributed.checkpoint.FileSystemReader(path)[source]
property checkpoint_id: Union[str, PathLike]

返回将用于保存检查点的 checkpoint_id。

class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True)[source]

使用文件 IO 实现的 StorageWriter 基本实现。

此实现做出了以下假设和简化:

  • 检查点路径是一个空目录或不存在的目录。

  • 文件创建是原子操作

检查点由每个写入请求对应的一个文件加上一个.metadata文件组成,该文件包含序列化的元数据。

stage(state_dict)[source]

异步阶段器的阶段重写

Return type

字典[字符串, 联合[StatefulT, 任意]]

我们提供了LoadPlannerSavePlanner的默认实现, 可以处理所有torch.distributed构造,例如FSDP、DDP、ShardedTensor和DistributedTensor。

class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False)[source]
lookup_object(index)[source]

从规划器接口进行扩展,以便轻松扩展默认规划器。

Return type

任何

transform_object(write_item, object)[source]

从规划器接口进行扩展,以便轻松扩展默认规划器。

class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]

在 LoadPlanner 的基础上添加了多个功能的默认加载计划器。

特别是它添加了以下内容:

flatten_state_dict:处理嵌套字典的状态字典 flatten_sharded_tensors:用于2D并行模式下的FSDP allow_partial_load:如果为False,在状态字典中存在而在检查点中不存在键时,将引发运行时错误。

lookup_tensor(index)[source]

从规划器接口进行扩展,以便轻松扩展默认规划器。

Return type

张量

transform_tensor(read_item, tensor)[source]

从规划器接口进行扩展,以便轻松扩展默认规划器。

由于历史设计决策,即使原始未并行化的模型相同,FSDPDDP 的状态字典可能具有不同的键或完全限定名称(例如,layer1.weight)。此外,FSDP 提供了多种类型的模型状态字典,例如完整和分片的状态字典。另外,优化器状态字典使用参数 ID 而不是完全限定名称来标识参数,在使用并行化时(例如管道并行化)可能会导致问题。

为了应对这些挑战,我们提供了一组API,使用户可以轻松管理状态字典。get_model_state_dict 返回一个与未并行化模型状态字典中的键一致的模型状态字典。同样地,get_optimizer_state_dict 提供了一个在所有并行化应用中键统一的优化器状态字典。为了实现这种一致性,get_optimizer_state_dict 将参数ID转换为与未并行化模型状态字典中相同的完全限定名称。

请注意,通过这些API返回的结果可以直接与 torch.distributed.checkpoint.save()torch.distributed.checkpoint.load() 方法一起使用,而无需任何额外的转换。

请注意,此功能仍处于试验阶段,未来的 API 签名可能会发生变化。

torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]

返回模型的 state_dict 和优化器的 state_dict。

get_state_dict 可以处理任何由 PyTorch 并行化的模块,例如 FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module,以及这些并行化方式的任意组合。get_state_dict 的主要功能是:1.) 返回一个可以在不同数量的训练器和/或不同的并行化方式下重新划分的模型和优化器状态字典。2.) 隐藏特定于并行化的状态字典 API。用户无需调用这些 API。3.) 对结果状态字典进行合理性检查。

结果状态字典的键是规范的FQN(全限定名)。规范的FQN是指基于参数在nn.Module层次结构中的位置的FQN。更具体地说,当模块未通过任何并行方式分布时,参数的规范FQN是由module.named_parameters()module.named_buffers()返回的FQN。由于优化器内部使用参数ID来表示一个参数,在调用此API时会将参数ID转换为规范的FQN。

get_state_dict 也可以处理未并行化的模块。在这种情况下,get_state_dict 只执行一个功能——将优化器参数 ID 转换为标准全限定名 (FQN)。

示例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict
Parameters
  • 模型 (nn.Module) – 要本地化的 nn.Module。

  • 优化器 (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • 子模块已弃用) – Optional[Set[nn.Module]]: 只返回属于子模块的模型参数。

  • 选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions

Returns

Tuple 包含模型 state_dict 和优化器 state_dict。

Return type

元组[字典[字符串, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]

返回模型的 state_dict 为 model

参见 get_state_dict 以获取详细用法。

Parameters
  • 模型 (nn.Module) – 要本地化的 nn.Module。

  • 子模块已弃用) – Optional[Set[nn.Module]]: 只返回属于子模块的模型参数。

  • 选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions

Returns

状态字典对于model

Return type

字典[字符串, ValueType]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]

返回优化器的组合状态字典。

参见 get_state_dict 以获取详细用法。

Parameters
  • 模型 (nn.Module) – 要本地化的 nn.Module。

  • 优化器 (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • 子模块已弃用) – Optional[Set[nn.Module]]: 只返回属于子模块的模型参数。

  • 选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions

Returns

状态字典对于optimizers

Return type

OptimizerStateType

torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]

加载模型的状态字典和优化器的状态字典。

get_state_dict设置为模型和优化器的状态字典的对应项。给定的model_state_dictoptim_state_dict不一定由get_state_dict返回,但必须满足以下要求:1) 所有FQN都是在get_state_dict中定义的标准FQN,2) 如果张量被分片,则必须是ShardedTensor或DTensor,3) 优化器状态字典不能包含参数ID;键应该是标准FQN。

Parameters
  • 模型 (nn.Module) – 要本地化的 nn.Module。

  • 优化器 (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型状态字典。如果model_state_dict键是nn.Module,则该键是model的子模块,值应该是该子模块的状态字典。在加载状态字典时,子模块的前缀将被添加到状态字典中。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType:要加载的优化器状态字典。

  • 选项 (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。详情请参见 StateDictOptions

Returns

  • missing_keys 是一个包含模型 state_dict 中缺失键的字符串列表。

  • unexpected_keys 是一个包含模型 state_dict 中意外键的字符串列表。

Return type

NamedTuple 个带有 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]

加载模型状态字典。

get_model_state_dict设置为状态字典的对应值以应用于模型。有关详细用法,请参见set_state_dict

Parameters
  • 模型 (nn.Module) – 要本地化的 nn.Module。

  • model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要加载的模型状态字典。如果model_state_dict的键是nn.Module,则该键是model的子模块,值应该是该子模块的状态字典。在加载状态字典时,子模块的前缀将被添加到状态字典中。

  • 选项 (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。详情请参见 StateDictOptions

Returns

  • missing_keys 是一个包含缺失键的字符串列表

  • unexpected_keys 是一个包含意外键的字符串列表

Return type

NamedTuple 个带有 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source]

加载优化器的状态字典。

get_optimizer_state_dict设置为状态字典以更新优化器。有关详细用法,请参见set_state_dict

Parameters
  • 模型 (nn.Module) – 要本地化的 nn.Module。

  • 优化器 (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType:要加载的优化器状态字典。

  • 选项 (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。详情请参见 StateDictOptions

Returns

请提供需要翻译的单词列表。

Return type

请提供需要翻译的单词列表。

class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False)[source]

这个数据类指定了 get_state_dict/set_state_dict 的工作方式。

  • full_state_dict: 如果设置为True,则返回的state_dict中的所有张量都将被收集。返回的state_dict中不会包含ShardedTensor和DTensor。

  • cpu_offload: 将所有张量卸载到CPU。为防止CPU内存不足,如果 full_state_dict 也为真,则只有rank0会获取 state_dict,而其他所有rank将获取空的state_dict。

  • ignore_frozen_params: 如果值为True,返回的state_dict将不包含任何被冻结的参数 – requires_grad 是False。 默认值为False。

  • keep_submodule_prefixes (已弃用):当 submodules 不为 None 时,此选项 表示是否保留状态字典键中的子模块前缀。 例如,如果子模块是 module.pretrain,且参数的完整 FQN 是 pretrain.layer1.weight。当此选项 为 True 时,返回的状态字典中参数的键将是 pretrain.layer1.weight。如果选项为 False,键将是 layer1.weight。 注意,如果 keep_submodule_prefixes 为 False,可能会有冲突的 FQNs,因此 submodules 中应只有一个子模块。

  • strict: the strict选项当set_state_dict调用 model.load_state_dict()时。

  • broadcast_from_rank0: when the option is True, rank0 should receive a

    完整的state_dict并将state_dict中的张量逐个广播到其他排名。其他排名将接收这些张量,并根据模型和优化器中的本地分片进行分片。使用此选项时,full_state_dict 必须设置为True。 此选项目前仅支持DTensor,不支持传统的ShardedTensor。

对于习惯使用和分享torch.save格式模型的用户,提供了以下方法,用于在格式之间进行离线转换。

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source]

给定一个包含 DCP 检查点的目录,此函数将把它转换为一个 Torch 保存文件。

Parameters
  • dcp_checkpoint_dir (Union[str, PathLike]) – 包含DCP检查点的目录。

  • torch_save_path (Union[str, PathLike]) – 用于存储转换后的Torch保存文件的文件名。

警告

为了避免内存溢出,建议仅在单个进程中运行此函数。

torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source]

给定 torch 检查点文件的位置,将其转换为 DCP 检查点。

Parameters
  • torch_save_path (Union[str, PathLike]) – Torch 保存文件的文件名。

  • dcp_checkpoint_dir (Union[str, PathLike]) – 用于存储DCP检查点的目录。

警告

为了避免内存溢出,建议仅在单个进程中运行此函数。

以下类也可以用于从 torch.save 格式在线加载和重新划分模型。

class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source]

用于读取Torch保存文件的StorageReader。此读取器将在协调员节点读取整个检查点,然后将每个张量广播并分片到所有节点。

注意:此内容旨在与DynamicMetaLoadPlanner一起使用。

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
prepare_global_plan(global_plan)[source]

StorageReader 方法的实现

Return type

列表[加载计划]

prepare_local_plan(plan)[source]

StorageReader 方法的实现

Return type

LoadPlan

read_data(plan, planner)[source]

在协调器 ranks 上读取 torch 保存的数据,然后进行广播 这会带来通信成本,但避免了在每个 ranks 上都加载整个检查点,有望防止 OOM 问题。

Return type

未来[无]

read_metadata()[source]

扩展默认的 StorageReader 以支持构建元数据文件

Return type

元数据

reset(checkpoint_id=None)[source]

StorageReader 方法的实现

set_up_storage_reader(metadata, is_coordinator)[source]

StorageReader 方法的实现

classmethod validate_checkpoint_id(checkpoint_id)[source]

StorageReader 方法的实现

Return type

布尔

class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]

扩展默认加载计划器,默认加载计划器的派生类,它根据传入的状态字典创建一个新的元数据对象,避免了从磁盘读取元数据的需求。这对于读取没有元数据文件的格式(例如 Torch Save 文件)非常有用。

注意:意在与BroadcastingTorchSaveReader一起使用。

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
set_up_planner(state_dict, metadata=None, is_coordinator=False)[source]

规划器的设置,通过从状态字典创建元数据对象来扩展默认行为。

以下实验接口提供了生产环境中改进的可观测性:

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源