目录

使用 Tensor Parallel (TP) 进行大规模 Transformer 模型训练

创建时间: 2024 年 4 月 19 日 |最后更新时间:2024 年 8 月 19 日 |上次验证: Nov 05, 2024

作者Wanchao LiangTianyu Liu

注意

编辑github 中查看和编辑本教程。

本教程演示了如何使用 Tensor Parallel 和 Fully Sharded Data Parallel 跨数百到数千个 GPU 训练大型 Transformer 类模型。

先决条件:

Tensor Parallel 的工作原理是什么?

张量并行 (TP) 最初是在 Megatron-LM 论文中提出的, 它是一种高效的模型并行技术,用于训练大规模 Transformer 模型。我们在本教程中提到的 Sequence Parallel (SP) 是 Tensor 的变体 Parallel that shards on sequence 维度上用于或进一步节省激活内存 在训练期间。随着模型变大,激活内存成为瓶颈,因此在 Tensor 中 Parallel training 它通常将 Sequence Parallel 应用于 or 层。nn.LayerNormRMSNormLayerNormRMSNorm

威震天-LM TP

图 1.表示 Transformer 模型的 MLP 和 Self-Attention 层上 Tensor Parallel 风格的分片,其中 attention/MLP 中的矩阵乘法都是通过分片计算进行的(图片来源

概括地说,PyTorch Tensor Parallel 的工作原理如下:

分片初始化

  • 确定要应用于每个层的模块,并通过调用 对初始化的模块进行分片。ParallelStyleparallelize_module

  • 并行化模块的模型参数将交换给 DTensor,而 DTensor 将负责使用分片计算运行并行化模块。

运行时前进/后退

  • 根据用户为每个 指定的输入/输出 DTensor 布局,它将运行适当的通信操作来转换输入/输出的 DTensor 布局(例如 、 和 )。ParallelStyleallreduceallgatherreduce_scatter

  • 为并行化层运行分片计算以节省计算/内存(例如 , , )。nn.Linearnn.Embedding

何时以及为何应应用 Tensor Parallel

PyTorch 全分片数据并行 (FSDP) 已经能够将模型训练扩展到特定的 GPU 数量。然而,当涉及到在模型大小和 GPU 数量方面进一步扩展模型训练时, 出现了许多其他挑战,可能需要将 Tensor Parallel 与 FSDP 相结合。

  1. 随着世界大小(GPU 数量)变得过大(超过 128/256 个 GPU),FSDP 集合(例如 )受到环形延迟的主导。 通过在 FSDP 上实施 TP/SP,仅将 FSDP 应用于主机间,可以将 FSDP 世界大小减少 8,从而将延迟成本降低相同数量。allgather

  2. 达到数据并行度限制,由于收敛和 GPU 内存限制,您无法将全局批量大小提高到 GPU 数量以上,张量/序列并行 是 “大致” 全局批量大小并继续使用更多 GPU 进行扩展的唯一已知方法。这意味着模型大小和 GPU 数量都可以继续扩展。

  3. 对于某些类型的模型,当本地批量大小变小时,TP/SP 可以生成更适合浮点运算 (FLOPS) 的矩阵乘法形状。

那么,在准备训练时,达到这些极限有多容易呢?截至目前,即使使用数千个 GPU 也是如此,使用数十亿或数万亿个令牌预训练大型语言模型 (LLM) 可能需要数月时间。

  • 在大规模训练 LLM 时,它总是会达到限制 1。例如,使用 2k GPU 训练 Llama 2 70B 35 天,需要 2k 尺度的多维并行性。

  • 当 Transformer 模型变大时(比如 Llama2 70B),也会很快达到限制 2。由于内存的原因,甚至不能单独使用 FSDP 和 local 和收敛约束。例如,Llama 2 全局批处理大小为 1K,因此不能在 2K GPU 上单独使用数据并行性。batch_size=1

如何应用 Tensor Parallel

PyTorch Tensor Parallel API 提供了一组模块级基元 () 来配置模型的每个单独层的分片,包括:ParallelStyle

  • ColwiseParallel和 :以列或行方式对 和 进行分片。RowwiseParallelnn.Linearnn.Embedding

  • SequenceParallel:对 、 、 等执行分片计算。nn.LayerNormnn.DropoutRMSNormPython

  • PrepareModuleInput和 :使用适当的通信操作配置模块输入/输出分片布局。PrepareModuleOutput

为了演示如何使用 PyTorch 原生 Tensor Parallel API,让我们看一个常见的 Transformer 模型。在本教程中,我们使用最新的 Llama2 模型作为参考 Transformer 模型实现,因为它在社区中也被广泛使用。

由于 Tensor Parallel 在一组设备上对单个张量进行分片,因此我们需要先设置分布式环境(例如 NCCL 通信器)。 张量并行是一种类似于 PyTorch DDP/FSDP 的单程序多数据 (SPMD) 分片算法,它在后台利用了 PyTorch DTensor 执行分片。它还利用 DeviceMesh 抽象(在后台管理 ProcessGroups)进行设备管理和分片。 要了解如何利用 DeviceMesh 设置多维并行度,请参阅本教程。Tensor Parallel 通常在每台主机内工作,因此让我们首先初始化一个连接主机内 8 个 GPU 的 DeviceMesh。

from torch.distributed.device_mesh import init_device_mesh

tp_mesh = init_device_mesh("cuda", (8,))

现在我们已经初始化了 DeviceMesh,让我们详细看看 Llama 2 模型架构,看看我们应该如何进行 Tensor Parallel 分片。 在这里,我们关注 核心 ,其中 Transformer 模型堆叠相同的 s 以扩大模型。TransformerBlockTransformerBlock

核心由一个层和一个层组成。让我们首先看一下更简单的层。 对于 Layer,它由三个 Linear 层组成,在其中执行 SwiGLU 风格的 MLP,查看其前向功能:TransformerBlockAttentionFeedForwardFeedForwardFeedForward

# forward in the FeedForward layer
def forward(self, x):
    return self.w2(F.silu(self.w1(x)) * self.w3(x))

它同时执行 和 matmuls,然后执行 matmuls,然后执行 matmul 以及 w1/w3 组合线性投影结果的结果。这意味着我们可以 使用 Tensor Parallelism 论文中的想法,以 colwise 方式对 w1/w3 Linear 层进行分片,并以 rowwise 方式对 Linear 层进行分片,这样 在所有三个层的末尾只有一个通信。使用 PyTorch 原生 Tensor Parallel,我们可以简单地为该层创建一个,如下所示:w1w3w2w2allreduceparallelize_planFeedForward

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "feed_foward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这就是我们使用 PyTorch Tensor Parallel API 为层配置分片的方式。请注意,用户只需指定如何对各个层进行分片,通信(例如 )将在后台进行。FeedForwardallreduce

继续进行 Layer。它由 、 、 线性层组成,将输入投影到 / / ,然后它与 线性 层一起执行注意力和输出投影。Tensor Parallelism 在这里打算对 q/k/v 投影和线性投影的逐行分片。因此,我们可以将 Attention 计划添加到我们刚刚起草的计划中:Attentionwqwkwvqkvwowotp_plan

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这几乎是我们需要将张量并行度应用于 .但是,我们应该注意的一点是,当按列对线性层进行分片时,线性层的输出将在最后一个张量维度上进行分片,而按行分片线性层直接接受在最后一个维度上分片的输入。 如果在逐列线性和逐行线性之间有更多张量操作(例如视图操作),我们需要将与形状相关的操作调整为分片形状。layer_tp_planTransformerBlock

对于 Llama 模型,在 attention 层中,有几个与形状相关的视图操作。特别是,对于 / / 线性层,激活张量在维度上进行分片,因此我们需要将 调整为局部 。wqwkwvnum_headsnum_headsnum_heads

最后,我们需要调用 API 来使每个计划的 PLAN 有效。在后台,它将模型内部和层的参数分配给 DTensor,并在必要时为模型输入和输出注册通信钩子(分别在每个模块之前和之后):parallelize_moduleTransformerBlockAttentionFeedForward

for layer_id, transformer_block in enumerate(model.layers):
    layer_tp_plan = {...}  # i.e. the plan we just generated

    # Adjust attention module to use the local number of heads
    attn_layer = transformer_block.attention
    attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
    attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

    parallelize_module(
        module=transformer_block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan,
    )

现在我们已经为每个 制定了分片计划,通常第一层和最终投影层中有一个,用户可以选择按行或按列分片到第一个投影层,按列分片到最后一个投影层,并指定适当的输入和输出布局。 下面是一个示例:TransformerBlocknn.Embeddingnn.Linearnn.Embeddingnn.Linear

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
        ),
        "output": ColwiseParallel(
            output_layouts=Replicate(),
        ),
    }
)

