目录

torchtune 中的检查点

此深入探讨将引导您了解 checkpointer 的设计和行为,以及 关联的实用程序。

本次深入探讨将涵盖什么:
  • torchtune 的 Checkpointer 设计

  • 检查点格式以及我们如何处理它们

  • 检查点场景:Intermediate vs Final 和 LoRA vs Full-finetune

概述

Torchtune 检查点设计为可插入的可组合组件 到任何配方中 - 训练、评估或生成。每个 checkpointer 都支持 一组模型和场景,使这些易于理解、调试和扩展。

在我们深入研究 torchtune 中的 checkpointer 之前,让我们定义一些概念。


检查点格式

在这次深入探讨中,我们将讨论不同的检查点格式以及 torchtune 如何处理它们。 让我们仔细看看这些不同的格式。

简单地说,检查点的格式由state_dict及其存储方式决定 在磁盘上的文件中。每个权重都与一个字符串键相关联,该键在 state dict 中标识它。 如果存储的 checkpoint 中 key 的字符串标识符不匹配 与模型定义中的参数完全相同,您将遇到显式错误(加载 state dict 将引发异常)或更糟 - 静默错误(加载将成功,但训练或 inference 将无法按预期工作)。除了排列的键之外,您还需要形状 的权重(state_dict中的值)与模型预期的值完全匹配 定义。

让我们看看 Llama2 的两种流行格式。

元格式

这是官方 Llama2 实现支持的格式。当您下载 Llama2 7B 模型时 从 Meta-LLAMA 网站,您将可以访问单个检查点文件。您可以使用.pthtorch.load

>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>    print(f'{key}: {value.shape}')

tok_embeddings.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
292

该state_dict包含 292 个键,包括一个名为 .这 state_dict模型定义需要一个嵌入层,其中每个标记都有一个 嵌入 dim 为 。tok_embeddings320004096

HF 格式

这是 Hugging Face Model Hub 中最受欢迎的格式,并且是 每个 torchtune 配置中的默认格式。这也是您下载 llama2 模型。

第一个很大的区别是 state_dict 被拆分为两个文件。要正确 load the checkpoint 中,您需要将这些文件拼凑在一起。让我们检查其中一个文件。.bin

