目录

模型并行

DistributedModelParallel是使用 TorchRec 优化进行分布式训练的主要 API。

torchrec.distributed.model_parallel 类DistributedModelParallelmodule Moduleenv Optional[ShardingEnv] = Nonedevice 可选[device] = None计划可选[ShardingPlan] = None分片器: 可选[List[ModuleSharder[Module]]] = Noneinit_data_parallel 布尔 = Trueinit_parameters bool = Truedata_parallel_wrapper 可选 [DataParallelWrapper] = )

模型并行性的入口点。

参数
  • 模块nn.Module) – 要包装的模块。

  • envOptional[ShardingEnv]) – 具有进程组的分片环境。

  • deviceOptional[torch.device]) – 计算设备,默认为 cpu。

  • planOptional[ShardingPlan]) – 计划在分片时使用,默认为 EmbeddingShardingPlanner.collective_plan()。

  • 分片器Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders 可用 进行分片,默认为 EmbeddingBagCollectionSharder()。

  • init_data_parallelbool) – 数据并行模块可以是 lazy,即它们延迟 参数初始化,直到第一次 forward pass。将 True 传递给延迟 数据并行模块的初始化。先前传,然后调用 DistributedModelParallel.init_data_parallel() 中。

  • init_parametersbool) – 初始化仍在 Meta Device 上的模块的参数。

  • data_parallel_wrapperOptional[DataParallelWrapper]) – 数据的自定义包装器 并行模块。

例:

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)
copydevice device DistributedModelParallel

通过调用每个模块的自定义复制,递归地将子模块复制到新设备 进程,因为某些模块需要使用原始引用(如 ShardedModule 进行推理)。

forward*args**kwargs Any

定义每次调用时执行的计算。

应被所有子类覆盖。

注意

尽管前向传递的配方需要在 这个函数,之后应该调用 instance 而不是 this,因为前者负责运行 registered hooks,而后者则默默地忽略它们。Module

init_data_parallel

有关用法,请参见 c-tor 参数init_data_parallel。 多次调用此方法是安全的。

load_state_dictstate_dict OrderedDict[str Tensor]前缀 str = ''严格 bool = True _IncompatibleKeys

将参数和缓冲区复制到此模块及其后代中。

如果 是 ,则 ,则 的 键必须与返回的键完全匹配 通过这个模块的功能。strictTruestate_dict()

警告

如果是,则必须在 对 unless 的调用是 。assignTrueget_swap_module_params_on_conversion()True

参数
  • state_dictdict) – 包含参数的 dict 和 持久缓冲区。

  • strictbooloptional) – 是否严格执行 key in 匹配此模块的函数返回的键。违约:state_dict()True

  • assignbooloptional) – 当 时,张量的属性 在当前模块中保留,而 when , 状态 dict 中 Tensor 的属性被保留。唯一的 exception 是FalseTruerequires_grad Default: ``False`

结果

  • missing_keys 是一个 str 列表,其中包含预期的任何键

    但提供的 .state_dict

  • unexpected_keys 是一个 str 列表,其中包含不是

    此模块预期,但存在于提供的 .state_dict

返回类型

NamedTuplewith 和 字段missing_keysunexpected_keys

注意

如果参数或缓冲区注册为 及其对应的键 存在于 中将引发 。NoneRuntimeError

property module 模块

属性直接访问分片模块,不会被 DDP 包裹, FSDP、DMP 或任何其他并行包装器。

named_buffersprefix str = ''递归 bool = Trueremove_duplicate: bool = True Iterator[Tuple[str Tensor]]

返回模块缓冲区的迭代器,从而产生缓冲区的名称以及缓冲区本身。

参数
  • prefixstr) – 所有缓冲区名称前面的前缀。

  • recursebooloptional) – 如果为 True,则生成此模块的缓冲区 和所有子模块。否则,仅生成 是此模块的直接成员。默认为 True。

  • remove_duplicatebooloptional) – 是否删除结果中的重复缓冲区。默认为 True。

产量

(str, Torch。Tensor) – 包含名称和缓冲区的元组

例:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parametersprefix str = ''递归 bool = Trueremove_duplicate: bool = true Iterator[Tuple[str parameter]]

返回模块参数的迭代器,从而产生参数的名称以及参数本身。

参数
  • prefixstr) – 所有参数名称前面的前缀。

  • recursebool) – 如果为 True,则生成此模块的参数 和所有子模块。否则,仅生成 是此模块的直接成员。

  • remove_duplicatebooloptional) – 是否删除重复的 参数。默认为 True。

产量

(str, Parameter) – 包含名称和参数的元组

例:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dictdestination Optional[Dict[str Any]] = Noneprefix str = ''keep_vars bool = False dict[str any]

返回一个包含对模块整个 state 的引用的字典。

参数和持久缓冲区(例如运行平均值)都是 包括。键是相应的参数和缓冲区名称。 不包括设置为 (Parameters) 和缓冲区 (buffers)。None

注意

返回的对象是浅表副本。它包含引用 添加到模块的参数和缓冲区中。

警告

目前还接受 和 order 的位置参数。然而 这已被弃用,关键字参数将在 未来版本。state_dict()destinationprefixkeep_vars

警告

请避免使用 argument ,因为它不是 专为最终用户设计。destination

参数
  • destinationdictoptional) – 如果提供,则 module 的状态将 被更新到 dict 中,并返回相同的对象。 否则,将创建并返回 an。 违约:。OrderedDictNone

  • prefixstroptional) – 添加到参数和缓冲区的前缀 names 来组成 state_dict 中的键。默认值:.''

  • keep_varsbooloptional) – 默认情况下为 s 在 state dict 中返回的 SET 与 autograd 分离。如果是 设置为 ,则不会执行分离。 违约:。TensorTrueFalse

结果

一个包含模块整个状态的字典

返回类型

dict (字典)

例:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源