Generative AI

Data-Efficient Knowledge Distillation for Supervised Fine-Tuning with NVIDIA NeMo-Aligner

Icon image of a chart and search symbol, on a purple background.

Knowledge distillation is an approach for transferring the knowledge of a much larger teacher model to a smaller student model, ideally yielding a compact, easily deployable student with comparable accuracy to the teacher. Knowledge distillation has gained popularity in pretraining settings, but there are fewer resources available for performing knowledge distillation during supervised fine-tuning (SFT). 

NVIDIA NeMo-Aligner has open-sourced an implementation for using knowledge distillation during SFT that is more data-efficient and yields higher accuracy than its standard SFT counterpart (Table 1).

Training  ObjectiveTrain StepsMMLU (5-shot)MMLU (0-shot)HumanEval (0-shot)MBPP (0-shot)GSM8K (0-shot)MATH (0-shot)
SFT loss600,00065.356.964.671.784.230.12
KD + SFT loss420,00065.357.370.173.385.235.84
KD + SFT loss600,00065.357.67273.884.836.6
Table 1. Benefits of SFT with knowledge distillation on Nemotron-4 15B

In Table 1, SFT was performed using a math/code dataset. The version of the model finetuned using knowledge distillation outperforms the baseline on all math– and code-related benchmarks, even with only 70% of the training steps.

Knowledge distillation in NeMo-Aligner

There are a number of approaches to transfer knowledge from a large model during SFT. The most common approach involves using the teacher model for synthetic data generation, which we refer to as KD-SDG. The synthetically generated data is then used to fine-tune the student model.

There is also a seminal approach in which the student is trained to match the teacher’s output logits. This approach was introduced in Distilling the Knowledge in a Neural Network. We refer to this as KD-logit.

This method enables a more informative gradient signal, using knowledge of the similarities and dissimilarities across classes, termed dark knowledge. For more information, see Dark Knowledge in Neural Networks.

In this post and in NeMo-Aligner, we focus on applying KD-logit during SFT. 

NeMo-Aligner’s offline KD-logit pipeline consists of these key steps:

  1. A preprocessing step in which the teacher model makes predictions on the training data. The logits from the teacher model are added to the training data.
  2. A training step in which the student is trained to match its logits with the teacher’s logits. 

Caching the teacher’s logits only has to be performed one time. This approach has benefits compared to dynamically computing the teacher’s logits at training time: 

  • Memory savings: You don’t have to load both the teacher and student model on GPU at the same time.
  • Faster training: You don’t have to wait for the teacher to make predictions during training. 

However, saving all of the teacher’s logits to disk can be quite memory-intensive. To save memory, we instead only save the teacher’s highest K logits to disk, where K is a hyperparameter chosen by the practitioner. 

The larger the value of K, the more fine-grained information the student can learn from the teacher, but the higher the memory pressure. In practice, K is usually chosen to be around 100, which is magnitudes smaller than the typical vocabulary size.

After the teacher’s logits are added to the dataset, the student is trained to match the teacher’s top-K logits. Concretely, the knowledge distillation loss function is equal to the forward-KL divergence between the K student and teacher logits:

L^{kd} (p^S, p^T) = \sum_{k=1}^K p_k^T(\log p_k^T - \log p_k^S)

This loss function is combined with the vanilla SFT cross-entropy loss function to yield the final training objective, where \lambda controls the strength of the SFT loss term relative to the KD loss term:

L(p^S, p^T, y) = L^{kd} (p^S, p^T) + \lambda L^{sft}(p^S, y)

Results

Table 1 shows that fine-tuning a model using the knowledge distillation objective yields higher accuracy and requires fewer training tokens than vanilla SFT. We conducted experiments using a base Nemotron-4 15B student model and a fine-tuned Nemotron-4 340B teacher model.

The dataset used for SFT is a combination generated using the techniques described in the following papers:

Both the math and code portions of the dataset were generated using synthetic data generation. These experiments set K=100 and \lambda=0.1.

With the same number of training steps, the model fine-tuned using the joint knowledge distillation and SFT objective performs better than the SFT baseline on six of the seven evaluation metrics. In particular, we saw significant improvement in the HumanEval, MBPP, and MATH benchmarks, which measure coding and mathematical reasoning skills. On MMLU, which evaluates a diverse range of language understanding tasks, the KD-finetuned model performs at least as well as the baseline in the zero-shot setting and outperforms the baseline in the 5-shot setting.

With only 70% of the training tokens, the KD-finetuned Nemotron-4 still outperforms the vanilla SFT model on the same six evaluation metrics.

Conclusion

These results have two important implications. First, we’ve shown that knowledge distillation can be used to improve the accuracy of fine-tuned models. This is especially useful in settings where data is scarce, as fewer training tokens are needed to achieve good accuracy. 

Second, we’ve demonstrated that KD-logit can be used in conjunction with your SDG data to achieve compounding benefits.

For more information about how to add knowledge distillation to your SFT training in NeMo-Aligner, see Supervised Fine-Tuning (SFT) with Knowledge Distillation.

Discuss (0)

Tags