>>> import torch
>>> state_dict = torch.load('pytorch_model-00001-of-00002.bin', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>     print(f'{key}: {value.shape}')

model.embed_tokens.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
241

state_dict不仅包含较少的键(意料之中,因为这是两个文件之一),而且 嵌入表称为 ,而不是 。这种不匹配 in names 将在您尝试加载 state_dict 时导致异常。此层的大小为 两者之间相同,正如预期的那样。model.embed_tokenstok_embeddings


正如你所看到的,如果你不小心,你很可能会在 checkpoint load 和 save。torchtune 检查点通过管理状态 dict 来降低出错率 给你的。torchtune 被设计为 “state-dict invariant” 。

  • 加载时,torchtune 接受来自多个来源的多种格式的检查点。 您不必担心每次运行配方时都显式转换检查点。

  • 保存时,torchtune 会生成与源相同格式的检查点。这包括 将 state_dict 转换回原始形式并拆分 Key 和 Weights 在相同数量的文件中。

“state-dict invariant”的一大优点是您应该能够使用 使用任何后训练工具(量化、评估、推理)从 Torchtune 微调检查点 它支持源格式,无需任何代码更改或转换脚本。这是 Torchtune 与周围生态系统互操作的方式。

为了实现“state-dict invariant”,and 方法使用了此处提供的权重转换器。load_checkpointsave_checkpoint


处理不同的 Checkpoint 格式

Torchtune 支持三种不同的 Checkpointer, 每个 Checkpoint 都支持不同的 checkpoint 格式。

HFCheckpointer

此 checkpointer 以与 transformer 兼容的格式读取和写入 checkpoint 框架。如上所述,这是 Hugging Face 中最流行的格式 Model Hub 的 ,是每个 torchtune 配置中的默认格式。

为了使此 checkpointer 正常工作,我们假设它包含必要的 checkpoint 和 json 文件。确保一切正常的最简单方法是使用以程:checkpoint_dir

  • 使用 tune download 从 HF 存储库下载模型。默认情况下,这将忽略 “safetensors” 文件。


    tune download meta-llama/Llama-2-7b-hf \
    --output-dir <checkpoint_dir> \
    --hf-token <hf-token>
    
  • 使用 specified here 作为 checkpointer 的参数。output_dircheckpoint_dir


以下代码片段说明了如何在 torchtune 配置文件中设置 HFCheckpointer。

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama2-7b-hf model we have
    # 2 .bin files. The checkpointer takes care of sorting
    # by id and so the order here does not matter
    checkpoint_files: [
        pytorch_model-00001-of-00002.bin,
        pytorch_model-00002-of-00002.bin,
    ]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state. More on this in the
    # next section
    recipe_checkpoint: null

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: False

注意

与 HF 格式之间的检查点转换需要访问模型参数,这些参数是 直接从文件中读取。这有助于确保我们加载权重 正确或错误,如果 HF 检查点文件和 torchtune 的 模型实现。此 json 文件与模型检查点一起从应用中心下载。 有关在转换过程中如何使用这些参数的更多详细信息,请参阅此处config.json


MetaCheckpointer 元校验指针

这个 checkpointer 以与原始 meta-lla 兼容的格式读取和写入 checkpoint GitHub 存储库。

为了使此 checkpointer 正常工作,我们假设它包含必要的 checkpoint 和 json 文件。确保一切正常的最简单方法是使用以程:checkpoint_dir

  • 使用 tune download 从 HF 存储库下载模型。默认情况下,这将忽略 “safetensors” 文件。


    tune download meta-llama/Llama-2-7b \
    --output-dir <checkpoint_dir> \
    --hf-token <hf-token>
    
  • 将 above 用作 checkpointer 的 。output_dircheckpoint_dir


以下代码片段说明了如何在 torchtune 配置文件中设置 MetaCheckpointer。

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelMetaCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama2-7b model we have
    # a single .pth file
    checkpoint_files: [consolidated.00.pth]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state. More on this in the
    # next section
    recipe_checkpoint: null

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: False

TorchTuneCheck指针

此 checkpointer 以与 torchtune 的 模型定义。这不会执行任何state_dict转换,当前正在使用 用于测试或加载量化模型以进行生成。


中间检查点与最终检查点

torchtune 检查点支持两种检查点方案:

训练结束检查点

模型在完成训练结束时进行加权 run 的 RUN 写入 file。检查点程序确保输出检查点 文件与用于开始训练的输入检查点文件具有相同的键。这 Checkpointer 还确保 Key 在相同数量的 files 作为原始检查点。输出状态 dict 具有以下 标准格式:

{
    "key_1": weight_1,
    "key_2": weight_2,
    ...
}

训练中期 Chekpointing

如果在训练过程中执行 checkpointing,则输出 checkpoint 需要存储额外的 信息,以确保后续训练运行可以正确重启。除了 模型检查点文件中,我们输出一个文件用于 intermediate 检查站。这些当前在每个 epoch 结束时输出,并包含信息 例如优化器状态、已完成的 epoch 数等。recipe_state.pt

为了防止我们被 checkpoint 文件淹没,配方状态为 在每个 epoch 结束时覆盖。output_dir

输出状态 dict 具有以下格式:

Model:
    {
        "key_1": weight_1,
        "key_2": weight_2,
        ...
    }

Recipe State:
    {
        "optimizer": ...,
        "epoch": ...,
        ...
    }

要从以前的检查点文件重新启动,您需要进行以下更改 到配置文件

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    checkpoint_files: [
        hf_model_0001_0.pt,
        hf_model_0002_0.pt,
    ]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state
    recipe_checkpoint: recipe_state.pt

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True

LoRA 的检查点

在 torchtune 中,我们输出适配器权重和完整的模型 “合并” 权重 对于 LoRA。“merged” 检查点可以像使用 source 一样使用 checkpoint 替换为任何后训练工具。有关更多详细信息,请查看我们的 LoRA 微调教程

这两个用例之间的主要区别在于何时需要恢复训练 从检查点。在这种情况下,checkpointer 需要访问初始 frozen Base Model 权重以及 Learnt Adapter 权重。此方案的配置 看起来像这样:

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. This is the ORIGINAL frozen checkpoint
    # and NOT the merged checkpoint output during training
    checkpoint_files: [
        pytorch_model-00001-of-00002.bin,
        pytorch_model-00002-of-00002.bin,
    ]

    # this refers to the adapter weights learnt during training
    adapter_checkpoint: adapter_0.pt

    # the file with the checkpoint state
    recipe_checkpoint: recipe_state.pt

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True

把这些放在一起

现在让我们把所有这些知识放在一起吧!我们将加载一些 checkpoints, 创建一些模型并运行一个简单的 forward。

在本节中,我们将使用 HF 格式的 Llama2 13B 模型。

import torch
from torchtune.utils import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b

# Set the right directory and files
checkpoint_dir = 'Llama-2-13b-hf/'
pytorch_files = [
    'pytorch_model-00001-of-00003.bin',
    'pytorch_model-00002-of-00003.bin',
    'pytorch_model-00003-of-00003.bin'
]

# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir=checkpoint_dir,
    model_type=ModelType.LLAMA2
)
torchtune_sd = checkpointer.load_checkpoint()

# Setup the model and the input
model = llama2_13b()

# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
<All keys matched successfully>

# We have 32000 vocab tokens; lets generate an input with 70 tokens
x = torch.randint(0, 32000, (1, 70))

with torch.no_grad():
    model(x)

tensor([[[ -6.3989,  -9.0531,   3.2375,  ...,  -5.2822,  -4.4872,  -5.7469],
    [ -8.6737, -11.0023,   6.8235,  ...,  -2.6819,  -4.2424,  -4.0109],
    [ -4.6915,  -7.3618,   4.1628,  ...,  -2.8594,  -2.5857,  -3.1151],
    ...,
    [ -7.7808,  -8.2322,   2.8850,  ...,  -1.9604,  -4.7624,  -1.6040],
    [ -7.3159,  -8.5849,   1.8039,  ...,  -0.9322,  -5.2010,  -1.6824],
    [ -7.8929,  -8.8465,   3.3794,  ...,  -1.3500,  -4.6145,  -2.5931]]])

您可以使用 torchtune 支持的任何模型执行此操作。您可以找到完整列表 模型和模型生成器

我们希望这次深入探讨能让您更深入地了解 checkpointer 和 Torchtune 中的关联实用程序。调音愉快!

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源