目录

使用 Knowledge Distillation 将 Llama3.1 8B 蒸馏成 Llama3.2 1B

本指南将教你知识蒸馏 (KD),并展示如何使用 torchtune 将 Llama3.1 8B 模型提炼成 Llama3.2 1B。 如果你已经知道什么是知识蒸馏,并想直接在 torchtune 中运行你自己的蒸馏, 您可以跳转到 Torchtune 中的 KD 配方教程。

您将学到什么
  • 什么是 KD 以及它如何帮助提高模型性能

  • torchtune 中的 KD 组件概述

  • 如何使用 torchtune 从教师到学生模型进行提炼

  • 如何试验不同的 KD 配置

先决条件

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术 将知识从较大的 (教师) 模型转移到较小的 (学生) 模型。较大的模型具有 更多的参数和知识容量,但是,这种更大的容量也更多的是计算 部署成本高昂。知识蒸馏可用于将较大模型的知识压缩为 一个较小的模型。这个想法是,可以通过从较大的模型中学习来提高较小模型的性能 model 的输出。

知识蒸馏如何运作?

通过在迁移集上训练知识,将知识从教师转移到学生模型,其中 Student 被训练为模仿教师的标记级概率分布。下图 是 KD 工作原理的简化表示。

../_images/kd-simplified.png

可以通过多种方式配置 total loss。torchtune 中的默认 KD 配置将交叉熵 (CE) 损失与 forward Kullback-Leibler (KL) 发散损失, 用于标准 KD 方法。正向 KL 背离旨在通过强制学生的 distribution 以与教师的所有分布保持一致。但是,将学生分布与整体保持一致 教师分布可能效果不佳,并且有多个论文,例如 MiniLLMDistiLLM广义 KD, 引入新的 KD 损失以解决限制。在本教程中,我们来看看 远期 KL 背离损失。

import torch
import torch.nn.functional as F

class ForwardKLLoss(torch.nn.Module):
  def __init__(self, ignore_index: int = -100)
    super().__init__()
    self.ignore_index = ignore_index

  def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:
    # Implementation from https://github.com/jongwooko/distillm
    # Computes the softmax of the teacher logits
    teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    # Computes the student log softmax probabilities
    student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
    # Computes the forward KL divergence
    prod_probs = teacher_prob * student_logprob
    # Compute the sum
    x = torch.sum(prod_probs, dim=-1).view(-1)
    # We don't want to include the ignore labels in the average
    mask = (labels != self.ignore_index).int()
    # Loss is averaged over non-ignored targets
    return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

为了简化计算,省略了一些细节,但如果您想了解更多信息, 您可以在 中看到实现。 默认情况下,KD 配置用于减少内存。 当前实现仅支持具有相同输出的 student 和 teacher 模型 logit 形状和相同的 tokenizer 进行匹配。

torchtune 中的 KD 配方

借助 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3 以及其他 LLM 模型系列。 让我们来看看如何使用 torchtune 的 KD 配方来提炼模型。

首先,确保您已下载所有模型权重。在此示例中,我们将使用 Llama3.1-8B 作为 teacher,使用 Llama3.2-1B 作为 student。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

然后,我们将使用 LoRA 微调教师模型。根据我们的实验和以前的工作, 我们发现,当教师模型已经在 Target 数据集上进行了微调时,KD 的表现会更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令,将微调后的 8B 模型提炼成单个 GPU 上的 1B 模型。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在前面的示例中,我们使用了 LoRA 微调的 8B 教师模型和基线 1B 学生模型 但我们可能想用不同的配置和超参数进行一些实验。 在本教程中,我们将对 truthfulqa_mc2hellaswagcommonsense_qa 任务上的模型进行微调和评估 通过 EleutherAI LM 评估工具。 让我们来看看以下因素的影响:

  1. 使用微调的教师模型

  2. 使用微调的学生模型

  3. kd_ratio 和学习率的超参数优化

  4. 参数数量较近的教师和学生模型

使用微调的教师模型

配置中的默认设置使用微调的教师模型。现在,我们来看看 不先微调教师模型的影响。要更改教师模型,您可以在配置中修改 :teacher_checkpointer

teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
      model-00001-of-00004.safetensors,
      model-00002-of-00004.safetensors,
      model-00003-of-00004.safetensors,
      model-00004-of-00004.safetensors
  ]

在下表中,我们可以看到 1B 模型的标准微调实现了更好的准确性 比基线 1B 模型。通过使用微调的 8B 教师模型,我们看到了类似的结果 对于 TruthfulQA 和改进 Hellaswag 和 Commonsense。当使用基线 8B 作为 老师,我们看到所有指标都有所改善,但低于其他配置。

../_images/kd-finetune-teacher.png

看一下损失,使用基线 8B 作为教师会导致比 使用微调的教师模式。KD 损失也保持相对稳定,表明 教师模型应具有与转移数据集相同的分布。

使用微调的学生模型

对于这些实验,让我们看一下当学生模型已经 微调。在这些实验中,我们研究了基线和微调 8B 的不同组合 和 1B 型号。要更改学生模型,您可以先微调 1B 模型,然后修改 Config 中的 Student Model Checkpointer:

checkpointer:
   _component_: torchtune.training.FullModelHFCheckpointer
   checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
   checkpoint_files: [
     hf_model_0001_0.pt
   ]

使用微调的学生模型可以进一步提高 truthfulqa 的准确性,但准确性 Hellaswag 和 Commonsense 的掉落。使用微调的教师模型和基线学生 模型在 Hellaswag 和 Common Sense 数据集上取得了最好的结果。基于这些发现, 最佳配置将根据您正在优化的评估数据集和量度而变化。

../_images/kd-finetune-student.png

根据损失图,使用微调的教师模型会导致较低的损失,而不管 学生模型是否经过微调。有趣的是,类的损失 在使用微调的学生模型时开始增加。

超参数优化:学习率

默认情况下,配置的学习率为 ,这与 LoRA 配置相同。对于这些实验, 我们将学习率从 high 更改为 Till 。要更改学习率, 您可以简单地使用以下方法覆盖 Learning rate 参数:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3

根据结果,最佳学习率会根据要优化的指标而变化。

../_images/kd-hyperparam-lr.png

根据损失图,所有学习率都会导致类似的损失,但 除外,它具有更高的 KD 和类损失。

超参数优化:KD 比率

在配置中,我们有 as 0.5,它为类和 KD 损失提供均匀的权重。在这些实验中, 我们研究了不同 KD 比率的影响,其中 0 仅使用类损失,1 仅使用 KD 损失。 与更改学习率类似,可以使用以下方法调整 KD 比率:kd_ratio

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=0.25

总体而言,KD 比值较高时,评估结果略好。

../_images/kd-hyperparam-kd-ratio.png

Qwen2 1.5B 至 0.5B

KD 配方也可以应用于不同的模型系列。这里我们看一下 KD 的影响,当 teacher 和 student 模型之间的参数更接近。在本实验中,我们使用了 Qwen2 1.5B 和 Qwen2 0.5B,其配置可以在 qwen2/knowledge_distillation_single_device config 中找到。在这里,我们看到在 alpaca 清理的数据集上进行训练只会提高truthful_qa性能,并会降低其他评估任务的指标。 对于 truthful_qa,KD 使学生模型性能提高了 5.8%,而微调使性能提高了 1.3%。

../_images/kd-qwen2-res.png

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源