模型并行¶
DistributedModelParallel
是使用 TorchRec 优化进行分布式训练的主要 API。
- torchrec.distributed.model_parallel 类。DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: 可选[device] = None,计划:可选[ShardingPlan] = None, 分片器: 可选[List[ModuleSharder[Module]]] = None, init_data_parallel: 布尔 = True, init_parameters: bool = True, data_parallel_wrapper: 可选 [DataParallelWrapper] = 无)¶
模型并行性的入口点。
- 参数
模块 (nn.Module) – 要包装的模块。
env (Optional[ShardingEnv]) – 具有进程组的分片环境。
device (Optional[torch.device]) – 计算设备,默认为 cpu。
plan (Optional[ShardingPlan]) – 计划在分片时使用,默认为 EmbeddingShardingPlanner.collective_plan()。
分片器 (Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders 可用 进行分片,默认为 EmbeddingBagCollectionSharder()。
init_data_parallel (bool) – 数据并行模块可以是 lazy,即它们延迟 参数初始化,直到第一次 forward pass。将 True 传递给延迟 数据并行模块的初始化。先前传,然后调用 DistributedModelParallel.init_data_parallel() 中。
init_parameters (bool) – 初始化仍在 Meta Device 上的模块的参数。
data_parallel_wrapper (Optional[DataParallelWrapper]) – 数据的自定义包装器 并行模块。
例:
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
- copy(device: device) DistributedModelParallel ¶
通过调用每个模块的自定义复制,递归地将子模块复制到新设备 进程,因为某些模块需要使用原始引用(如 ShardedModule 进行推理)。
- forward(*args, **kwargs) Any ¶
定义每次调用时执行的计算。
应被所有子类覆盖。
注意
尽管前向传递的配方需要在 这个函数,之后应该调用 instance 而不是 this,因为前者负责运行 registered hooks,而后者则默默地忽略它们。
Module
- init_data_parallel() 无 ¶
有关用法,请参见 c-tor 参数init_data_parallel。 多次调用此方法是安全的。
- load_state_dict(state_dict: OrderedDict[str, Tensor], 前缀: str = '', 严格: bool = True) _IncompatibleKeys ¶
将参数和缓冲区复制到此模块及其后代中。
如果 是 ,则 ,则 的
键必须与返回的键完全匹配 通过这个模块的功能。
strict
True
state_dict()
- 参数
- 结果
- missing_keys 是一个 str 列表,其中包含预期的任何键
但提供的 .
state_dict
- unexpected_keys 是一个 str 列表,其中包含不是
此模块预期,但存在于提供的 .
state_dict
- 返回类型:
NamedTuple
with 和 字段missing_keys
unexpected_keys
注意
如果参数或缓冲区注册为 及其对应的键 存在于 中
,将引发 。
None
RuntimeError
- property module: 模块¶
属性直接访问分片模块,不会被 DDP 包裹, FSDP、DMP 或任何其他并行包装器。
- named_buffers(prefix: str = '', 递归: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]] ¶
返回模块缓冲区的迭代器,从而产生缓冲区的名称以及缓冲区本身。
- 参数
prefix (str) – 所有缓冲区名称前面的前缀。
recurse (bool, optional) – 如果为 True,则生成此模块的缓冲区 和所有子模块。否则,仅生成 是此模块的直接成员。默认为 True。
remove_duplicate (bool, optional) – 是否删除结果中的重复缓冲区。默认为 True。
- 产量:
(str, Torch。Tensor) – 包含名称和缓冲区的元组
例:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', 递归: bool = True, remove_duplicate: bool = true) Iterator[Tuple[str, parameter]] ¶
返回模块参数的迭代器,从而产生参数的名称以及参数本身。
- 参数
prefix (str) – 所有参数名称前面的前缀。
recurse (bool) – 如果为 True,则生成此模块的参数 和所有子模块。否则,仅生成 是此模块的直接成员。
remove_duplicate (bool, optional) – 是否删除重复的 参数。默认为 True。
- 产量:
(str, Parameter) – 包含名称和参数的元组
例:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) dict[str, any] ¶
返回一个包含对模块整个 state 的引用的字典。
参数和持久缓冲区(例如运行平均值)都是 包括。键是相应的参数和缓冲区名称。 不包括设置为 (Parameters) 和缓冲区 (buffers)。
None
注意
返回的对象是浅表副本。它包含引用 添加到模块的参数和缓冲区中。
警告
目前还接受 和 order 的位置参数。然而 这已被弃用,关键字参数将在 未来版本。
state_dict()
destination
prefix
keep_vars
警告
请避免使用 argument ,因为它不是 专为最终用户设计。
destination
- 参数
destination (dict, optional) – 如果提供,则 module 的状态将 被更新到 dict 中,并返回相同的对象。 否则,将创建并返回 an。 违约:。
OrderedDict
None
prefix (str, optional) – 添加到参数和缓冲区的前缀 names 来组成 state_dict 中的键。默认值:.
''
keep_vars (bool, optional) – 默认情况下为 s 在 state dict 中返回的 SET 与 autograd 分离。如果是 设置为 ,则不会执行分离。 违约:。
Tensor
True
False
- 结果
一个包含模块整个状态的字典
- 返回类型:
dict (字典)
例:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']