目录

大规模Transformer模型训练与张量并行(TP)

创建日期:2024年4月19日 | 最后更新日期:2024年8月19日 | 最后验证日期:2024年11月5日

作者: 梁万超, 刘天宇

注意

edit 查看和编辑此教程在 github

本教程演示了如何使用张量并行和完全分片数据并行,在数百到数千块 GPU 上训练一个大型类似 Transformer 的模型。

Prerequisites:

张量并行如何工作?

张量并行(TP)最初在 Megatron-LM 论文中提出, 它是一种高效的模型并行技术,用于训练大规模的Transformer模型。 我们在本教程中提到的 序列并行(SP)是张量并行的一种变体, 它在序列维度上进行分片,用于 nn.LayerNormRMSNorm 以进一步节省训练期间的激活内存。 随着模型规模变大,激活内存成为瓶颈,因此在张量并行训练中通常会应用序列并行到 LayerNormRMSNorm 层。

Megatron-LM TP

图1. 表示在Transformer模型的MLP和自注意力层中的张量并行风格的分片,其中注意力/MLP中的矩阵乘法通过分片计算完成(图片来源

从高层次来看,PyTorch 张量并行的工作方式如下:

分片初始化

  • 确定要应用到每一层的 ParallelStyle,并通过调用 parallelize_module 对初始化的模块进行分片。

  • 并行化的模块将会有其模型参数被转换为DTensors,DTensor将负责使用分片计算来运行并行化的模块。

运行时 forward/backward

  • 根据用户为每个 ParallelStyle 指定的输入/输出 DTensor 布局,它将运行适当的通信操作以转换输入/输出的 DTensor 布局(例如 allreduceallgatherreduce_scatter)。

  • 运行分片计算以节省计算/内存(例如,nn.Linearnn.Embedding)。

何时以及为何应使用张量并行

PyTorch 全量分片数据并行(FSDP)已经具备将模型训练扩展到特定数量 GPU 的能力。然而,当涉及到进一步扩展模型规模和 GPU 数量时,会涌现出许多额外的挑战,这些挑战可能需要结合张量并行与 FSDP 来解决。

  1. 随着世界规模(GPU数量)变得异常庞大(超过128/256个GPU),FSDP集体通信(如allgather)正受到环形延迟的主导。 通过在FSDP之上实现TP/SP,将FSDP世界规模减少8倍,通过将FSDP仅应用于主机间通信,从而减少相同数量的延迟成本。

  2. 达到数据并行的限制,由于收敛性和GPU内存的限制,无法将全局批量大小提升到超过GPU数量。此时,张量/序列并行是唯一已知的方法来“大致”调整全局批量大小,并继续使用更多GPU进行扩展。这意味着模型规模和GPU数量都可以继续扩展。

  3. 对于某些类型的模型,当局部批量大小变小时,TP/SP 可以产生更适合浮点运算(FLOPS)的矩阵乘法形状。

所以,在预训练过程中,要达到这些限制有多容易呢?截至目前,使用数十亿或数万亿个标记来预训练一个大型语言模型(LLM)可能需要数月时间,即使使用数千块 GPU 也是如此。

  • 在大规模训练LLM时,总会遇到限制1。例如,使用2000块GPU训练Llama 2 70B模型35天,当规模达到2000时,需要多维并行性。

  • 当Transformer模型变得更大(例如Llama2 70B)时,也会很快遇到限制2。即使使用本地batch_size=1,由于内存和收敛限制,也不能单独使用FSDP。例如,Llama 2的全局批处理大小为1K,因此在2K个GPU上无法单独使用数据并行。

如何应用张量并行

PyTorch 张量并行API提供了一组模块级别的基础操作 (ParallelStyle),用于配置模型中每一层的分片方式,包括:

  • ColwiseParallelRowwiseParallel: 按列或行方式分割 nn.Linearnn.Embedding

  • SequenceParallel: 在 nn.LayerNorm, nn.Dropout, RMSNormPython, 等上执行分片计算。

  • PrepareModuleInput and PrepareModuleOutput: 配置模块输入/输出分片布局,使用适当的通信操作。

要演示如何使用PyTorch原生的张量并行API,让我们来看一个常见的Transformer模型。在本教程中,我们使用最新的Llama2模型作为参考的Transformer模型实现,因为它在社区中也被广泛使用。

由于张量并行在一组设备上对单个张量进行分片,我们首先需要设置分布式环境(如NCCL通信器)。张量并行是一种类似于PyTorch DDP/FSDP的单程序多数据(SPMD)分片算法,它在内部利用PyTorch DTensor来执行分片。它还使用DeviceMesh抽象(在内部管理ProcessGroups)来进行设备管理和分片。要了解如何使用DeviceMesh设置多维并行,请参阅本教程。张量并行通常在每个主机内工作,因此让我们首先初始化一个连接主机内8个GPU的DeviceMesh。

from torch.distributed.device_mesh import init_device_mesh

tp_mesh = init_device_mesh("cuda", (8,))

现在我们已经初始化了DeviceMesh,让我们详细查看Llama 2模型架构,并了解应该如何进行Tensor Parallel分片。 在这里,我们重点关注核心 TransformerBlock,其中Transformer模型堆叠相同的 TransformerBlock 来扩展模型。

核心 TransformerBlock 包含一个 Attention 层和一个 FeedForward 层。让我们先看一下更简单的 FeedForward 层。 对于 FeedForward 层,它包含三个线性层,其中执行一种 SwiGLU 风格的 MLP,查看其前向函数:

# forward in the FeedForward layer
def forward(self, x):
    return self.w2(F.silu(self.w1(x)) * self.w3(x))

它同时执行 w1w3 个矩阵乘法,随后进行一个 w2 矩阵乘法,使用组合后的 w1/w3 线性投影结果。这意味着我们可以借鉴 Tensor Parallelism 论文中的思想,以列并行的方式分割 w1/w3 线性层,并以行并行的方式分割 w2 线性层,这样在三个层处理完后,只会有一个 allreduce 通信发生。使用 PyTorch 原生的 Tensor Parallel,我们可以通过以下方式为 parallelize_planFeedForward 层创建一个:

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "feed_foward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这就是我们使用PyTorch张量并行API为FeedForward层配置分片的方式。请注意,用户只需指定如何分片各个层,通信(例如allreduce)将在后台自动处理。

继续进入第Attention层。它包含wqwkwv个线性层,用于将输入投影到q/ k / v,然后通过wo个线性层执行注意力机制和输出投影。此处的张量并行旨在对q/k/v投影进行列式分片,对wo线性投影进行行式分片。因此,我们可以将注意力计划添加到我们刚刚起草的tp_plan中:

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这几乎是我们需要应用于layer_tp_plan的张量并行所需的操作,以实现对TransformerBlock的张量并行。然而,我们需要注意的是,当列式分片线性层时,线性层的输出会在最后一个张量维度上被分片,而行式分片线性层直接接受在最后一个维度上分片的输入。 如果在列式线性和行式线性之间有任何更多的张量操作(例如视图操作),我们需要调整相关的形状相关操作以适应分片后的形状。

对于Llama模型,在注意力层中有一些与形状相关的视图操作。特别是列并行的wq/ wk/ wv线性层,激活张量在num_heads维度上进行分片,因此我们需要将num_heads调整为本地num_heads

最后,我们需要调用 parallelize_module API 使每个 TransformerBlock 的计划生效。在内部,它会将模型参数分布在 AttentionFeedForward 层的 DTensors 中,并在需要时为模型输入和输出注册通信钩子(分别在每个模块之前和之后):

for layer_id, transformer_block in enumerate(model.layers):
    layer_tp_plan = {...}  # i.e. the plan we just generated

    # Adjust attention module to use the local number of heads
    attn_layer = transformer_block.attention
    attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
    attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

    parallelize_module(
        module=transformer_block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan,
    )

现在我们已经详细说明了每个 TransformerBlock 的分片计划,通常第一层有一个 nn.Embedding,最后有一个 nn.Linear 投影层,用户可以选择按行或按列分片到第一个 nn.Embedding,并按列分片到最后一个 nn.Linear 投影层,通过指定适当的输入和输出布局。 这里有一个示例:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
        ),
        "output": ColwiseParallel(
            output_layouts=Replicate(),
        ),
    }
)

