Knowledge Distillation (Teacher/Student Models)
Knowledge distillation transfers dark knowledge from a large, pre-trained teacher model to a smaller, efficient student model. By training the student to match the softened probability outputs of the teacher, distillation preserves accuracy in compressed models.
The Dark Knowledge Paradigm
Knowledge distillation transfers the dark knowledge of class correlations from a teacher model to a student using soft targets.
Soft Targets
In standard classification training, models are optimized using hard targets (one-hot encoded vectors where the correct class is 1 and all other classes are 0). While hard targets indicate the correct label, they discard valuable information about class similarities. For example, when classifying images, a model might predict that a dog image has a probability of 0.8 for "dog", 0.19 for "cat", and 0.01 for "car". This distribution contains "dark knowledge": it shows that the model has learned that dogs are semantically closer to cats than to cars.
Knowledge distillation (Hinton et al., 2015) transfers this dark knowledge to a smaller student model. By training the student to match the full probability distribution of the teacher (the soft targets), we guide the student to learn the same decision boundaries and semantic associations, improving its generalization capacity.
The Softmax Temperature
To extract dark knowledge, we must prevent the teacher's probability distribution from being too sharp (where the correct class approaches 1.0 and other classes approach 0.0, hiding the similarities). We achieve this by introducing a Temperature parameter \\( T \\) into the Softmax function:
\\( p_i = \\frac{e^{z_i / T}}{\\sum_j e^{z_j / T}} \\)
Where \\( z_i \\) represents the raw output logits. When \\( T = 1 \\), we get the standard Softmax distribution. As \\( T \\) increases, the probability distribution becomes softer (more uniform), highlighting the relationships between non-target classes. The student is trained to match these softened probabilities, using the same temperature parameter \\( T \\) during loss calculation.
PyTorch Distillation Implementation
A PyTorch distillation loop trains a student model by optimizing a joint objective that balances student and teacher classification losses.
Model Setup
The following PyTorch script defines the training loop for knowledge distillation, incorporating both hard targets and soft teacher guidance:
<pre><code class="language-python">import torch import torch.nn as nn import torch.nn.functional as F def distillation_loss_fn(student_logits, teacher_logits, labels, T, alpha): """ Computes the joint knowledge distillation loss. """ # 1. Soft loss: KL divergence between softened student and teacher probabilities # We scale the gradients of the soft loss by T^2 to maintain consistent magnitudes soft_student = F.log_softmax(student_logits / T, dim=-1) soft_teacher = F.softmax(teacher_logits / T, dim=-1) kl_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T) # 2. Hard loss: standard Cross-Entropy against ground-truth labels hard_loss = F.cross_entropy(student_logits, labels) # Joint loss return alpha * kl_loss + (1.0 - alpha) * hard_loss # Example execution setup # Simulate student and teacher logit outputs for 3 classes, batch size 2 student_out = torch.randn(2, 3, requires_grad=True) teacher_out = torch.randn(2, 3) # Teacher is frozen, requires no gradients y_ground_truth = torch.tensor([0, 2]) loss = distillation_loss_fn(student_out, teacher_out, y_ground_truth, T=3.0, alpha=0.7) loss.backward() print("Distillation backpropagation step completed. Loss:", loss.item())</pre>Loss Function
The distillation loss function combines two objectives. The soft loss measures the KL divergence between the softened student and teacher predictions: \\( L_{KD} = D_{KL}(\\text{Softmax}(z^S/T) \\parallel \\text{Softmax}(z^T/T)) \\). The hard loss is the standard cross-entropy between the student's unsoftened predictions (\\( T = 1 \\)) and the ground-truth labels.
We balance these losses using a scaling parameter \\( \\alpha \\) (typically 0.5 to 0.9): \\( L = \\alpha T^2 L_{KD} + (1 - \\alpha) L_{CE} \\). The \\( T^2 \\) factor is required because the gradients of the soft loss scale down by \\( 1/T^2 \\) when dividing logits by temperature. Multiplying by \\( T^2 \\) keeps the relative contributions of the soft and hard losses balanced during optimization.
Architectural Choices and Variants
Distillation strategies expand beyond final logits to intermediate features and online multi-model training.
Response, Feature, and Relation Distillation
Knowledge distillation can be applied at different stages of the network. Response distillation (the standard method) matches the final output logits of the model. Feature distillation (introduced by FitNets) matches the intermediate hidden layer activations. We project the student's intermediate feature maps to match the shape of the teacher's feature maps and minimize the L2 distance between them. This forces the student to learn similar intermediate representations.
Relation distillation matches the relationships between different data samples. Instead of evaluating samples individually, we compute a similarity matrix between samples in a batch using the teacher's features, and train the student to match this similarity structure. Combining these strategies improves distillation performance for complex tasks.
Offline vs. Online Distillation
In offline distillation, the teacher model is pre-trained and its weights are frozen. We pass the training data through the frozen teacher to generate soft targets, and use them to train the student. This setup is simple but requires substantial storage to save the teacher's predictions if computed ahead of time, or increases computational cost if computed on-the-fly.
In online distillation, both the teacher and student are trained simultaneously. A common variant is deep mutual learning, where multiple student models are trained together, with each model acting as a teacher to the others. This co-training setup allows all models to converge to better minima without requiring a large pre-trained teacher.