目录

张量并行性 - torch.distributed.tensor.parallel

张量并行(TP)构建在 PyTorch 分布式张量 (DTensor) 之上,并提供了几种并行方式:行-wise 并行、列-wise 并行和配对并行。

警告

张量并行 API 是实验性的,可能会发生变化。

使用张量并行化的并行化您的 nn.Module 的入口是:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]

在PyTorch中应用张量并行(TP)的API。我们根据并行计划(parallelize_plan)对模块或子模块进行并行化。并行计划包含 ParallelStyle,这表示用户希望如何对模块或子模块进行并行化。

用户还可以根据模块的完整限定名称(FQN)指定不同的并行方式。 该 API 本机支持二维并行,通过接受一个 n 维 device_mesh,并且用户只需指定执行张量并行的维度即可。

Parameters
  • 模块 (nn.Module) – 需要并行化的模块。

  • device_mesh (DeviceMesh) – 描述用于 DTensor 的设备网格拓扑结构的对象。

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]]) – 用于并行化模块的计划。它可以是 一个ParallelStyle对象,其中包含我们如何为张量并行准备输入/输出,或者它也可以是一个字典,键为模块的完全限定名称(FQN),值为其对应的ParallelStyle对象。

  • tp_mesh_dim (int) – 我们在维度 device_mesh 上执行张量并行的维度。

Returns

一个 nn.Module 对象并行化。

Return type

模块

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
>>>
>>> # Define the module.
>>> m = Model(...)
>>> m = parallelize_module(m, PairwiseParallel())
>>>

警告

PairwiseParallel 目前带有约束条件。如果你需要更细粒度的控制,需要传入一个包含模块全限定名和并行风格的字典代替。

张量并行支持以下并行风格:

class torch.distributed.tensor.parallel.style.RowwiseParallel(_prepare_input=<function make_input_shard_1d_last_dim>, _prepare_output=<function make_output_tensor>)[source]

将模块的行进行划分。 我们假设输入为一个分片的 DTensor,输出为一个 torch.Tensor

class torch.distributed.tensor.parallel.style.ColwiseParallel(_prepare_input=<function make_input_replicate_1d>, _prepare_output=<function make_sharded_output_tensor>)[source]

将张量或模块的列进行分区。 我们假设输入是一个复制的 DTensor,输出是一个分片的 torch.Tensor

class torch.distributed.tensor.parallel.style.PairwiseParallel(_prepare_input=None, _prepare_output=None)[source]

