分布式检查点 - torch.distributed.checkpoint¶
分布式检查点(DCP)支持从多个进程并行加载和保存模型。它处理加载时的重新分片,这使得可以在一种集群拓扑中保存模型,然后在另一种拓扑中加载。
DCP与torch.save和torch.load在几个重要方面有所不同:
它为每个检查点生成多个文件,每个排名至少一个。
它以就地操作方式进行工作,这意味着模型应先分配其数据,然后DCP使用该存储空间。
加载和保存检查点的入口函数如下:
- torch.distributed.checkpoint.load(state_dict, storage_reader, *, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
以SPMD风格加载一个分布式的
state_dict。每个进程将尝试读取最少的数据以满足请求的state_dict。在加载
ShardedTensor或DTensor实例时,每个进程只读取其本地分片的数据。对于每个
Stateful对象(同时具有state_dict和load_state_dict), load 会在尝试反序列化之前首先调用state_dict,然后在反序列化完成后调用load_state_dict。警告
所有在
state_dict中的张量必须在其目标设备上分配 之前 调用此函数。所有非张量数据都使用torch.load()加载,并在state_dict上就地修改。
警告
用户必须在根模块上调用load_state_dict以确保加载后处理和非张量数据正确传播。
- Parameters
state_dict (Dict[str, Any]) – 要加载的 state_dict。请注意,此 state_dict 将在原地更新。
storage_reader (StorageReader) – 用于从此处加载数据的 StorageReader。
process_group (ProcessGroup) – 用于跨等级同步的 ProcessGroup。
coordinator_rank (int) – 用于协调检查点的秩(rank)。 默认使用 rank0。
no_dist (bool) – 如果
True, 分布式检查点将不会以 SPMD 风格加载。 (默认值:False)
- Returns
None.
- Return type
请提供需要翻译的单词列表。
- Examples
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
注意
load_state_dict 使用集体操作在各个排名之间协调读取。 对于基于 NCCL 的进程组,在通信发生之前,对象的内部张量表示必须移动到 GPU 设备。 在这种情况下,使用的设备由
torch.cuda.current_device()指定, 并且用户有责任确保每个排名都有一个单独的 GPU,通过torch.cuda.set_device()来实现。
- torch.distributed.checkpoint.save(state_dict, storage_writer, *, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
以 SPMD 风格保存分布式模型。
此功能不同于
torch.save(),因为它处理ShardedTensor和DTensor,通过让每个排名只保存它们的本地分片。对于每个
Stateful对象(同时具有state_dict和load_state_dict), 保存将在序列化之前调用state_dict。警告
保存的 state_dicts 在 PyTorch 不同版本之间没有向后兼容性的保证。
警告
如果使用process_group参数,请确保只有它的排名调用save_state_dict,并且state_dict中的所有数据都属于它。
注意
在为FSDP的ShardingStrategy.HYBRID_SHARD保存检查点时,shard_group中应该只有一个调用save_state_dict,并且需要传入相应的进程组。
注意
此函数可以在不初始化进程组的情况下,通过传递
no_dist=True来保存 state_dict。- Parameters
state_dict (Dict[str, Any]) – 要保存的state_dict。
storage_writer (StorageWriter) – StorageWrite 的实例,用于执行写入操作。
process_group (ProcessGroup) – 用于跨等级同步的 ProcessGroup。
coordinator_rank (int) – 用于协调检查点的秩(rank)。 默认使用 rank0。
no_dist (bool) – 如果
True, 分布式检查点将不会以SPMD风格保存。 (默认值:False)
- Returns
保存检查点的元数据对象。
- Return type
元数据
示例
>>> my_model = MyModule()
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save_state_dict( >>> state_dict=model_state_dict, >>> storage_writer=fs_storage_writer, >>> )
注意
save_state_dict 使用集体操作来协调不同排名之间的写入。 对于基于 NCCL 的进程组,在通信发生之前,对象的内部张量表示必须移动到 GPU 设备。 在这种情况下,使用的设备由
torch.cuda.current_device()指定, 并且用户有责任确保每个排名都有一个单独的 GPU,通过torch.cuda.set_device()设置。
- torch.distributed.checkpoint.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
此方法已弃用。请切换到‘load’。
- torch.distributed.checkpoint.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
此方法已弃用。请切换到‘save’。
- Return type
元数据
除了上述入口点外,Stateful个对象,如下面所述,在保存/加载期间提供额外的自定义选项 .. automodule:: torch.distributed.checkpoint.stateful
- class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source]¶
可检查点和恢复的状态化对象协议。
这个 示例 展示了如何使用 PyTorch 分布式检查点保存 FSDP 模型。
以下类型定义了检查点过程中使用的 IO 接口:
- class torch.distributed.checkpoint.StorageReader[source]¶
由
load_state_dict使用的从存储中读取的接口。一个StorageReader实例在分布式检查点中同时充当协调器和跟随者。在初始化过程中,每个实例都会被告知其角色。
子类应预期以下调用顺序由
load_state_dict:(所有排名)read_metadata()
(所有排名)设置存储阅读器()
(所有排名)prepare_local_plan()
(协调员) prepare_global_plan()
(所有排名)读取数据()
- abstract prepare_global_plan(plans)[source]¶
集中规划存储加载。
此方法仅在协调器实例上被调用。
虽然这种方法可以生成一个完全不同的计划,但推荐的方式是将存储特定的数据存储在 LoadPlan::storage_data 中。
- abstract prepare_local_plan(plan)[source]¶
执行存储特定的本地规划。
虽然这种方法可以生成一个完全不同的计划,但推荐的方式是将存储特定的数据存储在 LoadPlan::storage_data 中。
- abstract read_data(plan, planner)[source]¶
从
中读取所有项目,并使用来解析数据。子类应调用
LoadPlanner::load_bytes将BytesIO对象反序列化到正确的位置。子类应调用
LoadPlanner::resolve_tensor以访问需要加载数据的张量。存储层负责正确安排任何跨设备复制操作。
- Parameters
计划 (LoadPlan) – 本地执行的计划
planner (LoadPlanner) – 用于解析项目的规划器对象。
- Returns
所有读取操作完成后才会完成的未来状态。
- Return type
未来[无]
- class torch.distributed.checkpoint.StorageWriter[source]¶
由
save_state_dict使用的接口,用于写入存储。一个StorageWriter实例在分布式检查点中同时充当协调器和跟随者。在初始化过程中,每个实例都会被告知其角色。
一个子类应期望以下调用顺序。
(所有排名)设置存储写入器 ()
(所有排名)prepare_local_plan()
(协调员) prepare_global_plan()
所有排名 write_data()
(协调员) 结束()
- abstract finish(metadata, results)[source]¶
写入元数据,并将当前检查点标记为成功。
实际用于序列化的metadata格式/模式是一个实现细节。唯一的要求是它可以恢复到相同的对象图。
- abstract prepare_global_plan(plans)[source]¶
集中规划存储。
此方法仅在协调器实例上被调用。
虽然这种方法可以生成完全不同的计划,但推荐的方式是将存储特定的数据存储在 SavePlan::storage_data 中。
- abstract prepare_local_plan(plan)[source]¶
执行存储特定的本地规划。
虽然这种方法可以生成完全不同的计划,但推荐的方式是将存储特定的数据存储在 SavePlan::storage_data 中。
- abstract set_up_storage_writer(is_coordinator)[source]¶
初始化此实例。
- Parameters
is_coordinator (bool) – 是否此实例负责协调检查点。
- abstract write_data(plan, planner)[source]¶
从
plan开始写所有项目,并使用planner来解析数据。子类应在计划中的每个项目上调用
SavePlanner::resolve_data以访问底层对象进行写入。子类应懒惰地调用resolve_data,因为它可以分配内存。 对于张量,做如下假设:
它们可能出现在任何设备上,包括与
WriteItem::tensor_data不匹配的那个它们可能是视图,也可能不连续。只需保存投影。
- Parameters
计划 (SavePlan) – 要执行的保存计划。
planner (SavePlanner) – 用于解析项目到数据的规划器对象。
- Returns
一个将结果完成到 WriteResult 列表的未来
- Return type
以下类型定义了检查点期间使用的计划器接口:
- class torch.distributed.checkpoint.LoadPlanner[source]¶
抽象类,定义了 load_state_dict 使用的协议,以规划加载过程。
LoadPlanner 是有状态的对象,可用于自定义整个加载过程。
LoadPlanner 作为 state_dict 的访问代理,因此对其所做的任何转换都将对整个进程可见。
在调用 load_state_dict 期间,计划器子类可以预期以下调用顺序:
- set_up_planner - called on all ranks.
表示开始加载检查点。
- create_local_plan - called on all ranks.
处理 state_dict 并生成一个LoadPlan,该值将用于全局规划。
- create_global_plan - called on the coordinator rank only.
从所有 ranks 获取 LoadPlan 并做出任何全局决策。
- load_bytes - called multiple times on each rank
这在状态字典中的每个非张量值上调用一次。
- resolve_tensor and commit_tensor - called multiple times on each rank
它们以成对的方式为 state_dict 中的每个张量值调用。
建议用户扩展 DefaultLoadPlanner 而不是直接实现此接口,因为大多数更改可以通过对单个方法的修改来表达。
有两种常见的扩展模式:
重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 的复杂机制。我们需要保留对原始 state_dict 的引用,因为加载是就地进行的,所以我们需要能够就地执行它。
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner(self, state_dict, metadata, is_coordinator): >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
修改 resolve_tensor 和 commit_tensor 以处理加载时的转换。
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- abstract commit_tensor(read_item, tensor)[source]¶
调用一次,当StorageReader完成将数据加载到
tensor中时。提供的张量与调用
resolve_tensor返回的张量相同。 此方法仅在该LoadPlanner需要在将其复制回state_dict中的张量之前对tensor进行后处理时才需要。张量的内容将遵循其设备同步模型。
- abstract create_local_plan()[source]¶
基于 set_up_planner 提供的 state_dict 和元数据创建一个 LoadPlan。
注意:这在每个排名上都会被调用。
- Return type
- abstract load_bytes(read_item, value)[source]¶
加载由
read_item``and ``value描述的项。此方法预计将就地修改底层 state_dict。
value的内容由用于生成正在加载的检查点的 SavePlanner 定义。
- class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]¶
- class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]¶
- class torch.distributed.checkpoint.SavePlanner[source]¶
抽象类,定义了 save_state_dict 使用的协议,以规划保存过程。
SavePlanner 是一种有状态的对象,可用于自定义整个保存过程。
SavePlanner 作为 state_dict 的访问代理,因此对其所做的任何转换都会对整个进程可见。
在调用 save_state_dict 期间,计划子类可以预期以下调用顺序:
- set_up_planner - called on all ranks.
标志着检查点保存的开始。
- create_local_plan - called on all ranks.
处理 state_dict 并生成一个SavePlan,该值将用于全局规划。
- create_global_plan - called on the coordinator rank only.
从所有 ranks 中获取 SavePlan 并做出任何全局决策。
- finish_plan - called on all ranks.
这为每个排名有机会调整全局规划决策。
- resolve_data - called multiple times on each rank
在存储层写入时查找state_dict处的值。
用户建议扩展 DefaultSavePlanner 而不是直接实现此接口,因为大多数更改可以通过对单个方法的修改来表达。
有三种常见的扩展模式:
重写 state_dict。这是扩展保存过程的最简单方法,因为它不需要理解 SavePlan 的复杂机制:
>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner(self, state_dict, is_coordinator): >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
修改本地计划和查找表的同时进行调整。这在需要精细控制数据如何持久化时很有用。
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
使用全局规划步骤来做出各个层级单独无法做出的关键决策。
>>> from itertools import islice >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> def chunk(it, size): >>> it = iter(it) >>> return list(iter(lambda: tuple(islice(it, size)), ())) >>> all_plans = [ >>> replace(plan, items=items) for plan, items in >>> zip(all_plans, chunk(all_plans[0].items, len(all_plans))) >>> ] >>> return super().create_global_plan(all_plans)
最后,一些规划器需要在检查点中保存额外的元数据,这是通过让每个进程在其本地规划器中贡献其数据项,然后全局规划器聚合这些数据来实现的:
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- abstract create_local_plan()[source]¶
计算当前排名的保存计划。
这将被聚合并传递给 create_global_plan。 规划器特定的数据可以通过 SavePlan::planner_data 传递。
这在所有排名上都被调用。
- Return type
- abstract finish_plan(new_plan)[source]¶
合并由create_local_plan创建的计划和create_global_plan的结果。
这在所有排名上都被调用。
- Return type
- abstract resolve_data(write_item)[source]¶
将
write_item从state_dict转换并准备存储,确保幂等性和线程安全。查找与
write_item关联的对象,并在存储层使用它之前对其进行任何转换(例如序列化)。在最终的保存计划中的每个 WriteItem 至少调用一次,并在每个进程中多次调用。
此方法应具有幂等性且线程安全。StorageWriter 的实现可以自由地根据需要频繁调用它。
任何分配内存的转换都应在调用其方法时延迟执行,以减少检查点所需的峰值内存。
返回张量时,它们可以位于任何设备或格式上,也可以是视图。存储层的责任是确定如何保存它们。
- class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source]¶
- class torch.distributed.checkpoint.WriteItem(index: torch.distributed.checkpoint.metadata.MetadataIndex, type: torch.distributed.checkpoint.planner.WriteItemType, tensor_data: Union[torch.distributed.checkpoint.planner.TensorWriteData, NoneType] = None)[source]¶
我们提供一种基于文件系统的存储层:
- class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000)[source]¶
使用文件 IO 实现的 StorageWriter 基本实现。
此实现做出了以下假设和简化:
检查点路径是一个空目录或不存在的目录。
文件创建是原子操作
检查点由每个写入请求对应的一个文件加上一个.metadata文件组成,该文件包含序列化的元数据。
我们提供了LoadPlanner和SavePlanner的默认实现, 可以处理所有torch.distributed构造,例如FSDP、DDP、ShardedTensor和DistributedTensor。
- class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=True)[source]¶
- class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[source]¶
在 LoadPlanner 的基础上添加了多个功能的默认加载计划器。
特别是它添加了以下内容:
flatten_state_dict:处理包含嵌套字典的 state_dict flatten_sharded_tensors:在 FSDP 的 2D 并行模式下使用
我们提供了一组 API,帮助用户轻松地获取和设置 state_dict。这是一个实验性功能,可能会发生变化。
- torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]¶
返回模型的 state_dict 和优化器的 state_dict。
get_state_dict可以处理任何由 PyTorch 并行化的模块,例如 FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module,以及这些并行化方式的任意组合。get_state_dict的主要功能是:1.) 返回一个可以在不同数量的训练器和/或不同的并行化方式下重新划分的模型和优化器状态字典。2.) 隐藏特定于并行化的状态字典 API。用户无需调用这些 API。3.) 对结果状态字典进行合理性检查。结果状态字典的键是规范的FQN(全限定名)。规范的FQN是指基于参数在nn.Module层次结构中的位置的FQN。更具体地说,当模块未通过任何并行方式分布时,参数的规范FQN是由
module.named_parameters()或module.named_buffers()返回的FQN。由于优化器内部使用参数ID来表示一个参数,在调用此API时会将参数ID转换为规范的FQN。get_state_dict也可以处理未并行化的模块。在这种情况下,get_state_dict只执行一个功能——将优化器参数 ID 转换为标准全限定名 (FQN)。示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- Parameters
- Returns
Tuple包含模型 state_dict 和优化器 state_dict。- Return type
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]¶
返回模型的 state_dict 为
model。参见
get_state_dict以获取详细用法。- Parameters
模型 (nn.Module) – 要本地化的 nn.Module。
选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions。
- Returns
状态字典对于
model。- Return type
字典[字符串, 联合类型[DTensor, 分片张量, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, 分片张量, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, 分片张量, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合类型[DTensor, 分片张量, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, 分片张量, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, 分片张量, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 数据类型]]]]]
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]¶
返回优化器的组合状态字典。
参见
get_state_dict以获取详细用法。- Parameters
- Returns
状态字典对于
optimizers。- Return type
字典[字符串, 联合[字典[字符串, 联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, ValueType]]]]], 列表[字典[字符串, 联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, ValueType]]]]]]]]
- torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]¶
加载模型的状态字典和优化器的状态字典。
将
get_state_dict设置为模型和优化器的状态字典的对应项。给定的model_state_dict和optim_state_dict不一定由get_state_dict返回,但必须满足以下要求:1) 所有FQN都是在get_state_dict中定义的标准FQN,2) 如果张量被分片,则必须是ShardedTensor或DTensor,3) 优化器状态字典不能包含参数ID;键应该是标准FQN。- Parameters
模型 (nn.Module) – 要本地化的 nn.Module。
优化器 (Union[Optimizer, Iterable[Optimizer]]) – 用于优化
model的优化器。model_state_dict (Union[字典[Module, 字典[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, 列表[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 字典[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, 列表[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 字典[str, ValueType]]]]]]]]]]]]]]) – (Union[字典[nn.Module, 字典[str, ValueType]], 字典[str, ValueType]]): 模型状态字典,用于加载。如果键为
model_state_dict,则表示它是model的子模块,值应为该子模块的状态字典。在加载状态字典时,子模块的前缀将附加到状态字典中。optim_state_dict (字典[字符串, 联合类型[字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 值类型]]]]], 列表[字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 值类型]]]]]]]]]) – 优化器状态字典: 要加载的优化器状态字典。
选项 (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。详情请参见 StateDictOptions。
- Returns
missing_keys 是一个包含模型 state_dict 中缺失键的字符串列表。
unexpected_keys 是一个包含模型 state_dict 中意外键的字符串列表。
- Return type
NamedTuple个带有missing_keys和unexpected_keys字段
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]¶
加载模型状态字典。
将
get_model_state_dict设置为状态字典的对应值以应用于模型。有关详细用法,请参见set_state_dict。- Parameters
模型 (nn.Module) – 要本地化的 nn.Module。
model_state_dict (Union[字典[Module, 字典[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, 列表[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 字典[str, Union[DTensor, ShardedTensor, Tensor, int, float, str, 列表[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 元组[Union[DTensor, ShardedTensor, Tensor, int, float, str]], 字典[str, ValueType]]]]]]]]]]]]]]) – (Union[字典[nn.Module, 字典[str, ValueType]], 字典[str, ValueType]]): 模型状态字典,用于加载。如果键为
model_state_dict,则表示它是model的子模块,值应为该子模块的状态字典。在加载状态字典时,子模块的前缀将附加到状态字典中。选项 (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。详情请参见 StateDictOptions。
- Returns
missing_keys 是一个包含缺失键的字符串列表
unexpected_keys 是一个包含意外键的字符串列表
- Return type
NamedTuple个带有missing_keys和unexpected_keys字段
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, *, optim_state_dict, options=None)[source]¶
加载优化器的状态字典。
将
get_optimizer_state_dict设置为状态字典以更新优化器。有关详细用法,请参见set_state_dict。- Parameters
模型 (nn.Module) – 要本地化的 nn.Module。
优化器 (Union[Optimizer, Iterable[Optimizer]]) – 用于优化
model的优化器。optim_state_dict (字典[字符串, 联合类型[字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 值类型]]]]], 列表[字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串, 列表[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 元组[联合类型[DTensor, ShardedTensor, 张量, 整数, 浮点数, 字符串]], 字典[字符串, 值类型]]]]]]]]]) – 优化器状态字典: 要加载的优化器状态字典。
选项 (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。详情请参见 StateDictOptions。
- Returns
请提供需要翻译的单词列表。
- Return type
请提供需要翻译的单词列表。
- class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True)[source]¶
这个数据类指定了 get_state_dict/set_state_dict 的工作方式。
full_state_dict: 如果设置为True,则返回的state_dict中的所有张量都将被收集。返回的state_dict中不会包含ShardedTensor和DTensor。cpu_offload: 将所有张量卸载到CPU。为防止CPU内存不足,如果full_state_dict也为真,则只有rank0会获取 state_dict,而其他所有rank将获取空的state_dict。ignore_frozen_params: 如果值为True,返回的state_dict将不包含任何被冻结的参数 –requires_grad是False。 默认值为False。keep_submodule_prefixes: whensubmodulesis not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule ismodule.pretrainand the full FQN of the parameter ispretrain.layer1.weightof the param. When this option is True, the parameter’s key in the returned state_dict will bepretrain.layer1.weight. If the options is False, the key will belayer1.weight. Note that ifkeep_submodule_prefixesis False, there may be conflicted FQNs, hence there should be only one submodule insubmodules.strict: thestrictoption whenset_state_dictcalls model.load_state_dict(). 默认值为False。