注意

如果要分割的模型太大,无法放入CPU内存,可以使用 meta 设备初始化(例如,首先在meta设备上初始化模型,分片层,并将模型具体化),或在Transformer模型初始化期间逐层并行化 TransformerBlock 层。

将序列并行应用到 LayerNorm/RMSNorm

序列并行是在上述张量并行的基础上进行的。与仅在 Attention 模块和 FeedForward 模块内分割张量的基本张量并行不同,序列并行在序列维度上保持它们的分割,同时保留其模块输入和输出(即前向传递中的激活值和反向传递中的梯度)的复制。

在典型的 TransformerBlock 中,前向函数结合了归一化层(LayerNormRMSNorm)、一个注意力层、一个前馈层以及残差连接。例如:

# forward in a TransformerBlock
def forward(self, x):
    h = x + self.attention(self.attention_norm(x))
    out = h + self.feed_forward(self.ffn_norm(h))
    return out

在大多数使用场景中,激活值(和梯度)在 [batch size, sequence length, hidden dimension] 模块外部的形状为 AttentionFeedForward 模块。在DTensor的语言中,Sequence Parallel 使用 Shard(1) 布局对模块的前向/反向进行激活计算。 按照前面的代码示例,下面的代码演示了如何将Sequence Parallel应用于 TransformerBlock 中的归一化层:

