目录

Checkpointing in torchtune

本深入讲解将引导您了解检查点保存器的设计与行为,以及相关工具。

本深入探讨将涵盖以下内容:
  • torchtune 的检查点设计

  • 检查点格式及其处理方式

  • 检查点场景:中间与最终,以及 LoRA 与全量微调

概览

torchtune 检查点管理器被设计为可组合组件,可插入任何训练、评估或生成流程中。每个检查点管理器支持一组模型和场景,使其易于理解、调试和扩展。

在深入探讨 torchtune 中的检查点保存器之前,让我们先定义一些概念。


检查点格式

在本期深度解析中,我们将探讨不同的检查点格式以及 torchtune 如何处理它们。 让我们仔细看看这些不同的格式。

简单来说,检查点的格式由 state_dict 及其在磁盘文件中的存储方式决定。每个权重都关联一个字符串键,用于在 state_dict 中标识它。如果已保存检查点中键的字符串标识符与模型定义中的不完全匹配,您将遇到显式错误(加载 state_dict 会抛出异常),或者更糟糕的情况——静默错误(加载成功,但训练或推理无法按预期工作)。除了键需要对应之外,权重的形状(state_dict 中的值)也必须与模型定义所期望的形状完全一致。

让我们来看看 Llama 3.2 的两种流行格式。

元格式

这是官方 Llama 3.2 实现支持的格式。当你从 meta-llama 网站 下载 Llama 3.2 3B 模型时,你会获得一个 .pth 检查点文件。你可以轻松地使用 torch.load 查看此检查点的内容。

>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>    print(f'{key}: {value.shape}')

tok_embeddings.weight: torch.Size([128256, 3072])
...
...
>>> print(len(state_dict.keys()))
255

state_dict 包含 255 个键,其中包括一个名为 tok_embeddings 的输入嵌入表。此 state_dict 的模型定义期望一个嵌入层,其中包含 128256 个标记,每个标记的嵌入维度为 3072

Hugging Face 格式

这是Hugging Face Model Hub中最流行的格式,并且是每个torchtune配置中的默认格式。当你从Llama-3.2-3B-Instruct仓库下载llama3.2模型时,也会得到这种格式。

第一个主要区别是 state_dict 被拆分到两个 .safetensors 文件中。要正确加载检查点,你需要将这些文件拼接起来。让我们检查其中一个文件。

>>> from safetensors import safe_open
>>> state_dict = {}
>>> with safe_open("model-00001-of-00002.safetensors", framework="pt", device="cpu") as f:
>>>     for k in f.keys():
>>>         state_dict[k] = f.get_tensor(k)

>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>     print(f'{key}: {value.shape}')

model.embed_tokens.weight: torch.Size([128256, 3072])
...
...
>>> print(len(state_dict.keys()))
187

不仅 state_dict 包含的键更少(由于这是两个文件中的一个,因此预期如此),而且嵌入表被命名为 model.embed_tokens 而不是 tok_embeddings。当您尝试加载 state_dict 时,名称不匹配会导致异常。这两个层的大小是相同的,这与预期一致。


如您所见,如果不加小心,您很可能在检查点的加载和保存过程中犯下诸多错误。torchtune 的检查点管理器通过为您管理状态字典,降低了出错风险。torchtune 的设计目标是“与状态字典无关”。

  • 加载时,torchtune 支持从多种来源以多种格式读取检查点。 您无需在每次运行配方时都显式地转换检查点。

  • 在保存时,torchtune 会以与源文件相同的格式生成检查点。这包括将 state_dict 转换回原始形式,并将键和权重拆分到相同数量的文件中。

“状态字典不变性”的一大优势在于,您可以直接使用来自 torchtune 的微调检查点,配合任何支持源格式的后期训练工具(如量化、评估、推理),而无需进行任何代码修改或转换脚本。这是 torchtune 与周边生态系统实现互操作的方式之一。

注意

为了使状态字典(state-dict)保持“不变”,每个检查点保存器(checkpointer)的 load_checkpointsave_checkpoint 方法都使用了权重转换器,这些转换器能够正确地在检查点格式之间映射权重。例如,在加载来自 Hugging Face 的权重时,我们在加载和保存时会对某些权重应用排列操作,以确保检查点的行为完全相同。为了进一步说明这一点,Llama 系列模型使用了一个 通用权重转换函数 ,而像 Phi3 这样的其他模型则有自己的 转换函数 ,这些函数可以在它们的模型文件夹中找到。


处理不同的检查点格式

torchtune 支持三种不同的 检查点程序, 每种都支持不同的检查点格式。

HFCheckpointer

此检查点读取器以与 Hugging Face 的 transformers 框架兼容的格式读写检查点。如上所述,这是 Hugging Face Model Hub 中最流行的格式,也是每个 torchtune 配置中的默认格式。

