使用 Knowledge Distillation 将 Llama3.1 8B 蒸馏成 Llama3.2 1B¶
本指南将教你知识蒸馏 (KD),并展示如何使用 torchtune 将 Llama3.1 8B 模型提炼成 Llama3.2 1B。 如果你已经知道什么是知识蒸馏,并想直接在 torchtune 中运行你自己的蒸馏, 您可以跳转到 Torchtune 中的 KD 配方教程。
什么是 KD 以及它如何帮助提高模型性能
torchtune 中的 KD 组件概述
如何使用 torchtune 从教师到学生模型进行提炼
如何试验不同的 KD 配置
熟悉 torchtune
确保您已下载 Llama3 模型权重
熟悉 LoRA
什么是知识蒸馏?¶
知识蒸馏是一种广泛使用的压缩技术 将知识从较大的 (教师) 模型转移到较小的 (学生) 模型。较大的模型具有 更多的参数和知识容量,但是,这种更大的容量也更多的是计算 部署成本高昂。知识蒸馏可用于将较大模型的知识压缩为 一个较小的模型。这个想法是,可以通过从较大的模型中学习来提高较小模型的性能 model 的输出。
知识蒸馏如何运作?¶
通过在迁移集上训练知识,将知识从教师转移到学生模型,其中 Student 被训练为模仿教师的标记级概率分布。下图 是 KD 工作原理的简化表示。
可以通过多种方式配置 total loss。torchtune 中的默认 KD 配置将交叉熵 (CE) 损失与 forward Kullback-Leibler (KL) 发散损失, 用于标准 KD 方法。正向 KL 背离旨在通过强制学生的 distribution 以与教师的所有分布保持一致。但是,将学生分布与整体保持一致 教师分布可能效果不佳,并且有多个论文,例如 MiniLLM、DistiLLM 和广义 KD, 引入新的 KD 损失以解决限制。在本教程中,我们来看看 远期 KL 背离损失。
import torch
import torch.nn.functional as F
class ForwardKLLoss(torch.nn.Module):
def __init__(self, ignore_index: int = -100)
super().__init__()
self.ignore_index = ignore_index
def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:
# Implementation from https://github.com/jongwooko/distillm
# Computes the softmax of the teacher logits
teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
# Computes the student log softmax probabilities
student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
# Computes the forward KL divergence
prod_probs = teacher_prob * student_logprob
# Compute the sum
x = torch.sum(prod_probs, dim=-1).view(-1)
# We don't want to include the ignore labels in the average
mask = (labels != self.ignore_index).int()
# Loss is averaged over non-ignored targets
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
为了简化计算,省略了一些细节,但如果您想了解更多信息,
您可以在 中看到实现。
默认情况下,KD 配置用于
减少内存。
当前实现仅支持具有相同输出的 student 和 teacher 模型
logit 形状和相同的 tokenizer 进行匹配。
torchtune 中的 KD 配方¶
借助 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3 以及其他 LLM 模型系列。 让我们来看看如何使用 torchtune 的 KD 配方来提炼模型。
首先,确保您已下载所有模型权重。在此示例中,我们将使用 Llama3.1-8B 作为 teacher,使用 Llama3.2-1B 作为 student。
tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>
tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>
然后,我们将使用 LoRA 微调教师模型。根据我们的实验和以前的工作, 我们发现,当教师模型已经在 Target 数据集上进行了微调时,KD 的表现会更好。
tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
最后,我们可以运行以下命令,将微调后的 8B 模型提炼成单个 GPU 上的 1B 模型。
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device
消融研究¶
在前面的示例中,我们使用了 LoRA 微调的 8B 教师模型和基线 1B 学生模型
但我们可能想用不同的配置和超参数进行一些实验。
在本教程中,我们将对 truthfulqa_mc2、hellaswag 和 commonsense_qa 任务上的模型进行微调和评估
通过 EleutherAI LM 评估工具。
让我们来看看以下因素的影响:
使用微调的教师模型
使用微调的学生模型
kd_ratio 和学习率的超参数优化
参数数量较近的教师和学生模型
使用微调的教师模型¶
配置中的默认设置使用微调的教师模型。现在,我们来看看
不先微调教师模型的影响。要更改教师模型,您可以在配置中修改 :teacher_checkpointer
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
在下表中,我们可以看到 1B 模型的标准微调实现了更好的准确性 比基线 1B 模型。通过使用微调的 8B 教师模型,我们看到了类似的结果 对于 TruthfulQA 和改进 Hellaswag 和 Commonsense。当使用基线 8B 作为 老师,我们看到所有指标都有所改善,但低于其他配置。
看一下损失,使用基线 8B 作为教师会导致比 使用微调的教师模式。KD 损失也保持相对稳定,表明 教师模型应具有与转移数据集相同的分布。
使用微调的学生模型¶
对于这些实验,让我们看一下当学生模型已经 微调。在这些实验中,我们研究了基线和微调 8B 的不同组合 和 1B 型号。要更改学生模型,您可以先微调 1B 模型,然后修改 Config 中的 Student Model Checkpointer:
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
hf_model_0001_0.pt
]
使用微调的学生模型可以进一步提高 truthfulqa 的准确性,但准确性 Hellaswag 和 Commonsense 的掉落。使用微调的教师模型和基线学生 模型在 Hellaswag 和 Common Sense 数据集上取得了最好的结果。基于这些发现, 最佳配置将根据您正在优化的评估数据集和量度而变化。
根据损失图,使用微调的教师模型会导致较低的损失,而不管 学生模型是否经过微调。有趣的是,类的损失 在使用微调的学生模型时开始增加。
超参数优化:学习率¶
默认情况下,配置的学习率为 ,这与 LoRA 配置相同。对于这些实验, 我们将学习率从 high 更改为 Till 。要更改学习率, 您可以简单地使用以下方法覆盖 Learning rate 参数:
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3
根据结果,最佳学习率会根据要优化的指标而变化。
根据损失图,所有学习率都会导致类似的损失,但 除外,它具有更高的 KD 和类损失。
超参数优化:KD 比率¶
在配置中,我们有 as 0.5,它为类和 KD 损失提供均匀的权重。在这些实验中,
我们研究了不同 KD 比率的影响,其中 0 仅使用类损失,1 仅使用 KD 损失。
与更改学习率类似,可以使用以下方法调整 KD 比率:kd_ratio
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=0.25
总体而言,KD 比值较高时,评估结果略好。
Qwen2 1.5B 至 0.5B¶
KD 配方也可以应用于不同的模型系列。这里我们看一下 KD 的影响,当 teacher 和 student 模型之间的参数更接近。在本实验中,我们使用了 Qwen2 1.5B 和 Qwen2 0.5B,其配置可以在 qwen2/knowledge_distillation_single_device config 中找到。在这里,我们看到在 alpaca 清理的数据集上进行训练只会提高truthful_qa性能,并会降低其他评估任务的指标。 对于 truthful_qa,KD 使学生模型性能提高了 5.8%,而微调使性能提高了 1.3%。