生成式人工智能/大语言模型

利用 NVIDIA NeMo-Aligner 进行监督式微调的数据高效知识蒸馏

知识蒸馏是一种将更大的教师模型的知识转移到更小的学生模型的方法,理想情况下可生成紧凑、易于部署的学生,且准确度与教师相当。知识蒸馏在预训练设置中越来越受欢迎,但用于在监督式微调(Supervised Fine-Tuning,SFT)期间执行知识蒸馏的资源越来越少。

NVIDIA NeMo-Aligner 开源了一个在 SFT 期间使用知识蒸馏的实现,相较于标准 SFT,该实现的数据效率更高,准确性也更高 (Table 1)。

训练目标 训练步骤 MMLU (5 次采样) MMLU (0 次采样) HumanEval (0 分) MBPP (零射) GSM8K (零射) 数学 (0 分)
表面张力损失 600000 65.3% 56.9% 64.6 71.7% 84.2 30.12
KD = SFT 损失 420000 65.3% 57.3% 70.1 73.3% 85.2 35.84
KD = SFT 损失 600000 65.3% 57.6% 72 73.8 84.8 36.6
表 1、在 Nemotron-4 15B 上以知识蒸馏使用表面张力的优势

在表 1 中,SFT 是使用数学/代码数据集执行的。使用知识蒸馏微调的模型版本在所有数学和代码相关基准测试中均优于基准,即使仅执行 70%的训练步骤也是如此。

NeMo-Aligner 中的知识蒸馏 

在 SFT 期间,有许多方法可以从大型模型传输知识。最常见的方法是使用教师模型生成合成数据,我们称之为 KD-SDG。然后,使用合成生成的数据微调学生模型。

还有一种开创性的方法,即训练学生以匹配教师的输出 logits。此方法在 Distilling the Knowledge in a Neural Network 中引入。我们将其称为 KD-logit。

此方法利用跨类(称为 暗知识 )的知识,生成信息更丰富的梯度信号。有关更多信息,请参阅神经网络中的 Dark Knowledge。

在本文和 NeMo-Aligner 中,我们将重点介绍在 SFT 期间应用 KD-logit。

NeMo-Aligner 的离线 KD-logit 工作流包含以下关键步骤:

  1. 教师模型对训练数据进行预测的预处理步骤。教师模型的 logits 添加到训练数据中。
  2. 这是一个训练步骤,其中对学生进行了训练,使其 logits 与教师的 logits 相匹配。

只需缓存一次教师的 logits。与在训练时动态计算教师逻辑相比,此方法具有以下优势:

  • 节省 内存: 您不必同时在 GPU 上加载教师和学生模型。
  • 加快训练速度: 您不必等待老师在训练期间做出预测。

但是,将所有教师的 logits 保存到磁盘可能需要大量内存。为节省内存,我们仅将教师的最高 K logits 保存到磁盘,其中 K 是从业者选择的超参数。

K 的值越大,学生可以从教师那里学习的细粒度信息越多,但内存压力就越大。在实践中,通常选择 K 值在 100 左右,这比典型的词汇量小。

将教师 logits 添加到数据集后,学生被训练以匹配教师的 top- K logits。具体来说,知识蒸馏损失函数等于 K 学生和教师 logits 之间的前向 KL 差异:

L^{kd} (p^S, p^T) = \sum_{k=1}^K p_k^T(\log p_k^T - \log p_k^S)

此损失函数与 Vanilla SFT 交叉熵损失函数结合使用,以生成最终训练目标,其中 \lambda 控制 SFT 损失项相对于 KD 损失项的强度:

L(p^S, p^T, y) = L^{kd} (p^S, p^T) + \lambda L^{sft}(p^S, y)

结果 

表 1 显示,与 Vanilla SFT 相比,使用知识蒸馏目标微调模型可获得更高的准确性和所需的训练令牌。我们使用 基础 Nemotron-4 15B 学生模型 微调的 Nemotron-4 340B 教师模型 进行实验。

用于 SFT 的数据集是使用以下论文中描述的技术生成的组合:

数据集的数学和代码部分均使用合成数据生成。这些实验设置了 K=100\lambda=0.1

在相同数量的训练步骤中,使用联合知识蒸馏和 SFT 目标微调的模型在七个评估指标中的六个方面的表现优于 SFT 基准。特别是,我们看到 HumanEval、MBPP 和 MATH 基准测试有了显著改进,这些基准用于衡量编码和数学推理技能。在评估各种语言理解任务的 MMLU 上,KD 微调模型的表现至少与零样本设置中的基准相当,并且在 5 镜头设置中优于基准。

KD-finetuned Nemotron-4 仅使用 70% 的训练令牌,但在相同的六个评估指标上,其性能仍然优于 Vanilla SFT 模型。

结束语 

这些结果具有两个重要含义。首先,我们已证明知识蒸馏(Knowledge Distillation)可用于提高微调模型的准确性。这在数据稀缺的设置中特别有用,因为需要更少的训练令牌才能实现良好的准确性。

其次,我们已经证明 KD-logit 可以与您的 SDG 数据结合使用,以实现复合优势。

有关如何在 NeMo-Aligner 中将知识蒸馏添加到 SFT 训练的更多信息,请参阅使用知识蒸馏进行监督微调 (SFT)。

标签