目录

模型并行

DistributedModelParallel 是 Pytorch 深度学习框架的主 API,用于分布式训练并优化 TorchRec。

class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)

模型并行化的入口点。

Parameters:
  • 模块 (nn.Module) – 模块用于包装。

  • 环境 (可选[分片环境]) – 有一个进程组的分片环境。

  • 设备 (可选[torch.device]) – 计算设备,默认为 cpu。

  • 计划 (可选[ShardingPlan]) – 分片时使用的计划,默认为 EmbeddingShardingPlanner.collective_plan()

  • 分片器 (可选[列表[ModuleSharder[nn.Module]]]) – ModuleSharders 个可用的用于分片的模块,默认为 EmbeddingBagCollectionSharder()

  • init_data_parallel (bool) – 数据并行模块可以延迟初始化,即它们在第一次前向传播时才延迟参数初始化。传递 True 以延迟数据并行模块的初始化。首先执行一次前向传播,然后调用 DistributedModelParallel.init_data_parallel()。

  • 初始化参数 (bool) – 初始化模块在元设备上的参数。

  • 数据并行封装器 (Optional[DataParallelWrapper]) – 为数据并行模块自定义封装器。

Example:

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)
copy(device: device) DistributedModelParallel

递归地通过调用每个模块的自定义复制过程将子模块复制到新设备,因为某些模块需要使用原始引用(例如 ShardedModule 用于推理)。

forward(*args, **kwargs) Any

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传播的实现需要在这个函数内定义,但应该在之后调用 Module 实例而不是这个,因为前者负责运行注册的钩子,而后者则无声地忽略它们。

init_data_parallel() None

查看 init_data_parallel 构造函数的参数以了解其使用方法。可以安全地多次调用此方法。

load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys

state_dict 复制参数和缓冲区到此模块及其后代。

如果 strictTrue,那么 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。

警告

如果 assignTrue 的话,优化器必须在调用 load_state_dict 之后创建, 除非 get_swap_module_params_on_conversion()True

Parameters:
  • state_dict (dict) – 一个包含参数和持久缓冲区的字典。

  • 严格 (bool, 可选) – 是否严格要求 state_dict 中的键 与这个模块的 state_dict() 函数返回的键匹配。默认值:True

  • assign (bool, optional) – 当 False 时,当前模块中的张量的属性被保留;当 True 时,状态字典中的张量的属性被保留。唯一例外是 requires_grad 字段中的 Default: ``False`

Returns:

  • missing_keys is a list of str containing any keys that are expected

    通过这个模块但缺少提供的state_dict

  • unexpected_keys is a list of str containing the keys that are not

    预期由本模块使用,但提供的state_dict中存在。

Return type:

NamedTuple 个带有 missing_keysunexpected_keys 字段

注意

如果一个参数或缓冲区被注册为None,并且其对应的键 在state_dict中存在, RuntimeError将引发一个错误。

property module: Module

直接访问分片模块的属性,这将不会被DDP、FSDP、DMP或其他任何并行化包装器包裹。

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

返回一个迭代器,遍历模块缓冲区,同时提供缓冲区的名称和缓冲区本身。

Parameters:
  • 前缀 (字符串) – 在所有缓冲区名称前添加的前缀。

  • recurse (bool, optional) – 如果为真,则返回此模块及其所有子模块的缓冲区。否则,仅返回直接属于此模块的缓冲区。默认值为真。

  • 删除重复项 (布尔值, 可选) – 是否在结果中移除重复的缓冲区。默认为True。

Yields:

(字符串, PyTorch张量) – 包含名称和缓冲区的元组

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

返回一个迭代器,遍历模块参数,同时提供参数的名称和参数本身。

Parameters:
  • 前缀 (字符串) – 在所有参数名称前添加的前缀。

  • recurse (bool) – 如果为 True,则返回该模块及其所有子模块的参数。否则,仅返回该模块的直接成员参数。

  • 删除重复项 (布尔值, 可选) – 是否在结果中移除重复的参数。默认为True。

Yields:

(字符串, 参数) – 元组包含名称和参数

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

返回一个包含模块整个状态的字典。

所有参数和持久缓冲区(例如运行平均值)都被包括在内。键是对应的参数和缓冲区名称。 设置为None的参数和缓冲区不被包括。

注意

返回的对象是一个浅拷贝。它包含模块的参数和缓冲区的引用。

警告

目前 state_dict() 也接受位置参数以 destination, prefixkeep_vars 的顺序。然而, 这将被废弃,并在未来的版本中强制使用关键字参数。

警告

请避免使用参数 destination,因为它不是为最终用户设计的。

Parameters:
  • 目的地 (字典, 可选) – 如果提供,模块的状态将更新到字典中,并返回相同的对象。否则,将创建一个 OrderedDict 并返回。默认值:None

  • 前缀 (字符串, 可选) – 在状态字典中组合参数和缓冲区名称的前缀。默认值:''

  • 保持变量 (bool, 可选) – 默认情况下,返回的状态字典中的Tensor个元素不会被自动求导。如果将其设置为True,则不会执行分离操作。默认值:False

Returns:

一个包含模块整个状态的字典

Return type:

字典

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源