torchtune 中的检查点¶
此深入探讨将引导您了解 checkpointer 的设计和行为,以及 关联的实用程序。
torchtune 的 Checkpointer 设计
检查点格式以及我们如何处理它们
检查点场景:Intermediate vs Final 和 LoRA vs Full-finetune
概述¶
Torchtune 检查点设计为可插入的可组合组件 到任何配方中 - 训练、评估或生成。每个 checkpointer 都支持 一组模型和场景,使这些易于理解、调试和扩展。
在我们深入研究 torchtune 中的 checkpointer 之前,让我们定义一些概念。
检查点格式¶
在这次深入探讨中,我们将讨论不同的检查点格式以及 torchtune 如何处理它们。 让我们仔细看看这些不同的格式。
简单地说,检查点的格式由state_dict及其存储方式决定 在磁盘上的文件中。每个权重都与一个字符串键相关联,该键在 state dict 中标识它。 如果存储的 checkpoint 中 key 的字符串标识符不匹配 与模型定义中的参数完全相同,您将遇到显式错误(加载 state dict 将引发异常)或更糟 - 静默错误(加载将成功,但训练或 inference 将无法按预期工作)。除了排列的键之外,您还需要形状 的权重(state_dict中的值)与模型预期的值完全匹配 定义。
让我们看看 Llama 3.2 的两种流行格式。
元格式
这是官方 Llama 3.2 实现支持的格式。当您下载 Llama 3.2 3B 模型时
从 Meta-LLAMA 网站,您将可以访问单个检查点文件。您可以使用.pth
torch.load
>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>> print(f'{key}: {value.shape}')
tok_embeddings.weight: torch.Size([128256, 3072])
...
...
>>> print(len(state_dict.keys()))
255
该state_dict包含 255 个键,包括一个名为 .这
state_dict模型定义需要一个嵌入层,其中每个标记都有一个
嵌入 dim 为 。tok_embeddings
128256
3072
HF 格式
这是 Hugging Face Model Hub 中最受欢迎的格式,并且是 每个 torchtune 配置中的默认格式。这也是您下载 llama3.2 模型。
第一个很大的区别是 state_dict 被拆分为两个文件。要正确
load the checkpoint 中,您需要将这些文件拼凑在一起。让我们检查其中一个文件。.safetensors
>>> from safetensors import safe_open
>>> state_dict = {}
>>> with safe_open("model-00001-of-00002.safetensors", framework="pt", device="cpu") as f:
>>> for k in f.keys():
>>> state_dict[k] = f.get_tensor(k)
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>> print(f'{key}: {value.shape}')
model.embed_tokens.weight: torch.Size([128256, 3072])
...
...
>>> print(len(state_dict.keys()))
187
state_dict不仅包含较少的键(意料之中,因为这是两个文件之一),而且
嵌入表称为 ,而不是 。这种不匹配
in names 将在您尝试加载 state_dict 时导致异常。此层的大小为
两者之间相同,正如预期的那样。model.embed_tokens
tok_embeddings
正如你所看到的,如果你不小心,你很可能会在 checkpoint load 和 save。torchtune 检查点通过管理状态 dict 来降低出错率 给你的。torchtune 被设计为 “state-dict invariant” 。
加载时,torchtune 接受来自多个来源的多种格式的检查点。 您不必担心每次运行配方时都显式转换检查点。
保存时,torchtune 会生成与源相同格式的检查点。这包括 将 state_dict 转换回原始形式并拆分 Key 和 Weights 在相同数量的文件中。
“state-dict invariant”的一大优点是您应该能够使用 使用任何后训练工具(量化、评估、推理)从 Torchtune 微调检查点 它支持源格式,无需任何代码更改或转换脚本。这是 Torchtune 与周围生态系统互操作的方式。
注意
要以这种方式成为 state-dict “不变”,每个 checkpointer 的 and 方法
使用权重转换器,在 checkpoint 格式之间正确映射权重。例如,当加载权重
在 Hugging Face 中,我们对 load 和 save 上的某些权重应用排列,以确保 checkpoint 的行为完全相同。
为了进一步说明这一点,Llama 系列模型使用通用的权重转换器功能,而 Phi3 等其他一些模型有自己的转换功能,可以在其模型文件夹中找到。load_checkpoint
save_checkpoint
处理不同的 Checkpoint 格式¶
Torchtune 支持三种不同的 Checkpointer, 每个 Checkpoint 都支持不同的 checkpoint 格式。
HFCheckpointer
¶
此 checkpointer 以与 transformer 兼容的格式读取和写入 checkpoint 框架。如上所述,这是 Hugging Face 中最流行的格式 Model Hub 的 ,是每个 torchtune 配置中的默认格式。
为了使此 checkpointer 正常工作,我们假设它包含必要的 checkpoint
和 json 文件。确保一切正常的最简单方法是使用以程:checkpoint_dir
使用 tune download 从 HF 存储库下载模型。这将忽略 “pth” 文件,因为我们将加载 “safeTensors”。
tune download meta-llama/Llama-3.2-3B-Instruct \ --output-dir /tmp/Llama-3.2-3B-Instruct \ --ignore-patterns "original/consolidated.00.pth"
使用 specified here 作为 checkpointer 的参数。
output_dir
checkpoint_dir
以下代码片段说明了如何在 torchtune 配置文件中设置 HFCheckpointer。
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the folder you used when downloading the model
checkpoint_dir: /tmp/Llama-3.2-3B-Instruct
# checkpoint files. For the Llama-3.2-3B-Instruct model we have
# 2 .safetensor files. The checkpointer takes care of sorting
# by id and so the order here does not matter
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
# dir for saving the output checkpoints
output_dir: <output_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA3_2
# set to True if restarting training. More on that later.
resume_from_checkpoint: False
注意
与 HF 格式之间的检查点转换需要访问模型参数,这些参数是
直接从文件中读取。这有助于确保我们加载权重
正确或错误,如果 HF 检查点文件和 torchtune 的
模型实现。此 json 文件与模型检查点一起从应用中心下载。config.json
MetaCheckpointer 元校验指针
¶
这个 checkpointer 以与原始 meta-lla 兼容的格式读取和写入 checkpoint GitHub 存储库。
为了使此 checkpointer 正常工作,我们假设它包含必要的 checkpoint
和 json 文件。确保一切正常的最简单方法是使用以程:checkpoint_dir
使用 tune download 从 HF 存储库下载模型。默认情况下,这将忽略 “safetensors” 文件。
tune download meta-llama/Llama-3.2-3B-Instruct \ --output-dir /tmp/Llama-3.2-3B-Instruct \ --ignore-patterns "*.safetensors"
将 above 用作 checkpointer 的 。
output_dir
checkpoint_dir
以下代码片段说明了如何在 torchtune 配置文件中设置 MetaCheckpointer。
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelMetaCheckpointer
# directory with the checkpoint files
# this should match the folder you used when downloading the model
checkpoint_dir: <checkpoint_dir>
# checkpoint files. For the llama3.2 3B model we have
# a single .pth file
checkpoint_files: [consolidated.00.pth]
# dir for saving the output checkpoints.
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA3_2
# set to True if restarting training. More on that later.
resume_from_checkpoint: False
TorchTuneCheck指针
¶
此 checkpointer 以与 torchtune 的 模型定义。这不会执行任何state_dict转换,当前正在使用 用于测试或加载量化模型以进行生成。
检查点输出¶
恭喜你走到这一步!假设您已经遵循了我们的 torchtune 端到端工作流程,并使用我们的 LoRA 配方之一训练了一只 llama 3.2 3B。
现在让我们可视化输出。执行此操作的一种简单方法是运行 ,它应该显示类似于下面的树的内容。
有 3 种类型的文件夹:tree -a path/to/outputdir
recipe_state:保存 recipe_state.pt 其中包含从最后一个中间 epoch 重新启动训练运行所需的信息。稍后会详细介绍;
logs:metric_logger的输出(如果有);
epoch_{}:包含经过训练的模型权重和模型元数据。如果运行推理或推送到模型中心,则应直接使用此文件夹;
注意
对于每个 epoch,我们复制原始 checkpoint 文件夹的内容,不包括原始 checkpoint 和大文件。 这些文件是轻量级的,主要是配置文件,使用户更容易直接在下游应用程序中使用 epoch 文件夹。
有关每个文件的更多详细信息,请查看上面提到的端到端教程。
>>> tree -a /tmp/torchtune/llama3_2_3B/lora_single_device /tmp/torchtune/llama3_2_3B/lora_single_device ├── epoch_0 │ ├── adapter_config.json │ ├── adapter_model.pt │ ├── adapter_model.safetensors │ ├── config.json │ ├── ft-model-00001-of-00002.safetensors │ ├── ft-model-00002-of-00002.safetensors │ ├── generation_config.json │ ├── LICENSE.txt │ ├── model.safetensors.index.json │ ├── original │ │ ├── orig_params.json │ │ ├── params.json │ │ └── tokenizer.model │ ├── original_repo_id.json │ ├── README.md │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ ├── tokenizer.json │ └── USE_POLICY.md ├── epoch_1 │ ├── adapter_config.json │ ├── adapter_model.pt │ ├── adapter_model.safetensors │ ├── config.json │ ├── ft-model-00001-of-00002.safetensors │ ├── ft-model-00002-of-00002.safetensors │ ├── generation_config.json │ ├── LICENSE.txt │ ├── model.safetensors.index.json │ ├── original │ │ ├── orig_params.json │ │ ├── params.json │ │ └── tokenizer.model │ ├── original_repo_id.json │ ├── README.md │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ ├── tokenizer.json │ └── USE_POLICY.md ├── logs │ └── log_1734652101.txt └── recipe_state └── recipe_state.pt
中间检查点与最终检查点¶
torchtune 检查点支持两种检查点方案:
训练结束检查点
模型在完成训练结束时进行加权 run 的 RUN 写入 file。检查点程序确保输出检查点 文件与用于开始训练的输入检查点文件具有相同的键。这 Checkpointer 还确保 Key 在相同数量的 files 作为原始检查点。输出状态 dict 具有以下 标准格式:
{ "key_1": weight_1, "key_2": weight_2, ... }
训练中期 Chekpointing。
如果在训练过程中执行 checkpointing,则输出 checkpoint 需要存储额外的
信息,以确保后续训练运行可以正确重启。除了
模型检查点文件中,我们输出一个文件用于 intermediate
检查站。这些当前在每个 epoch 结束时输出,并包含信息
例如优化器状态、已完成的 epoch 数等。recipe_state.pt
为了防止我们被 checkpoint 文件淹没,配方状态为
在每个 epoch 结束时覆盖。output_dir
输出状态 dict 具有以下格式:
Model: { "key_1": weight_1, "key_2": weight_2, ... } Recipe State: { "optimizer": ..., "epoch": ..., ... }
从 checkpoint 恢复 - 完全微调¶
有时我们的训练会因为某种原因而中断。要从以前的检查点文件重新开始训练, 您需要更新配置中的以下字段:
resume_from_checkpoint:将其设置为 True;
checkpoint_files:将路径更改为epoch_{YOUR_EPOCH}/ft-model={}-of-{}.safetensors
;
请注意,我们不会更改 checkpoint_dir 或 output_dir。由于我们正在从 checkpoint 恢复,因此我们知道 以在 output_dir 中查找它。
checkpointer:
# checkpoint files. Note that you will need to update this
# section of the config with the intermediate checkpoint files
checkpoint_files: [
epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
]
# set to True if restarting training
resume_from_checkpoint: True
从检查点恢复 - LoRA Finetuning¶
与完全微调类似,我们也只需要修改两个字段: 和 ,它们将从 output_dir 加载。我们不必修改,
因为正在加载的基础模型仍然相同。resume_from_checkpoint
adapter_checkpoint
checkpoint_files
checkpointer:
# adapter_checkpoint. Note that you will need to update this
# section of the config with the intermediate checkpoint files
adapter_checkpoint: epoch_{YOUR_EPOCH}/adapter_model.safetensors
# set to True if restarting training
resume_from_checkpoint: True
# set to True to save only the adapter weights
# it does not influence resuming_from_checkpointing
save_adapter_weights_only: False
注意
在 torchtune 中,我们输出适配器权重和完整的模型合并权重 对于 LoRA。合并的 checkpoint 很方便,因为它可以在没有特殊 工具来处理适配器。但是,在恢复时不应使用它们 training 的 Defined Defined 的 Lut S Package,因为加载合并的权重 + 适配器将是一个错误。因此,在恢复 LoRA 时, 我们将从 checkpoint dir 中获取原始的未训练的 weigths,并将经过训练的 output_dir 的适配器。有关更多详细信息,请查看我们的 LoRA 微调教程。
注意
此外,通过设置选项 ,您可以选择仅保存适配器权重。
这减少了保存 checkpoint 所需的存储量和时间,但对从 checkpoint 恢复没有影响。save_adapter_weights_only
把这些放在一起¶
现在让我们把所有这些知识放在一起吧!我们将加载一些 checkpoints, 创建一些模型并运行一个简单的 forward。
在本节中,我们将使用 HF 格式的 Llama-3.2-3B-Instruct 模型。
import torch
from torchtune.models.llama3_2 import llama3_2_3b
from torchtune.training import FullModelHFCheckpointer
# Set the right directory and files
checkpoint_dir = "/tmp/Llama-3.2-3B-Instruct/"
output_dir = "/tmp/torchtune/llama3_2_3B/full_single_device"
pytorch_files = [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
]
# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=checkpoint_dir,
checkpoint_files=pytorch_files,
output_dir=output_dir,
model_type="LLAMA3_2",
)
torchtune_sd = checkpointer.load_checkpoint()
# Setup the model and the input
model = llama3_2_3b()
# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
model.to("cuda")
# We have 128256 vocab tokens; lets generate an input with 24 tokens
x = torch.randint(0, 128256, (1, 24), dtype=torch.long, device="cuda")
tensor([[[ 1.4299, 1.1658, 4.2459, ..., -2.3259, -2.3262, -2.3259],
[ 6.5942, 7.2284, 2.4090, ..., -6.0129, -6.0121, -6.0127],
[ 5.6462, 4.8787, 4.0950, ..., -4.6460, -4.6455, -4.6457],
...,
[-0.4156, -0.0626, -0.0362, ..., -3.6432, -3.6437, -3.6427],
[-0.5679, -0.6902, 0.5267, ..., -2.6137, -2.6138, -2.6127],
[ 0.3688, -0.1350, 1.1764, ..., -3.4563, -3.4565, -3.4564]]],
device='cuda:0')
您可以使用 torchtune 支持的任何模型执行此操作。您可以找到完整列表 模型和模型生成器。
我们希望这次深入探讨能让您更深入地了解 checkpointer 和 Torchtune 中的关联实用程序。调音愉快!