为了使此检查点程序正常工作,我们假设 checkpoint_dir 包含必要的检查点和 JSON 文件。确保一切正常工作的最简单方法是使用以下流程:

  • 使用 tune download 从 HF 仓库下载模型。这将忽略"pth"文件,因为我们将加载"safetensors"文件。


    tune download meta-llama/Llama-3.2-3B-Instruct \
    --output-dir /tmp/Llama-3.2-3B-Instruct \
    --ignore-patterns "original/consolidated.00.pth"
    
  • 使用 output_dir 作为此处指定的检查点程序的 checkpoint_dir 参数。


以下代码片段说明了如何在 torchtune 配置文件中设置 HFCheckpointer。

checkpointer:

    # checkpointer to use
    _component_: torchtune.training.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the folder you used when downloading the model
    checkpoint_dir: /tmp/Llama-3.2-3B-Instruct

    # checkpoint files. For the Llama-3.2-3B-Instruct model we have
    # 2 .safetensor files. The checkpointer takes care of sorting
    # by id and so the order here does not matter
    checkpoint_files: [
        model-00001-of-00002.safetensors,
        model-00002-of-00002.safetensors,
    ]

    # dir for saving the output checkpoints
    output_dir: <output_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA3_2

# set to True if restarting training. More on that later.
resume_from_checkpoint: False

注意

检查点在HF格式与torchtune格式之间转换需要访问模型参数,这些参数直接从config.json文件中读取。这有助于确保我们正确加载权重,或者在HF检查点文件与torchtune的模型实现之间存在差异时出现错误。此JSON文件与模型检查点一起从Hub下载。


MetaCheckpointer

此检查点读取器和写入器使用的格式与原始 meta-llama GitHub 仓库兼容。

为了使此检查点程序正常工作,我们假设 checkpoint_dir 包含必要的检查点和 JSON 文件。确保一切正常工作的最简单方法是使用以下流程:

  • 使用 tune download 从 HF 仓库下载模型。默认情况下,这将忽略"safetensors"文件。


    tune download meta-llama/Llama-3.2-3B-Instruct \
    --output-dir /tmp/Llama-3.2-3B-Instruct \
    --ignore-patterns "*.safetensors"
    
  • 使用 output_dir 作为检查点程序的 checkpoint_dir


以下代码片段说明了如何在 torchtune 配置文件中设置 MetaCheckpointer。

checkpointer:

    # checkpointer to use
    _component_: torchtune.training.FullModelMetaCheckpointer

    # directory with the checkpoint files
    # this should match the folder you used when downloading the model
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama3.2 3B model we have
    # a single .pth file
    checkpoint_files: [consolidated.00.pth]

    # dir for saving the output checkpoints.
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA3_2

# set to True if restarting training. More on that later.
resume_from_checkpoint: False

TorchTuneCheckpointer

此检查点读取器和写入器以与 torchtune 模型定义兼容的格式读写检查点。它不执行任何 state_dict 转换,目前仅用于测试或加载用于生成的量化模型。


Checkpoint Output

恭喜你走到这一步!假设你已经按照我们的使用 torchtune 的端到端工作流,并使用我们其中一个 LoRA 配方训练了一个 llama 3.2 3B。

现在让我们可视化输出。一种简单的方法是运行 tree -a path/to/outputdir,这应该会显示如下所示的树。 有 3 种类型的文件夹:

  1. recipe_state:保存 recipe_state.pt,其中包含从最后一个中间 epoch 重新启动训练运行所需的信息。稍后详细介绍;

  2. logs: 您的 metric_logger 的输出(如果有);

  3. epoch_{}:包含您训练好的模型权重及模型元数据。如果运行推理或将模型推送到模型库,应直接使用此文件夹;

注意

对于每个 epoch,我们复制原始检查点文件夹的内容,但排除原始检查点和大型文件。 这些文件体积小巧,主要是配置文件,便于用户在下游应用中直接使用 epoch 文件夹。

有关每个文件的更多详细信息,请参阅上述端到端教程。

>>> tree -a /tmp/torchtune/llama3_2_3B/lora_single_device
/tmp/torchtune/llama3_2_3B/lora_single_device
├── epoch_0
│   ├── adapter_config.json
│   ├── adapter_model.pt
│   ├── adapter_model.safetensors
│   ├── config.json
│   ├── ft-model-00001-of-00002.safetensors
│   ├── ft-model-00002-of-00002.safetensors
│   ├── generation_config.json
│   ├── LICENSE.txt
│   ├── model.safetensors.index.json
│   ├── original
│      ├── orig_params.json
│      ├── params.json
│      └── tokenizer.model
│   ├── original_repo_id.json
│   ├── README.md
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.json
│   └── USE_POLICY.md
├── epoch_1
│   ├── adapter_config.json
│   ├── adapter_model.pt
│   ├── adapter_model.safetensors
│   ├── config.json
│   ├── ft-model-00001-of-00002.safetensors
│   ├── ft-model-00002-of-00002.safetensors
│   ├── generation_config.json
│   ├── LICENSE.txt
│   ├── model.safetensors.index.json
│   ├── original
│      ├── orig_params.json
│      ├── params.json
│      └── tokenizer.model
│   ├── original_repo_id.json
│   ├── README.md
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.json
│   └── USE_POLICY.md
├── logs
│   └── log_1734652101.txt
└── recipe_state
    └── recipe_state.pt