PairwiseParallel 将列方向和行方向的样式作为固定对进行拼接,就像 Megatron-LM(https://arxiv.org/abs/1909.08053) 所做的那样。 我们假设输入和输出都需要复制 DTensors。

警告

PairwiseParallel 目前不支持 nn.MultiheadAttention, nn.Transformer。一种解决方法是将 ColwiseParallelRowwiseParallel 应用于 transformer 的组件中。我们目前建议仅将 PairwiseParallel 用于偶数层的 MLP。

警告

序列并行性仍处于实验阶段,尚未进行评估。

class torch.distributed.tensor.parallel.style.SequenceParallel[source]

SequenceParallel 将列方向和行方向的拼接样式作为固定对一起使用,类似于 Megatron-LM 序列并行 ( https://arxiv.org/pdf/2205.05198.pdf) 的做法。 我们假设输入和输出都需要被分割为 DTensors。

警告

SequenceParallel 目前不支持 nn.MultiheadAttention, nn.Transformer。一种解决方法是将 ColwiseParallelRowwiseParallel 应用于 transformer 的组件中。我们目前建议仅将 SequenceParallel 用于偶数层 MLP。

由于张量并行性是建立在 DTensor 之上的,我们需要使用 DTensor 指定模块的输入和输出位置,以便它能够与前后的模块预期地进行交互。以下是一些用于输入/输出准备的函数:

torch.distributed.tensor.parallel.style.make_input_replicate_1d(input, device_mesh=None)[source]

在1-D设备网格上复制输入张量。此函数将在ParallelStyle中使用。

Parameters
  • 输入 (Union[torch.Tensor, DTensor]) – 此输入张量将在1-D DeviceMesh 上被复制。

  • device_mesh (DeviceMesh, 可选) – 用于放置 input 的 1-D 设备网格。 如果未传入 DeviceMesh 并且 input 是一个 DTensor, 将使用 input.device_mesh。 如果 DeviceMesh 不是 1-D,将抛出异常。 默认值: None

Returns

A DTensor 复制到 device_mesh

Return type

DTensor

torch.distributed.tensor.parallel.style.make_input_reshard_replicate(input, device_mesh)[source]

从不同 rank 上的张量构建一个分片的 DTensor,然后将其转换为复制的 DTensor。

Parameters
  • 输入 (torch.Tensor) – 每个rank上的输入张量,该张量由一个全局DTensor组成,该DTensor在维度0上进行分片,并且在1-DDeviceMesh上进行分片,然后将分片的DTensor转换为复制DTensor。

  • device_mesh (DeviceMesh, 可选) – 用于将 input 进行分片的 1-D 设备网格。 如果 DeviceMesh 不是 1-D,将抛出异常。 默认值: None

Returns

A DTensor sharded on dimension 0 over device_mesh

然后将其转换为复制。

Return type

DTensor

torch.distributed.tensor.parallel.style.make_input_shard_1d(input, device_mesh=None, dim=0)[source]

在1-D设备网格上将输入张量分片到dim。此函数将在ParallelStyle中使用。

Parameters
  • 输入 (Union[torch.Tensor, DTensor]) – 单个张量将在维度 dim 上沿 1-D DeviceMesh 进行分片。

  • device_mesh (DeviceMesh, 可选) – 用于将 input 进行分片的 1-D 设备网格。 如果未传递 DeviceMesh 并且 input 是一个 DTensor, 将使用 input.device_mesh。 如果 DeviceMesh 不是 1-D,将抛出异常。 默认值: None

  • dim (int, 可选) – input 张量的分片维度。 默认值: 0

Returns

A DTensor 在维度 dim 上分片为 device_mesh

Return type

DTensor

torch.distributed.tensor.parallel.style.make_input_shard_1d_last_dim(input, device_mesh=None)[source]

Wrapper func of make_input_shard_1d with dim = -1.

Parameters
  • 输入 (Union[torch.Tensor, DTensor]) – 这个单一张量将在最后一个维度上被分片到1-D DeviceMesh 上。

  • device_mesh (DeviceMesh, 可选) – 用于将 input 进行分片的 1-D 设备网格。 如果未传递 DeviceMesh 并且 input 是一个 DTensor, 将使用 input.device_mesh。 如果 DeviceMesh 不是 1-D,将抛出异常。 默认值: None

Returns

A DTensor 在最后一个维度上分片到 device_mesh

Return type

DTensor

torch.distributed.tensor.parallel.style.make_output_replicate_1d(output, device_mesh=None)[source]

将输出 DTensor 转换为复制的 DTensor。这将在并行风格中使用。

Parameters
  • 输出 (DTensor) – 要转换的模块的输出。

  • device_mesh (DeviceMesh, 可选) – 用于复制输出所需的对象,它必须是一个1D device_mesh 如果传入的不是1D device_mesh,我们将抛出异常。 如果没有传入 device_mesh,我们将复用输出中的那个。 默认值: None

Returns

一个 DTensor 对象被复制。

Return type

DTensor

torch.distributed.tensor.parallel.style.make_output_reshard_tensor(output, device_mesh=None)[source]

将输出的 DTensor 转换为分片的 DTensor,并返回本地张量。

Parameters
  • 输出 (DTensor) – 要转换的模块的输出。

  • device_mesh (DeviceMesh, 可选) – 用于分片输出所需的对象,它必须是一个 1D device_mesh, 如果传入了非 1D device_mesh,我们将抛出异常。 如果没有传入 device_mesh,我们将复用输出中的那个。 默认值: None

Returns

一个从输出DTensor转换而来的 torch.Tensor 对象。

Return type

张量

torch.distributed.tensor.parallel.style.make_output_shard_1d(output, device_mesh=None, dim=0)[source]

将输出 DTensor 转换为分片的 DTensor。这将在 ParallelStyle 中使用。

Parameters
  • 输出 (DTensor) – 要转换的模块的输出。

  • device_mesh (DeviceMesh, 可选) – 用于分片输出所需的对象,它必须是一个 1D device_mesh, 如果传入了非 1D device_mesh,我们将抛出异常。 如果没有传入 device_mesh,我们将复用输出中的那个。 默认值: None

  • dim (int) – 输出的分片维度。默认值:0

Returns

A DTensor 对象在给定维度上进行分片。

Return type

DTensor

torch.distributed.tensor.parallel.style.make_output_tensor(output, device_mesh=None)[source]

首先将输出的 DTensor 转换为复制的 DTensor,然后再将其转换为 Tensor。

Parameters
  • 输出 (DTensor) – 要转换的模块的输出。

  • device_mesh (DeviceMesh, 可选) – 用于复制输出所需的对象,它必须是一个一维的 device_mesh,如果传入非一维的 device_mesh,我们将抛出异常。如果没有传入 device_mesh, 我们将复用输出中的那个。默认值: None

Returns

一个从输出DTensor转换而来的 torch.Tensor 对象。

Return type

张量

目前,存在一些限制,使得MultiheadAttention 模块无法直接用于张量并行,因此我们建议用户尝试为每个参数使用ColwiseParallelRowwiseParallel。由于我们现在在MultiheadAttention模块的头部维度上进行并行化,可能需要进行一些代码更改。

我们还支持二维并行,其中我们将张量并行与数据并行相结合。 要与 FullyShardedDataParallel 集成, 用户只需显式调用以下 API:

torch.distributed.tensor.parallel.fsdp.enable_2d_with_fsdp()[source]

该 API 注册了 Tensor Parallelism(TP)与 FullyShardedDataParallel(FSDP)协同工作所需的扩展。我们首先根据 parallelize_plan 在一个模块或子模块内对参数进行并行化,然后让 FSDP 对分布式参数的本地张量进行重新分片,该张量本质上是一个 DTensor。

Returns

一个 bool 表示扩展注册是否成功。

Return type

布尔

要与 DistributedDataParallel 集成, 用户只需显式调用以下 API 即可:

torch.distributed.tensor.parallel.ddp.pre_dp_module_transform(module)[source]

在使用 DDP 时,启用 PyTorch 中 Tensor Parallelism (TP) 与 Data Parallelism (DP) 的组合性。我们需要在将参数包装到数据并行 API 之前,将那些是 DTensor 的参数转换为本地张量。然后我们注册两个钩子,一个用于在 forward 之前将本地张量转换回 DTensor,另一个用于在 forward 之后将 DTensor 转换回张量。通过这种方式集成,我们可以避免 DDP 对 DTensor 参数进行任何特殊处理,并使 DTensor 的梯度传播回 DP,例如 DDP 的梯度桶。

目前,此API仅支持 DistributedDataParallel。之后将支持其他DP方法,如FSDP。

Parameters

模块 (nn.Module) – 已应用TP的模块。

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
>>>
>>> # Define the module.
>>> m = module(...)
>>> parallelize_module(m, PairwiseParallel())
>>> m = pre_dp_module_transform(m)
>>> m = DDP(m)
>>>

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源