首先,让我们导入序列并行所需的依赖项:

from torch.distributed.tensor.parallel import (
    PrepareModuleInput,
    SequenceParallel,
)

接下来让我们调整 layer_tp_plan 以在 RMSNorm 层启用序列并行:

layer_tp_plan = {
    # Now the input and output of SequenceParallel has Shard(1) layouts,
    # to represent the input/output tensors sharded on the sequence dimension
    "attention_norm": SequenceParallel(),
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "ffn_norm": SequenceParallel(),
    "feed_forward": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

可以看到,我们现在使用 PrepareModuleInput 来修改 Attention 和 FeedForward 层的模块输入布局,从 Shard(1) 改为 Replicate(),并标记它们的输出布局为 Shard(1)。 就像张量并行化所发生的情况一样,只需指定输入和输出的张量分片布局,层之间的通信将会自动完成。

注意,使用序列并行时,我们假设 TransformerBlock 的输入和输出始终在序列维度上进行分片,以便多个 TransformerBlocks 可以无缝拼接。 这可以通过显式指定开始 nn.Embedding 层的输出和最终 nn.Linear 投影层的输入为 Shard(1) 来实现:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            output_layouts=Replicate()
        ),
    }
)

应用损失并行

Loss Parallel 是一种相关技术,用于在计算损失函数时节省内存和通信开销,因为模型输出通常非常大。在 Loss Parallel 中,当模型输出在(通常非常大的)词汇维度上进行分片时,可以高效地计算交叉熵损失,而无需将所有模型输出收集到每个 GPU 上。这不仅显著减少了内存消耗,还通过减少通信开销并并行执行分片计算来提高训练速度。下图简要说明了 Loss Parallel 如何通过分片计算避免将所有模型输出收集到每个 GPU 上。

loss parallel

图2. 在一个GPU上进行交叉熵损失的前向计算,损失并行化。蓝色表示分片张量;绿色表示复制张量;黄色表示具有部分值的张量(需进行全局归约)。黑色箭头表示本地计算;红色箭头表示GPU之间的功能集合通信。

在PyTorch张量并行API中,可以通过上下文管理器 loss_parallel 启用损失并行,这样用户可以直接使用 torch.nn.functional.cross_entropytorch.nn.CrossEntropyLoss 而无需修改代码的其他部分。

应用损失并行时,模型预测结果,通常形状为 [batch size, sequence length, vocabulary size],应在词汇维度上进行分片。这可以通过标记最后一个线性投影层输出的输出布局来轻松实现:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            # use DTensor as the output
            use_local_output=False,
        ),
    },
)

在上面的代码中,我们还在输出前的norm层应用了序列并行。我们应用use_local_output=False,以使输出保持为DTensor,以便与loss_parallel上下文管理器配合使用。之后,可以像下面所示那样直接调用交叉熵损失函数。请注意,反向计算也需要在上下文中进行。

import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel

pred = model(input_ids)
with loss_parallel():
    # assuming pred and labels are of the shape [batch, seq, vocab]
    loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
    loss.backward()

将张量并行与完全分片数据并行结合在一起

既然我们已经展示了如何将张量并行/序列并行应用于模型,那么我们也来看看张量并行和全分片数据并行是如何协同工作的。 由于张量并行会引入阻塞计算的通信开销,我们希望确保它在快速通信通道上运行,例如 NVLink。 实际上,我们通常在每个主机内部应用张量并行,并在多个主机之间应用全分片数据并行。

fsdp + tp

图3. FSDP和TP分别作用于不同的设备维度,FSDP通信发生在主机之间,TP通信发生在同一主机内。

这种二维并行模式可以通过一个二维设备网格(DeviceMesh)轻松表达,我们只需将每个“子”设备网格传递给各个独立的并行化API:

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices

model = Model(...)

tp_plan = {...}

# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)

这将使我们能够轻松地在每个主机(intra-host)内应用张量并行,并在主机之间(inter-hosts)应用FSDP,而无需对Llama模型进行任何代码更改。 张量(模型)并行和数据并行技术的结合,使我们能够继续增加模型规模并高效地使用大量GPU进行训练。

结论

本教程演示了如何通过张量并行(Tensor Parallel)结合完全分片数据并行(Fully Sharded Data Parallel),在数百到数千个GPU上训练一个类似Transformer的大模型。 它解释了如何将张量并行应用于模型的不同部分,而无需对模型本身进行任何代码更改。张量并行是一种高效的模型并行技术,适用于大规模训练。

要查看本教程中完整解释的端到端代码示例,请参阅 pytorch/examples 仓库中的 张量并行示例

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源