Intermediate vs Final Checkpoints

torchtune 检查点器支持两种检查点场景:

训练结束时的检查点保存

训练运行完成后,模型权重会被写入文件。检查点器确保输出检查点文件的键与用于开始训练的输入检查点文件的键相同。检查点器还确保这些键被划分到与原始检查点相同数量的文件中。输出状态字典具有以下标准格式:

{
    "key_1": weight_1,
    "key_2": weight_2,
    ...
}

训练中途检查点

如果在训练过程中进行检查点保存,输出的检查点需要存储额外的信息,以确保后续的训练运行可以正确地重新启动。除了模型检查点文件外,我们还会输出一个recipe_state.pt文件用于中间检查点。这些文件目前会在每个 epoch 结束时输出,并包含优化器状态、已完成的 epoch 数量等信息。

为了防止我们向 output_dir 中填充过多的检查点文件,每个 epoch 结束时都会覆盖配方状态。

输出状态字典具有以下格式:

Model:
    {
        "key_1": weight_1,
        "key_2": weight_2,
        ...
    }

Recipe State:
    {
        "optimizer": ...,
        "epoch": ...,
        ...
    }

从检查点恢复 - 完整微调

有时我们的训练会因某些原因中断。要从之前的检查点文件恢复训练,您需要在配置中更新以下字段:

resume_from_checkpoint: 将其设置为 True;

checkpoint_files: 将路径更改为 epoch_{YOUR_EPOCH}/ft-model={}-of-{}.safetensors;

请注意,我们更改 checkpoint_dir 或 output_dir。由于我们正在从检查点恢复,我们知道在 output_dir 中查找它。

checkpointer:
    # checkpoint files. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    checkpoint_files: [
        epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
        epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
    ]

# set to True if restarting training
resume_from_checkpoint: True

从检查点恢复 - LoRA 微调

与完全微调类似,我们只需要修改两个字段:resume_from_checkpointadapter_checkpoint,它们将从 output_dir 加载。我们不需要修改 checkpoint_files,因为加载的基础模型仍然是相同的。

checkpointer:

    # adapter_checkpoint. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    adapter_checkpoint: epoch_{YOUR_EPOCH}/adapter_model.safetensors

# set to True if restarting training
resume_from_checkpoint: True

# set to True to save only the adapter weights
# it does not influence resuming_from_checkpointing
save_adapter_weights_only: False

注意

在torchtune中,我们输出LoRA的适配器权重和完整模型合并后的权重。合并检查点是一个便利选项,因为它可以在没有特殊工具处理适配器的情况下使用。然而,在恢复训练时,它们**不应**被使用,因为加载合并权重加上适配器会导致错误。因此,当为LoRA恢复训练时,我们将从检查点目录中获取原始未训练权重,并从输出目录中获取训练好的适配器。更多细节,请查看我们的LoRA微调教程

注意

此外,通过设置选项 save_adapter_weights_only,您可以选择仅保存适配器权重。 这可以减少保存检查点所需的存储空间和时间,但对从检查点恢复没有影响。


将这些内容整合在一起

现在让我们将所有这些知识整合起来!我们将加载一些检查点,创建一些模型并运行一个简单的前向传播。

在本节中,我们将使用 HF 格式的 Llama-3.2-3B-Instruct 模型。

import torch
from torchtune.models.llama3_2 import llama3_2_3b
from torchtune.training import FullModelHFCheckpointer

# Set the right directory and files
checkpoint_dir = "/tmp/Llama-3.2-3B-Instruct/"
output_dir = "/tmp/torchtune/llama3_2_3B/full_single_device"

pytorch_files = [
    "model-00001-of-00002.safetensors",
    "model-00002-of-00002.safetensors",
]

# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir=output_dir,
    model_type="LLAMA3_2",
)
torchtune_sd = checkpointer.load_checkpoint()

# Setup the model and the input
model = llama3_2_3b()

# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
model.to("cuda")

# We have 128256 vocab tokens; lets generate an input with 24 tokens
x = torch.randint(0, 128256, (1, 24), dtype=torch.long, device="cuda")

tensor([[[ 1.4299,  1.1658,  4.2459,  ..., -2.3259, -2.3262, -2.3259],
        [ 6.5942,  7.2284,  2.4090,  ..., -6.0129, -6.0121, -6.0127],
        [ 5.6462,  4.8787,  4.0950,  ..., -4.6460, -4.6455, -4.6457],
        ...,
        [-0.4156, -0.0626, -0.0362,  ..., -3.6432, -3.6437, -3.6427],
        [-0.5679, -0.6902,  0.5267,  ..., -2.6137, -2.6138, -2.6127],
        [ 0.3688, -0.1350,  1.1764,  ..., -3.4563, -3.4565, -3.4564]]],
    device='cuda:0')

您可以使用torchtune支持的任何模型完成此操作。您可以在 此处 找到完整模型列表和模型构建器。

我们希望此次深入探讨能让您对 torchtune 中的检查点保存器及相关工具获得更深刻的理解。祝您调参顺利!

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源