注意

如果要分区的模型太大而无法放入 CPU 内存,则可以使用设备初始化(例如,首先在 meta 设备上初始化模型,然后分片层,然后具体化模型),或者在 Transformer 模型初始化期间逐层并行化。metaTransformerBlock

将序列平行应用于图层LayerNorm/RMSNorm

Sequence Parallel 在上图所示的 Tensor Parallel 之上工作。与基本的 Tensor Parallel 相比,Sequence Parallel 仅在模块和模块中对 Tensor 进行分片,并保持其模块输入和输出(即前向传递中的激活和向后传递中的梯度)的复制,而 Sequence Parallel 则将它们保留在序列维度上。AttentionFeedForward

在典型的 中,forward 函数结合了范数层 ( 或 )、注意力层、前馈层和残差连接。例如:TransformerBlockLayerNormRMSNorm

# forward in a TransformerBlock
def forward(self, x):
    h = x + self.attention(self.attention_norm(x))
    out = h + self.feed_forward(self.ffn_norm(h))
    return out

在大多数用例中,激活(和梯度)是 and 模块之外的形状。在 DTensor 的语言中,Sequence Parallel 使用模块前向/后向布局执行激活计算。 按照前面的代码示例,下面的代码演示了我们如何将 Sequence Parallel 应用于 :[batch size, sequence length, hidden dimension]AttentionFeedForwardShard(1)TransformerBlock

