Checkpointing in torchtune¶
本深入讲解将引导您了解检查点保存器的设计与行为,以及相关工具。
torchtune 的检查点设计
检查点格式及其处理方式
检查点场景:中间与最终,以及 LoRA 与全量微调
概览¶
torchtune 检查点管理器被设计为可组合组件,可插入任何训练、评估或生成流程中。每个检查点管理器支持一组模型和场景,使其易于理解、调试和扩展。
在深入探讨 torchtune 中的检查点保存器之前,让我们先定义一些概念。
检查点格式¶
在本期深度解析中,我们将探讨不同的检查点格式以及 torchtune 如何处理它们。 让我们仔细看看这些不同的格式。
简单来说,检查点的格式由 state_dict 及其在磁盘文件中的存储方式决定。每个权重都关联一个字符串键,用于在 state_dict 中标识它。如果已保存检查点中键的字符串标识符与模型定义中的不完全匹配,您将遇到显式错误(加载 state_dict 会抛出异常),或者更糟糕的情况——静默错误(加载成功,但训练或推理无法按预期工作)。除了键需要对应之外,权重的形状(state_dict 中的值)也必须与模型定义所期望的形状完全一致。
让我们来看看 Llama2 的两种流行格式。
元格式
这是官方 Llama2 实现支持的格式。当你从 meta-llama 网站 下载 Llama2 7B 模型时,你会获得一个
.pth 检查点文件。你可以轻松地使用 torch.load 查看此检查点的内容。
>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>> print(f'{key}: {value.shape}')
tok_embeddings.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
292
state_dict 包含 292 个键,其中包括一个名为 tok_embeddings 的输入嵌入表。此 state_dict 的模型定义期望一个嵌入层,该层包含 32000 个标记,每个标记的嵌入维度为 4096。
Hugging Face 格式
这是Hugging Face Model Hub中最流行的格式,并且是每个torchtune配置中的默认格式。当你从Llama-2-7b-hf仓库下载llama2模型时,你也会得到这种格式。
第一个主要区别是 state_dict 被拆分到两个 .bin 文件中。要正确加载检查点,你需要将这些文件拼接起来。让我们检查其中一个文件。
>>> import torch
>>> state_dict = torch.load('pytorch_model-00001-of-00002.bin', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>> print(f'{key}: {value.shape}')
model.embed_tokens.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
241
不仅 state_dict 包含的键更少(由于这是两个文件中的一个,因此预期如此),而且嵌入表被命名为 model.embed_tokens 而不是 tok_embeddings。当您尝试加载 state_dict 时,名称不匹配会导致异常。这两个层的大小是相同的,这与预期一致。
如您所见,如果不加小心,您很可能在检查点的加载和保存过程中犯下诸多错误。torchtune 的检查点管理器通过为您管理状态字典,降低了出错风险。torchtune 的设计目标是“与状态字典无关”。
加载时,torchtune 支持从多种来源以多种格式读取检查点。 您无需在每次运行配方时都显式地转换检查点。
在保存时,torchtune 会以与源文件相同的格式生成检查点。这包括将 state_dict 转换回原始形式,并将键和权重拆分到相同数量的文件中。
“状态字典不变性”的一大优势在于,您可以直接使用来自 torchtune 的微调检查点,配合任何支持源格式的后期训练工具(如量化、评估、推理),而无需进行任何代码修改或转换脚本。这是 torchtune 与周边生态系统实现互操作的方式之一。
注意
为了使状态字典(state-dict)保持“不变”,每个检查点保存器(checkpointer)的 load_checkpoint 和 save_checkpoint 方法都使用了权重转换器,这些转换器能够正确地在检查点格式之间映射权重。例如,在加载来自 Hugging Face 的权重时,我们在加载和保存时会对某些权重应用排列操作,以确保检查点的行为完全相同。为了进一步说明这一点,Llama 系列模型使用了一个
通用权重转换函数
,而像 Phi3 这样的其他模型则有自己的 转换函数
,这些函数可以在它们的模型文件夹中找到。
处理不同的检查点格式¶
torchtune 支持三种不同的 检查点程序, 每种都支持不同的检查点格式。
HFCheckpointer¶
此检查点读取器以与 Hugging Face 的 transformers 框架兼容的格式读写检查点。如上所述,这是 Hugging Face Model Hub 中最流行的格式,也是每个 torchtune 配置中的默认格式。
为了使此检查点程序正常工作,我们假设 checkpoint_dir 包含必要的检查点和 JSON 文件。确保一切正常工作的最简单方法是使用以下流程:
使用 tune download 从 HF 仓库下载模型。默认情况下,这将忽略"safetensors"文件。
tune download meta-llama/Llama-2-7b-hf \ --output-dir <checkpoint_dir> \ --hf-token <hf-token>
使用
output_dir作为此处指定的检查点程序的checkpoint_dir参数。
以下代码片段说明了如何在 torchtune 配置文件中设置 HFCheckpointer。
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>
# checkpoint files. For the llama2-7b-hf model we have
# 2 .bin files. The checkpointer takes care of sorting
# by id and so the order here does not matter
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin,
]
# if we're restarting a previous run, we need to specify
# the file with the checkpoint state. More on this in the
# next section
recipe_checkpoint: null
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: False
注意
检查点在HF格式与torchtune格式之间转换需要访问模型参数,这些参数直接从config.json文件中读取。这有助于确保我们正确加载权重,或者在HF检查点文件与torchtune的模型实现之间存在差异时出现错误。此JSON文件与模型检查点一起从Hub下载。
MetaCheckpointer¶
此检查点读取器和写入器使用的格式与原始 meta-llama GitHub 仓库兼容。
为了使此检查点程序正常工作,我们假设 checkpoint_dir 包含必要的检查点和 JSON 文件。确保一切正常工作的最简单方法是使用以下流程:
使用 tune download 从 HF 仓库下载模型。默认情况下,这将忽略"safetensors"文件。
tune download meta-llama/Llama-2-7b \ --output-dir <checkpoint_dir> \ --hf-token <hf-token>
使用
output_dir作为检查点程序的checkpoint_dir。
以下代码片段说明了如何在 torchtune 配置文件中设置 MetaCheckpointer。
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelMetaCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>
# checkpoint files. For the llama2-7b model we have
# a single .pth file
checkpoint_files: [consolidated.00.pth]
# if we're restarting a previous run, we need to specify
# the file with the checkpoint state. More on this in the
# next section
recipe_checkpoint: null
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: False
TorchTuneCheckpointer¶
此检查点读取器和写入器以与 torchtune 模型定义兼容的格式读写检查点。它不执行任何 state_dict 转换,目前仅用于测试或加载用于生成的量化模型。
Intermediate vs Final Checkpoints¶
torchtune 检查点器支持两种检查点场景:
训练结束时的检查点保存
训练运行完成后,模型权重会被写入文件。检查点器确保输出检查点文件的键与用于开始训练的输入检查点文件的键相同。检查点器还确保这些键被划分到与原始检查点相同数量的文件中。输出状态字典具有以下标准格式:
{ "key_1": weight_1, "key_2": weight_2, ... }
训练中途检查点。
如果在训练过程中进行检查点保存,输出的检查点需要存储额外的信息,以确保后续的训练运行可以正确地重新启动。除了模型检查点文件外,我们还会输出一个recipe_state.pt文件用于中间检查点。这些文件目前会在每个 epoch 结束时输出,并包含优化器状态、已完成的 epoch 数量等信息。
为了防止我们向 output_dir 中填充过多的检查点文件,每个 epoch 结束时都会覆盖配方状态。
输出状态字典具有以下格式:
Model: { "key_1": weight_1, "key_2": weight_2, ... } Recipe State: { "optimizer": ..., "epoch": ..., ... }
若要从前一个检查点文件重新开始,您需要对配置文件进行以下更改
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: <checkpoint_dir>
# checkpoint files. Note that you will need to update this
# section of the config with the intermediate checkpoint files
checkpoint_files: [
hf_model_0001_0.pt,
hf_model_0002_0.pt,
]
# if we're restarting a previous run, we need to specify
# the file with the checkpoint state
recipe_checkpoint: recipe_state.pt
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: True
LoRA的检查点¶
在torchtune中,我们输出LoRA的适配器权重和完整的“合并”模型权重。这个“合并”检查点可以像使用源检查点一样用于任何后续训练工具。更多细节,请查看我们的 LoRA微调教程。此外,通过在保存检查点时将选项“save_adapter_weights_only”设置为True,您可以选择仅保存适配器权重。
这两种用例的主要区别在于您何时希望从检查点恢复训练。在这种情况下,检查点保存器需要同时访问初始冻结的基础模型权重以及已学习到的适配器权重。此场景的配置文件大致如下:
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>
# checkpoint files. This is the ORIGINAL frozen checkpoint
# and NOT the merged checkpoint output during training
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin,
]
# this refers to the adapter weights learnt during training
adapter_checkpoint: adapter_0.pt
# the file with the checkpoint state
recipe_checkpoint: recipe_state.pt
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: True
# Set to True to save only the adapter weights
save_adapter_weights_only: False
将这些内容整合在一起¶
现在让我们将所有这些知识整合起来!我们将加载一些检查点,创建一些模型并运行一个简单的前向传播。
在本节中,我们将使用 HF 格式的 Llama2 13B 模型。
import torch
from torchtune.training import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b
# Set the right directory and files
checkpoint_dir = 'Llama-2-13b-hf/'
pytorch_files = [
'pytorch_model-00001-of-00003.bin',
'pytorch_model-00002-of-00003.bin',
'pytorch_model-00003-of-00003.bin'
]
# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=checkpoint_dir,
checkpoint_files=pytorch_files,
output_dir=checkpoint_dir,
model_type=ModelType.LLAMA2
)
torchtune_sd = checkpointer.load_checkpoint()
# Setup the model and the input
model = llama2_13b()
# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
<All keys matched successfully>
# We have 32000 vocab tokens; lets generate an input with 70 tokens
x = torch.randint(0, 32000, (1, 70))
with torch.no_grad():
model(x)
tensor([[[ -6.3989, -9.0531, 3.2375, ..., -5.2822, -4.4872, -5.7469],
[ -8.6737, -11.0023, 6.8235, ..., -2.6819, -4.2424, -4.0109],
[ -4.6915, -7.3618, 4.1628, ..., -2.8594, -2.5857, -3.1151],
...,
[ -7.7808, -8.2322, 2.8850, ..., -1.9604, -4.7624, -1.6040],
[ -7.3159, -8.5849, 1.8039, ..., -0.9322, -5.2010, -1.6824],
[ -7.8929, -8.8465, 3.3794, ..., -1.3500, -4.6145, -2.5931]]])
您可以使用torchtune支持的任何模型完成此操作。您可以在 此处 找到完整模型列表和模型构建器。
我们希望此次深入探讨能让您对 torchtune 中的检查点保存器及相关工具获得更深刻的理解。祝您调参顺利!