分布式检查点 - torch.distributed.checkpoint¶
分布式检查点 (DCP) 支持并行加载和保存来自多个 rank 的模型。 它处理加载时重新分片,从而支持在一个集群拓扑中保存并加载到另一个集群拓扑中。
DCP 与 torch.save 和 torch.load 在几个重要方面有所不同:
它为每个检查点生成多个文件,每个等级至少生成一个文件。
它在原地运行,这意味着模型应该首先分配其数据,然后 DCP 使用该存储。
加载和保存检查点的入口点如下:
- torch.distributed.checkpoint 中。load(state_dict, storage_reader, *, process_group=无, coordinator_rank=0, no_dist=False,planner=None)[来源]¶
以 SPMD 样式加载 distributed。
state_dict
每个等级将尝试读取所需的最少数据量 完成请求的state_dict。在加载或实例时,每个排名仅读取其本地分片的数据。
ShardedTensor
DTensor
对于每个对象(同时具有 a 和 ), load 将在尝试反序列化之前首先调用,然后在反序列化完成后调用。
Stateful
state_dict
load_state_dict
state_dict
load_state_dict
警告
中的所有张量都必须在其 destination device 之前。
state_dict
所有非张量数据都使用 torch.load() 加载并就地修改 在 state_dict。
警告
用户必须在根模块上调用 load_state_dict 以确保负载 pos-processing 和非 Tensor 数据正确传播。
- 参数
state_dict (Dict[str, Any]) – 要加载state_dict。请注意,此 state dict 将就地更新。
storage_reader (StorageReader) – 用于从中加载数据的 StorageReader。
process_group (ProcessGroup) – 用于跨等级同步的 ProcessGroup。
coordinator_rank (int) – 用于协调检查点的排名。 默认情况下使用 rank0。
no_dist (bool) – 如果 ,则不会加载分布式检查点 在 SPMD 样式中。(默认:
True
False
)
- 返回
没有。
- 返回类型
没有
- 例子
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
注意
load_state_dict 使用 collectives 来协调跨等级的读取。 对于基于 NCCL 的进程组, 在进行通信之前,必须将对象移动到 GPU 设备。 在这种情况下,使用的设备由 提供,用户有责任确保将其设置为每个 rank 具有单个 GPU,通过 .
torch.cuda.current_device()
torch.cuda.set_device()
- torch.distributed.checkpoint 中。save(state_dict, storage_writer, *, process_group=无, coordinator_rank=0, no_dist=False,planner=None)[来源]¶
以 SPMD 样式保存分布式模型。
此函数与它处理的不同之处在于,每个排名仅保存其本地分片。
torch.save()
ShardedTensor
DTensor
对于每个对象(同时具有 a 和 ), save 将在序列化之前调用。
Stateful
state_dict
load_state_dict
state_dict
警告
无法保证 PyTorch 版本之间的向后兼容性 用于节省state_dicts。
警告
如果使用 process_group 参数,请确保只有其 ranks 调用 save_state_dict,并且 state_dict 中的所有数据都属于它。
注意
为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存 checkpoint 时,只有一个 shard_group 应调用 save_state_dict 和相应的进程 group 需要传入。
注意
此功能可用于在没有进程组的情况下保存state_dict 通过传递 .
no_dist=True
- 参数
state_dict (Dict[str, Any]) – 要保存的state_dict。
storage_writer (StorageWriter) – StorageWrite 实例用于执行写入。
process_group (ProcessGroup) – 用于跨等级同步的 ProcessGroup。
coordinator_rank (int) – 用于协调检查点的排名。 默认情况下使用 rank0。
no_dist (bool) – 如果 ,则分布式检查点不会保存 在 SPMD 样式中。(默认:
True
False
)
- 返回
Metadata 对象。
- 返回类型
元数据
例
>>> my_model = MyModule()
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save_state_dict( >>> state_dict=model_state_dict, >>> storage_writer=fs_storage_writer, >>> )
注意
save_state_dict 使用 collectives 来协调跨等级的写入。 对于基于 NCCL 的进程组, 在进行通信之前,必须将对象移动到 GPU 设备。 在这种情况下,使用的设备由 提供,用户有责任确保将其设置为: 每个等级都有一个单独的 GPU,通过 .
torch.cuda.current_device()
torch.cuda.set_device()
- torch.distributed.checkpoint 中。load_state_dict(state_dict, storage_reader, process_group=无, coordinator_rank=0, no_dist=False, 规划者=无)[来源]¶
此方法已弃用。请切换到 'load'。
- torch.distributed.checkpoint 中。save_state_dict(state_dict, storage_writer, process_group=无, coordinator_rank=0, no_dist=False, 规划者=无)[来源]¶
此方法已弃用。请切换到 'save'。
- 返回类型
元数据
除了上述入口点之外,如下所述的 Stateful 对象在保存/加载期间提供了额外的自定义 ..automodule::torch.distributed.checkpoint.stateful
- 类 torch.distributed.checkpoint.stateful 中。Stateful(*args, **kwargs)[来源]¶
用于可进行检查点和还原的对象的有状态协议。
此示例说明如何使用 Pytorch 分布式检查点保存 FSDP 模型。
以下类型定义了 checkpoint 期间使用的 IO 接口:
- 类 torch.distributed.checkpoint 中。StorageReader[来源]¶
用于从存储中读取的接口。
load_state_dict
一个 StorageReader 实例同时充当协调器和追随者 在分布式检查点中。作为初始化的一部分,每个实例 被告知其角色。
子类应期望以下调用序列:
load_state_dict
(所有等级) read_metadata()
(所有等级) set_up_storage_reader()
(所有等级) prepare_local_plan()
(协调) prepare_global_plan()
(所有等级) read_data()
- 摘要 prepare_global_plan(计划)[来源]¶
执行存储加载的集中规划。
此方法仅在协调器实例上调用。
虽然这种方法可以产生完全不同的计划,但首选的 方式是将存储特定数据存储在 LoadPlan::storage_data 中。
- 摘要 prepare_local_plan(计划)[来源]¶
执行特定于存储的本地规划。
虽然此方法可以生成完全不同的计划,但建议的 方式是将存储特定数据存储在 LoadPlan::storage_data 中。
- 参数
plan (LoadPlan) – 正在使用的本地计划。
LoadPlan
- 返回
A 仓储后本地规划
LoadPlan
- 返回类型
- 摘要 read_data(Plan, Planner)[来源]¶
从 using 中读取所有项目以解析数据。
plan
planner
应调用子类以反序列化 BytesIO object 移动到正确的位置。
LoadPlanner::load_bytes
子类应该调用以获取对 应该将数据加载到的张量。
LoadPlanner::resolve_tensor
StorageLayer 负责正确安排任何跨设备副本 必填。
- 参数
plan (LoadPlan) – 要执行的本地计划
planner (LoadPlanner) – 用于解析项目的 planner 对象。
- 返回
一个 future ,在完成所有读取后完成。
- 返回类型
Future[无]
- 类 torch.distributed.checkpoint 中。StorageWriter[来源]¶
用于写入存储的接口。
save_state_dict
一个 StorageWriter 实例同时充当协调器和追随者 在分布式检查点中。作为初始化的一部分,每个实例 被告知其角色。
子类应期望以下调用序列。
(所有等级) set_up_storage_writer()
(所有等级) prepare_local_plan()
(协调) prepare_global_plan()
(所有等级) write_data()
(协调器) finish()
- 抽象完成(元数据、结果)[来源]¶
写入元数据并将当前检查点标记为成功。
用于序列化元数据的实际格式/架构是 implementation detail 的 implementation detail 中。唯一的要求是它是可恢复的 in 添加到同一对象图中。
- 摘要 prepare_global_plan(计划)[来源]¶
执行集中存储规划。
此方法仅在协调器实例上调用。
虽然这种方法可以产生完全不同的计划,但首选的 方法是将存储特定数据存储在 SavePlan::storage_data 中。
- 摘要 prepare_local_plan(计划)[来源]¶
执行特定于存储的本地规划。
虽然此方法可以生成完全不同的计划,但建议的 方法是将存储特定数据存储在 SavePlan::storage_data 中。
- 摘要 write_data(Plan, Planner)[来源]¶
写入 using 中的所有项目以解析数据。
plan
planner
子类应该调用每个项目 从 plan 获取对要写入的基础对象的访问权。
SavePlanner::resolve_data
子类应该延迟调用 resolve_data因为它可以分配内存。 对于 Tensors,请做出以下假设:
它们可能在任何设备上,包括不匹配的
WriteItem::tensor_data
它们可能是视图,也可能不是连续的。只需保存投影。
- 参数
plan (SavePlan) – 要执行的保存计划。
planner (SavePlanner) – 用于将项目解析为数据的 Planner 对象。
- 返回
完成 WriteResult 列表的 future
- 返回类型
以下类型定义了 checkpoint 期间使用的 planner 接口:
- 类 torch.distributed.checkpoint 中。LoadPlanner[来源]¶
定义 load_state_dict 用于规划加载过程的协议的抽象类。
LoadPlanner 是可用于自定义整个加载过程的有状态对象。
LoadPlanner 充当 state_dict 的访问代理,因此对其执行的任何转换 将对整个过程可见。
在 load_state_dict期间,Planner 子类可以预期以下 Sequences 调用:
- set_up_planner - 召集所有级别。
表示开始加载检查点。
- create_local_plan - 召集所有等级。
处理 state_dict 并生成 LoadPlan,该 LoadPlan 将发送以进行全局规划。
- create_global_plan - 仅根据协调者级别调用。
从所有级别获取 LoadPlan 并做出任何全局决策。
- load_bytes - 在每个等级上多次调用
在 state_dict 中,每个非张量值调用一次。
- resolve_tensor 和 commit_tensor - 在每个等级上多次调用
它们对 state_dict 中的每个 Tensor 值成对调用。
建议用户直接扩展 DefaultLoadPlanner 而不是这个接口 大多数更改可以通过单个方法中的更改来表示。
有两种常见的扩展模式:
重写state_dict。这是扩展 load 进程的最简单方法,因为它 并没有完全理解 LoadPlan 工作原理的内在含义。我们需要 在加载发生时保留对原始state_dict的引用,因此 我们需要能够在现场执行
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner(self, state_dict, metadata, is_coordinator): >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
修改 resolve_tensor 和 commit_tensor 以处理加载时间转换。
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- 摘要 commit_tensor(read_item, Tensor)[来源]¶
在 StorageReader 完成将数据加载到 后调用 。
tensor
提供的张量与调用 . 仅当此 LoadPlanner 需要在 将其复制回 state_dict 中的那个。
resolve_tensor
tensor
tensor 的内容将遵循其设备同步模型。
- 摘要 load_bytes(read_item, value)[来源]¶
加载 描述的项目。
read_item``and ``value
此方法应就地修改基础state_dict。
的内容由用于生成 正在加载的 checkpoint。
value
- 类 torch.distributed.checkpoint 中。LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data:任何 = 无)[来源]¶
- 类 torch.distributed.checkpoint 中。ReadItem(类型:torch.distributed.checkpoint.planner.LoadItemType,dest_index:torch.distributed.checkpoint.metadata.MetadataIndex,dest_offsets:torch.大小,storage_index:torch.distributed.checkpoint.metadata.MetadataIndex,storage_offsets:torch。尺寸、长度:Torch。尺寸)[来源]¶
- 类 torch.distributed.checkpoint 中。SavePlanner[来源]¶
定义 save_state_dict 用于规划保存过程的协议的抽象类。
SavePlanners 是可用于自定义整个保存过程的有状态对象。
SavePlanner 充当 state_dict 的访问代理,因此对它所做的任何转换 将对整个过程可见。
在 save_state_dict期间,Planner 子类可以预期以下 Sequences 调用:
- set_up_planner - 召集所有级别。
指示检查点保存开始。
- create_local_plan - 召集所有等级。
处理state_dict并生成 SavePlan,该 SavePlan 将发送用于全球规划。
- create_global_plan - 仅根据协调者级别调用。
从所有级别中获取 SavePlan 并做出任何全局决策。
- finish_plan - 召集所有军衔。
这使每个等级都有机会适应全局规划决策。
- resolve_data - 在每个等级上多次调用
在 state_dict 上查找要写入的存储层的值。
建议用户将 DefaultSavePlanner 而不是这个接口直接扩展为 大多数更改可以通过单个方法中的更改来表示。
有 3 种常见的扩展模式:
重写state_dict。这是扩展保存过程的最简单方法,因为它 并不完全理解 SavePlan 工作原理的内在含义:
>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner(self, state_dict, is_coordinator): >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
同时修改本地计划和查找。这在精细控制数据的持久化方式时非常有用
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
使用全局规划步骤做出每个等级无法单独做出的中央决策
>>> from itertools import islice >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> def chunk(it, size): >>> it = iter(it) >>> return list(iter(lambda: tuple(islice(it, size)), ())) >>> all_plans = [ >>> replace(plan, items=items) for plan, items in >>> zip(all_plans, chunk(all_plans[0].items, len(all_plans))) >>> ] >>> return super().create_global_plan(all_plans)
最后,一些 planner 需要在 checkpoint 中保存额外的元数据,这是 通过让每个 rank 在本地计划中贡献其数据项来完成,并且 全局规划器会聚合它们:
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- 摘要 create_local_plan()[来源]¶
计算当前排名的保存计划。
这将被聚合并传递给 create_global_plan。 Planner 特定的数据可以通过 SavePlan::p lanner_data 传递。
这在所有等级上都调用。
- 返回类型
- 摘要 resolve_data(write_item)[来源]¶
转换和准备存储,确保幂等性和线程安全。
write_item
state_dict
查找与 in 关联的对象并应用任何 转换(例如序列化)的 Interface。
write_item
state_dict
对每个排名调用多次,在最终 SavePlan 中对每个 WriteItem 至少调用一次。
此方法应该是幂等的和 thread-save。StorageWriter 实现 可以根据需要随意调用它。
任何分配内存的转换都应该在他的方法 以减少检查点所需的峰值内存。
返回张量时,它们可以位于任何设备或格式上,也可以是视图。 存储层有责任弄清楚如何保存它们。
- 类 torch.distributed.checkpoint 中。SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data:任何 = 无)[来源]¶
- 类 torch.distributed.checkpoint 中。WriteItem(index: torch.distributed.checkpoint.metadata.MetadataIndex, type: torch.distributed.checkpoint.planner.WriteItemType, tensor_data: Union[torch.distributed.checkpoint.planner.TensorWriteData, NoneType] = None)[来源]¶
我们提供基于文件系统的存储层:
- 类 torch.distributed.checkpoint 中。FileSystemWriter(路径, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000 元)[来源]¶
使用文件 IO 的 StorageWriter 的基本实现。
此实现进行了以下假设和简化:
检查点路径是一个空目录或不存在的目录。
文件创建是原子的
检查点由每个写入请求一个文件以及 包含序列化元数据的 .metadata 文件。
我们提供了 LoadPlanner 和 SavePlanner 的默认实现,这些 可以处理所有 torch.distributed 结构,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。
- 类 torch.distributed.checkpoint 中。DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=True)[来源]¶
- 类 torch.distributed.checkpoint 中。DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[来源]¶
DefaultLoadPlanner 在 LoadPlanner 的基础上添加了多个功能。
具体而言,它添加了以下内容:
flatten_state_dict:使用嵌套 dict 处理 state_dict flatten_sharded_tensors:对于 2D 并行模式下的 FSDP
我们提供了一组 API 来帮助用户轻松获取和设置state_dict。这是 这是一项实验性功能,可能会发生更改。
- torch.distributed.checkpoint.state_dict。get_state_dict(model, optimizers, *, submodules=None, options=None)[来源]¶
返回模型 state_dict 和 state_dict 优化器。
get_state_dict
可以处理 PyTorch 并行化的任何模块 FSDP/fully_shard、DDP/复制、tensor_parallel/parallelize_module 和任何 这些并行度的组合。的主要功能是:1.) 返回一个可以重新分片的模型和优化器state_dict 具有不同数量的 trainer 和/或不同的 parallelisms。 2.) 隐藏特定于并行度的 state_dict API。用户不必调用 这些 API 的 API 进行验证。 3.) 对结果进行健全性检查state_dict。get_state_dict
结果状态字典的键是规范的 FQN(完全 限定名称)。规范 FQN 是指基于参数的 position 在 nn.模块层次结构。更具体地说,将规范 FQN 转换为 parameter 是模块未由任何 parallelisms 的 Parallel Actions.由于优化器内部使用参数 ID 来表示 参数,则参数 ID 将转换为 规范 FQN。
module.named_parameters()
module.named_buffers()
get_state_dict
还可以处理未并行化的模块。在 在这种情况下,只执行一个功能——将 optimizer 参数 ID 设置为规范 FQN。get_state_dict
例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- 参数
- 返回
Tuple
,其中包含 Model state_dict 和 Optimizer state_dict。- 返回类型
- torch.distributed.checkpoint.state_dict。get_model_state_dict(model, *, submodules=None, options=None)[来源]¶
返回 的模型state_dict 。
model
有关详细用法,请参见。
get_state_dict
- 参数
模型 (NN.Module) – nn.Module 添加到模型中。
submodules (Optional[Set[Module]]) – 可选[Set[nn.Module]]:仅返回模型参数 属于子模块。
options (StateDictOptions) – 控制方式的选项 应返回 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions。
- 返回
的 state_dict 。
model
- 返回类型
Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]
- torch.distributed.checkpoint.state_dict。get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[来源]¶
返回优化器的组合state_dict。
有关详细用法,请参见。
get_state_dict
- 参数
- 返回
的 state_dict 。
optimizers
- 返回类型
Dict[str, Union[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[联盟[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]], List[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]]]]
- torch.distributed.checkpoint.state_dict。set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[来源]¶
加载模型state_dict state_dict 和优化器。
的对应项将 state_dict 设置为 model,并将 优化器。given 和 do not 必须返回,但必须满足以下条件 要求: 1) 所有 FQN 都是 中定义的规范 FQN, 2) 如果张量被分片,则它必须是 ShardedTensor 或 DTensor, 3) optimizer state_dict 不能包含参数 ID;键应该是 规范的 FQN。
get_state_dict
model_state_dict
optim_state_dict
get_state_dict
get_state_dict
- 参数
模型 (NN.Module) – nn.Module 添加到模型中。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化的优化器。
model
model_state_dict (Union[Dict[Module, Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float、 str、 List[Union[DTensor、 ShardedTensor、 Tensor、 int、 float、 str]]、 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor,int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor,张量、int、浮点、str]]、元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]]) – (联合[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 模型state_dict加载。如果 的 key 是 nn.Module 的 Module,键是 的子模块,值应 是子模块的state_dict。加载 state_dict 时, 子模块的前缀将附加到 state_dict。
model_state_dict
model
optim_state_dict (Dict[str, Union[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]]], Dict[str, ValueType]]]]], List[Dict[str, 联盟[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor,张量、int、浮点、str]]、元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor,int, float, str]], Dict[str, ValueType]]]]]]])) – 优化器StateType: 优化器state_dict加载。
options (StateDictOptions) – 控制方式的选项 应加载 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions。
- 返回
missing_keys 是一个 str 列表,其中包含模型state_dict的缺失键。
unexpected_keys 是一个 str 列表,其中包含模型state_dict的意外键。
- 返回类型
NamedTuple
with 和 字段missing_keys
unexpected_keys
- torch.distributed.checkpoint.state_dict。set_model_state_dict(model, model_state_dict, *, options=None)[来源]¶
state_dict加载模型。
的对应项将 state_dict 设置为 型。有关详细用法,请参见。
get_model_state_dict
set_state_dict
- 参数
模型 (NN.Module) – nn.Module 添加到模型中。
model_state_dict (Union[Dict[Module, Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float、 str、 List[Union[DTensor、 ShardedTensor、 Tensor、 int、 float、 str]]、 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor,int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor,张量、int、浮点、str]]、元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, ValueType]]]]]]) – (联合[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 模型state_dict加载。如果 的 key 是 nn.Module 的 Module,键是 的子模块,值应 是子模块的state_dict。加载 state_dict 时, 子模块的前缀将附加到 state_dict。
model_state_dict
model
options (StateDictOptions) – 控制方式的选项 应加载 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions。
- 返回
missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表
- 返回类型
NamedTuple
with 和 字段missing_keys
unexpected_keys
- torch.distributed.checkpoint.state_dict。set_optimizer_state_dict(model, optimizers, *, optim_state_dict, options=None)[来源]¶
state_dict加载优化器。
的对应项将 state_dict 设置为 优化器。有关详细用法,请参见。
get_optimizer_state_dict
set_state_dict
- 参数
模型 (NN.Module) – nn.Module 添加到模型中。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化的优化器。
model
optim_state_dict (Dict[str, Union[Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor, int, float, str]]], Dict[str, ValueType]]]]], List[Dict[str, 联盟[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor,张量、int、浮点、str]]、元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Dict[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, List[Union[DTensor, ShardedTensor, Tensor, int, float, str]], Tuple[Union[DTensor, ShardedTensor, Tensor,int, float, str]], Dict[str, ValueType]]]]]]])) – 优化器StateType: 优化器state_dict加载。
options (StateDictOptions) – 控制方式的选项 应加载 Model state_dict 和 Optimizer state_dict。有关详细信息,请参阅 StateDictOptions。
- 返回
没有
- 返回类型
没有
- torch.distributed.checkpoint.state_dict 类。StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True)[来源]¶
此数据类指定 get_state_dict/set_state_dict 的工作方式。
full_state_dict
:如果设置为 True,则 将收集返回的 state_dict。没有 ShardedTensor 和 DTensor 将位于返回的state_dict中。cpu_offload
:将所有张量卸载到 CPU。为防止 CPU OOM,如果也是 true,则只有 rank0 才会获得 state_dict 和所有其他等级都将为空state_dict。full_state_dict
ignore_frozen_params
:如果值为 True,则返回state_dict 不包含任何冻结的参数 – 为 False。 默认值为 False。requires_grad
keep_submodule_prefixes
:当 is not None 时,此选项 指示是否保留 state_dict 键的 submodule 前缀。 或 example,如果子模块是且 该参数为 param.当此选项 为 True,则返回的 state_dict 中的参数键将为 。如果选项为 False,则键将为 。 请注意,如果为 False,则可能存在冲突 FQN,因此 中应该只有一个子模块 。submodules
module.pretrain
pretrain.layer1.weight
pretrain.layer1.weight
layer1.weight
keep_submodule_prefixes
submodules
strict
:调用 model.load_state_dict() 中。 默认值为 False。strict
set_state_dict