首先,我们导入 Sequence Parallel 所需的依赖项:

from torch.distributed.tensor.parallel import (
    PrepareModuleInput,
    SequenceParallel,
)

接下来,让我们调整 以在层上启用 sequence parallel:layer_tp_planRMSNorm

layer_tp_plan = {
    # Now the input and output of SequenceParallel has Shard(1) layouts,
    # to represent the input/output tensors sharded on the sequence dimension
    "attention_norm": SequenceParallel(),
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "ffn_norm": SequenceParallel(),
    "feed_forward": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

可以看到,我们现在用于将 Attention 和 FeedForward 层的模块输入布局从 修改为 ,并将其输出布局标记为 。 就像 Tensor Parallelism 发生的情况一样,只需要指定输入和输出的 Tensor 分片布局,层之间的通信就会自动发生。PrepareModuleInputShard(1)Replicate()Shard(1)

请注意,使用 Sequence Parallel 时,我们假设 a 的输入和输出始终在 sequence 维度上分片,以便可以无缝连接多个。 这可以通过将起始层的输出和最终投影层的输入显式指定为:TransformerBlockTransformerBlocksnn.Embeddingnn.LinearShard(1)

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            output_layouts=Replicate()
        ),
    }
)

Apply Loss Parallel

Loss Parallel 是一种相关技术,用于在计算损失函数时节省内存和通信,因为模型输出通常非常大。在 Loss Parallel 中,当模型输出在(通常是巨大的)词汇维度上分片时,可以有效地计算交叉熵损失,而无需将所有模型输出收集到每个 GPU 。这不仅显著降低了内存消耗,还通过减少通信开销和并行执行分片计算来提高训练速度。下图简要说明了 Loss Parallel 如何通过执行分片计算来避免将所有模型输出收集到每个 GPU 。

损失平行

图 2.在一个 GPU 上并行损失的交叉熵损失前向计算。蓝色表示分片张量;绿色表示复制的张量;黄色表示具有部分值(要全部缩减)的张量。黑色箭头是本地计算;红色箭头是 GPU 中的函数集合。

在 PyTorch Tensor Parallel API 中,可以通过 Context manager 启用 Loss Parallel ,使用它可以直接使用或不修改代码的其他部分。loss_paralleltorch.nn.functional.cross_entropytorch.nn.CrossEntropyLoss

要应用 Loss Parallel,模型预测(通常为 shape )应在词汇维度上进行分片。这可以通过标记最后一个线性投影图层输出的输出布局来轻松完成:[batch size, sequence length, vocabulary size]

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            # use DTensor as the output
            use_local_output=False,
        ),
    },
)

在上面的代码中,我们还在输出之前将 Sequence Parallel 应用于范数层。我们申请让输出保持为 DTensor,以便与上下文管理器一起使用。之后,可以简单地调用 cross_entropy 损失函数,如下所示。请注意,反向计算也需要在上下文中进行。use_local_output=Falseloss_parallel

import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel

pred = model(input_ids)
with loss_parallel():
    # assuming pred and labels are of the shape [batch, seq, vocab]
    loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
    loss.backward()

将 Tensor Parallel 与 Fully Sharded Data Parallel 组合在一起

现在我们已经展示了如何将 Tensor/Sequence Parallel 应用于模型,那么我们还来看看 Tensor Parallel 和 Fully Sharded Data Parallel 如何协同工作。 由于 Tensor Parallelism 会产生阻止计算的通信,因此我们希望确保它在快速通信通道(例如 NVLink)中运行。 在实践中,我们通常在每个主机内应用 Tensor Parallel,并在主机之间应用 Fully Sharded Data Parallel。

FSDP + TP

图 3.FSDP 和 TP 在不同的设备维度上工作,FSDP 通信发生在主机之间,而 TP 通信发生在主机内部。

这种 2-D 并行模式可以通过 2-D DeviceMesh 轻松表达,我们只需要将每个 “sub” DeviceMesh 传递给每个单独的并行 API:

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices

model = Model(...)

tp_plan = {...}

# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)

这将使我们能够轻松地在每个主机(主机内)内应用 Tensor Parallel,并在主机(主机间)之间应用 FSDP,对 Llama 模型进行 0 代码更改。 Tensor(Model) Parallel 和 Data Parallel 技术相结合,能够继续使用大量 GPU 继续增加模型大小和高效训练。

结论

本教程演示了如何将 Tensor Parallel 与 Fully Sharded Data Parallel 结合使用,在数百到数千个 GPU 上训练大型 Transformer 类模型。 它解释了如何将 Tensor Parallel 应用于模型的不同部分,而无需对模型本身进行代码更改。Tensor Parallel 是一种用于大规模训练的高效模型并行技术。

要查看本教程中解释的完整端到端代码示例,请参阅 pytorch/examples 存储库中的 Tensor Parallel 示例

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源