张量并行性 - torch.distributed.tensor.parallel¶
Tensor Parallelism(TP) 构建在 DistributedTensor(DTensor) 和 提供了几种 Parallelism 样式:Rowwise、Colwise 和 Pairwise Parallelism。
警告
Tensor Parallelism API 是实验性的,可能会发生变化。
使用 Tensor Parallelism 并行化的入口点是:nn.Module
- torch.distributed.tensor.parallel 的parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[来源]¶
在 PyTorch 中应用张量并行 (TP) 的 API。我们并行化模块 或基于parallelize_plan sub_modules。parallelize_plan包含 ,它指示用户对模块或sub_module的需要 进行并行化。
ParallelStyle
用户还可以为每个模块指定不同的并行样式 full qualifed name (FQN)。 API 通过接受 n 维 device_mesh 来原生支持 2D 并行性 用户只需要指定我们执行张量并行的维度。
- 参数
module () – 要并行化的模块。
nn.Module
device_mesh () – 描述网格拓扑的对象 的设备数量。
DeviceMesh
parallelize_plan (Union[, Dict[str, ]]) – 用于并行化模块的计划。它可以是一个对象,其中包含 我们为 Tensor Parallelism 准备 input/output,或者它可以是一个 dict 及其对应的对象。
ParallelStyle
ParallelStyle
ParallelStyle
ParallelStyle
tp_mesh_dim (int) – 我们执行位置的维度 Tensor Parallelism 打开。
device_mesh
- 结果
并行化的对象。
nn.Module
- 返回类型:
- 例::
>>> from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel >>> >>> # Define the module. >>> m = Model(...) >>> m = parallelize_module(m, PairwiseParallel()) >>>
警告
PairwiseParallel
现在带有约束。如果您需要更精细 granularity 时,您需要传入模块 FQN 和 parallel 样式的 dict。
Tensor Parallelism 支持以下并行样式:
- 类 torch.distributed.tensor.parallel.style 中。RowwiseParallel[来源]¶
对模块的行进行分区。 我们假设 input 是 sharded,output 是 replicated 。
DTensor
DTensor
- 类 torch.distributed.tensor.parallel.style 中。ColwiseParallel[来源]¶
对张量或模块的列进行分区。 我们假设 input 是 replicad,output 是 sharded 。
DTensor
DTensor
- 类 torch.distributed.tensor.parallel.style 中。PairwiseParallel[来源]¶
PairwiseParallel 将 colwise 和 rowwise 样式连接为固定样式 就像 Megatron-LM(https://arxiv.org/abs/1909.08053) 正在做的事情一样。 我们假设输入和输出都需要复制 DTensor。
警告
PairwiseParallel 目前仅支持 或 偶数层 MLP。
nn.Multihead Attention
nn.Transformer
由于 Tensor Parallelism 构建在 DTensor 之上,因此我们需要指定 使用 DTensor 放置模块的输入和输出位置,以便它可以预期地 与 Before, and After 的 Module 交互。以下是函数 用于输入/输出准备:
- torch.distributed.tensor.parallel.style 的make_input_replicate_1d(输入,device_mesh=无)[来源]¶
在 1-D 设备网格上复制输入张量。此函数将在 ParallelStyle 中使用。
- torch.distributed.tensor.parallel.style 的make_input_shard_1d(input, device_mesh=None, dim=0)[来源]¶
一维设备网格上的分片输入张量。此函数将在 ParallelStyle 中使用。
dim
- 参数
device_mesh (可选) – 将分片的一维设备网格。 如果传递 no 并且是 ,则将使用 input.device_mesh。 如果不是 1-D,则将引发异常。 违约:
DeviceMesh
input
DeviceMesh
input
DTensor
DeviceMesh
None
dim (int, optional) – tensor 的分片维度。 默认值:0
input
- 结果
维度上 的分片 。
DTensor
dim
device_mesh
- 返回类型:
DTensor
- torch.distributed.tensor.parallel.style 的make_input_shard_1d_last_dim(输入,device_mesh=无)[来源]¶
包装器 func 与 = -1。
make_input_shard_1d
dim
- torch.distributed.tensor.parallel.style 的make_output_replicate_1d(输出,device_mesh=无)[来源]¶
将 Output DTensor 转换为复制的 DTensor。这将在 ParallelStyle 中使用。
- 参数
output () – 要转换的模块的输出。
DTensor
device_mesh ( 可选 ) – 对象需要复制输出,并且它必须是 1D 的,如果传入非 1D ,我们将引发异常。 如果传入 no,我们将重用 output 中的那个。 违约:
DeviceMesh
device_mesh
device_mesh
device_mesh
None
- 结果
已复制的对象。
DTensor
- 返回类型:
DTensor
- torch.distributed.tensor.parallel.style 的make_output_tensor(输出,device_mesh=无)[来源]¶
首先将 Output DTensor 转换为复制的 DTensor,然后将其转换为 Tensor。
- torch.distributed.tensor.parallel.style 的make_output_shard_1d(输出,device_mesh=无,dim=0)[来源]¶
将 Output DTensor 转换为分片的 DTensor。这将在 ParallelStyle 中使用。
- 参数
output () – 要转换的模块的输出。
DTensor
device_mesh ( 可选 ) – 对象需要对输出进行分片,并且它必须是 1D 的,如果传入非 1D ,我们将引发异常。 如果传入 no,我们将重用 output 中的那个。 违约:
DeviceMesh
device_mesh
device_mesh
device_mesh
None
dim (int) – 为输出分片 dim。默认值:0
- 结果
在给定 dim 上分片的对象。
DTensor
- 返回类型:
DTensor
目前,有一些约束使 nn.MultiheadAttention 模块实现开箱即用的张量并行性,因此我们构建了这个multihead_attention
模块。此外,在 中,我们会自动
在指定 时交换到此自定义模块。parallelize_module
nn.MultiheadAttention
PairwiseParallel
- torch.distributed.tensor.parallel.multihead_attention_tp 类。TensorParallelMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=无, vdim=无, batch_first=False, device=无, DTYPE=无,tp_size=1,self_attention=真)[来源]¶
来自 Transformer 模型的多头 Attention 块。 由于我们需要对注意力层进行一些自定义, 我们正在编写一个自定义但在数学上等效的 attention 模块。
请注意: 我们现在只支持自我注意的情况 limited input args,我们还假设输入张量 的维度为 3。虽然我们确实实现了 logic 对于多头注意力,它没有经过完全测试。
我们还启用了 2D 并行性以与 集成。
用户只需显式调用以下 API:FullyShardedDataParallel