使用全分片数据并行 (FSDP) 的高级模型训练¶
创建于: 2024年10月31日 |上次更新时间:2024 年 10 月 31 日 |上次验证: Nov 05, 2024
作者: Hamid Shojanazeri, Less 赖特, 罗汉·瓦尔玛, 赵艳丽
PyTorch 的全分片数据并行模块:分片模块参数的包装器
数据并行工作程序。
PyTorch 1.12 或更高版本
阅读有关 FSDP API 的信息。
本教程介绍了 Fully Sharded Data Parallel 的更多高级功能 (FSDP) 作为 PyTorch 1.12 版本的一部分。要熟悉 FSDP,请 请参阅 FSDP 入门教程。
在本教程中,我们将使用 FSDP 对文本微调 HuggingFace (HF) T5 模型 总结作为一个工作示例。
该示例使用 Wikihow,为简单起见,我们将在 具有 8 个 A100 GPU 的单节点 P4dn 实例。我们现在有几篇博客文章 ( (link1), (link2)) 以及一篇关于 在多节点集群上进行大规模 FSDP 训练。
FSDP 是一个生产就绪型软件包,专注于易用性、性能和 长期支持。FSDP 的主要好处之一是减少内存 占用空间。这样可以训练具有较低总数的较大模型 memory 与 DDP 的 v,并利用计算和通信的重叠来 高效训练模型。 这种降低的内存压力可用于训练更大的模型或 增加批量大小,可能有助于整体训练吞吐量。您可以 在此处阅读有关 PyTorch FSDP 的更多信息。
本教程中的 FSDP 功能¶
Transformer 自动换行策略
混合精度
在设备上初始化 FSDP 模型
分片策略
向后预取
通过流式传输到 CPU 保存模型检查点
FSDP 工作原理回顾¶
概括地说,FDSP 的工作原理如下:
在构造函数中
分片模型参数和每个排名仅保留自己的分片
在前向传递中
运行 all_gather 以收集所有等级的所有分片以恢复全部 参数并运行前向计算
丢弃刚刚收集的非拥有参数分片以释放内存
在向后传递中
运行 all_gather 以收集所有等级的所有分片以恢复全部 参数并运行反向计算
丢弃非拥有的参数以释放内存。
运行 reduce_scatter 以同步渐变
微调 HF T5¶
HF T5 预训练模型有四种不同的尺寸,包括 具有 6000 万个参数的小型到具有 110 亿个参数的 XXL。在这个 教程中,我们演示了使用 FSDP 对文本的 T5 3B 进行微调 使用 WikiHow 数据集进行总结。本教程的主要重点是 突出显示 FSDP 中有助于培训的不同可用功能 3B 参数以上的大比例模型。此外,我们还介绍了以下特定功能 基于 Transformer 的模型。本教程的代码在 Pytorch 中可用 示例。
设置
1.1 安装最新的 PyTorch
pip3 install torch torchvision torchaudio
1.2 数据集设置
请创建一个数据文件夹,从 wikihowAll.csv 和 wikihowSep.cs 下载 WikiHow 数据集, 并将它们放置在 data 文件夹中。我们将使用来自 summarization_dataset 的 wikihow 数据集。
接下来,我们将以下代码片段添加到 Python 脚本 “T5_training.py” 中。
注意
本教程的完整源代码可在 PyTorch 示例中找到。
1.3 导入必要的包:
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing_wrapper)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
1.4 分布式训练设置。 这里我们使用两个辅助函数来初始化分布式进程 training 进行清理,然后在 Training 完成后进行清理。在本教程中,我们将 将使用 Torch elastic,使用 torchrun ,这将设置 worker RANK 和 WORLD_SIZE 自动。
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
2.1 设置 HuggingFace T5 模型:
def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
return model, tokenizer
我们还在此处添加了几个辅助函数,用于日期和格式化内存 指标。
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
2.2 定义 train 函数:
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank==0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
2.3 定义验证函数:
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
2.4 定义一个分布式训练函数,将模型包装在 FSDP 中:
def fsdp_main(args):
model, tokenizer = setup_model("t5-base")
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dataset = load_dataset('wikihow', 'all', data_dir='data/')
print(dataset.keys())
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
setup()
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)
#init_start_event = torch.cuda.Event(enable_timing=True)
#init_end_event = torch.cuda.Event(enable_timing=True)
#init_start_event.record()
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
if bf16_ready:
mp_policy = bfSixteen
else:
mp_policy = None # defaults to fp32
# model is on CPU before input to FSDP
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
#sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device())
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "T5-model-"
if rank == 0:
time_of_run = get_date_of_run()
dur = []
train_acc_tracking = []
val_acc_tracking = []
training_start_time = time.time()
if rank == 0 and args.track_memory:
mem_alloc_tracker = []
mem_reserved_tracker = []
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if args.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if args.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
print(f"completed save and stats zone...")
if args.save_model and curr_val_loss < best_val_loss:
# save
if rank == 0:
print(f"--> entering save model state")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
#print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
print(f"--> attempting to save model prefix {currEpoch}")
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
print(f"--> saving as model name {save_name}")
torch.save(cpu_state, save_name)
if curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
if rank==0:
print(f"-->>>> New Val Loss Record: {best_val_loss}")
dist.barrier()
cleanup()
2.5 解析参数并设置 main 函数:
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=.002, metavar='LR',
help='learning rate (default: .002)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--track_memory', action='store_false', default=True,
help='track the gpu memory')
parser.add_argument('--run_validation', action='store_false', default=True,
help='running the validation')
parser.add_argument('--save-model', action='store_false', default=True,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
fsdp_main(args)
要使用 torchrun 运行训练:
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
Transformer 包装策略¶
如上一个教程中所述, auto_wrap_policy 是 FSDP 的其中一项功能,可以轻松实现自动 对给定模型进行分片,并将模型、优化器和梯度分片放入 不同的 FSDP 单位。
对于某些架构,例如 Transformer 编码器-解码器,某些部分的 Embedding Table 等模型正在与 Encoder 和 Decoder 共享。在 在这种情况下,我们需要将嵌入表放在外部 FSDP 单元中,以便 它可以从 Encoder 和 Decoder 访问。此外,通过注册 Transformer 的 Layer 类,分片计划可以做得更多 沟通高效。在 PyTorch 1.12 中,FSDP 添加了此支持,现在我们 为 Transförms 制定 wrapping 策略。
可以按如下方式创建,其中 T5Block 表示 T5 变压器 layer 类(包含 MHSA 和 FFN)。
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy)
要查看包装后的模型,您可以轻松打印模型并目视检查 分片和 FSDP 单元。
混合精度¶
FSDP 支持灵活的混合精度训练,允许任意缩减 精度类型(如 FP16 或 BFLOAT16)。目前 BFloat16 仅可用 在 Ampere GPU 上,因此您需要在使用之前确认本机支持。上 例如,V100 仍然可以运行,但由于它是非本机运行的,因此 它可能会导致显着的减速。
要检查 BFloat16 是否原生受支持,您可以使用以下内容:
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
FSDP 中混合精度的优势之一是提供精细控制 参数、梯度和缓冲区的不同精度级别为 遵循:
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
# Gradient communication precision.
reduce_dtype=torch.float32,
# Buffer precision.
buffer_dtype=torch.float32,
)
请注意,如果未指定特定类型 (parameter, reduce, buffer),则它们 不会被施放。
这种灵活性允许用户进行精细控制,例如仅设置 梯度通信以降低的精度发生,并且所有参数 / buffer 计算以全精度完成。这在 节点内通信是主要瓶颈和参数的情况 / 缓冲区必须为全精度,以避免出现精度问题。这是可以做到的 使用以下策略:
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
在 2.4 中,我们只是将相关的混合精度策略添加到 FSDP 包装器中:
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen)
在我们的实验中,我们观察到使用 BFloat16 将 训练和记忆减少约 30% 在一些实验中可以 用于增加批量大小。
在设备上初始化 FSDP 模型¶
在 1.12 中,FSDP 支持用于初始化输入 CPU 的 device_id 参数 模块device_id。当整个模型 不适合单个 GPU,但适合主机的 CPU 内存。指定 device_id 后,FSDP 会将模型移动到每个 FSDP 上的指定设备 单位基础,避免 GPU OOM 问题,同时初始化速度比 基于 CPU 的初始化:
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device())
分片策略¶
FSDP 分片策略默认设置为对模型参数进行全分片, 梯度和优化器状态在所有等级之间进行分片。(也称为 Zero3 分片)。如果您有兴趣使用 Zero2 分片策略,其中 只有优化器状态和梯度被分片,FSDP 通过以下方式支持此功能 使用 “ShardingStrategy.SHARD_GRAD_OP” 传递 Sharding 策略, 而不是 “ShardingStrategy.FULL_SHARD” 添加到 FSDP 初始化,如下所示:
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)
这将减少 FSDP 中的通信开销,在本例中,它保持完整 参数。
这样可以在向后保存all_gather,从而减少 较高内存占用的成本。请注意,完整的模型参数在 END of BACKWARDS 和 all_gather 将发生在下一次 FORWARD 传递中。
向后预取¶
向后预取设置控制下一个 FSDP 单元的 参数。通过将其设置为 BACKWARD_PRE,下一个 FSDP 的 unit params 可以在 当前单位的计算开始。这与 all_gather 通信和梯度计算重叠,可以提高 换取略高的内存消耗。它可以在 FSDP 中使用 wrapper 中,如下所示:
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
backward_prefetch 有两种模式,BACKWARD_PRE 和 BACKWARD_POST。BACKWARD_POST 表示不会请求下一个 FSDP 单元的参数 直到当前 FSDP 单元处理完成,从而最大限度地减少内存 开销。在某些情况下,使用 BACKWARD_PRE 可以提高模型训练速度 高达 2-10%,对于较大的模型,速度甚至更高。
模型检查点保存,通过流式传输到 Rank0 CPU¶
要使用 FULL_STATE_DICT 保存模型检查点,请将模型保存在 与本地模型相同,PyTorch 1.12 提供了一些实用程序来支持 保存更大的模型。
首先,可以指定 FullStateDictConfig,允许state_dict 仅在 rank 0 上填充并卸载到 CPU。
使用此配置时,FSDP 将全部收集模型参数,卸载 它们一个接一个地连接到 CPU,仅在 rank 0 上。当state_dict终于 saved,则它只会在排名 0 上填充并包含 CPU 张量。这样可以避免 对于大于单个 GPU 内存的模型,可能 OOM,并允许 users 添加到检查点模型,其大小大致是 用户的机器。
此功能可以按如下方式运行:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
if rank == 0:
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
torch.save(cpu_state, save_name)
总结¶
在本教程中,我们介绍了 FSDP 的许多新功能,包括 Pytorch 1.12 的 Pytorch 1.12 中,使用 HF T5 作为运行示例。使用适当的包装 策略,以及混合精度和 向后预取应该可以加快您的训练运行速度。此外,还有 在设备上初始化模型,并通过流式传输到 CPU 来保存 checkpoint 应该有助于避免在处理大型模型时出现 OOM 错误。
我们正在积极努力为下一版本的 FSDP 添加新功能。如果 您有反馈、功能请求、问题或遇到问题 使用 FSDP,请随时通过在 PyTorch Github 存储库中打开一个问题与我们联系。