Knowledge distillation explained

In machine learning, knowledge distillation or model distillation is the process of transferring knowledge from a large model to a smaller one. While large models (such as very deep neural networks or ensembles of many models) have higher knowledge capacity than small models, this capacity might not be fully utilized. It can be just as computationally expensive to evaluate a model even if it utilizes little of its knowledge capacity. Knowledge distillation transfers knowledge from a large model to a smaller model without loss of validity. As smaller models are less expensive to evaluate, they can be deployed on less powerful hardware (such as a mobile device).[1]

Knowledge distillation has been successfully used in several applications of machine learning such as object detection,[2] acoustic models,[3] and natural language processing.[4] Recently, it has also been introduced to graph neural networks applicable to non-grid data.[5]

Concept of distillation

Transferring the knowledge from a large to a small model needs to somehow teach to the latter without loss of validity. If both models are trained on the same data, the small model may have insufficient capacity to learn a concise knowledge representation given the same computational resources and same data as the large model. However, some information about a concise knowledge representation is encoded in the pseudolikelihoods assigned to its output: when a model correctly predicts a class, it assigns a large value to the output variable corresponding to such class, and smaller values to the other output variables. The distribution of values among the outputs for a record provides information on how the large model represents knowledge. Therefore, the goal of economical deployment of a valid model can be achieved by training only the large model on the data, exploiting its better ability to learn concise knowledge representations, and then distilling such knowledge into the smaller model, that would not be able to learn it on its own, by training it to learn the soft output of the large model.

A related methodology was model compression or pruning, where a trained network is reduced in size, via methods such as Biased Weight Decay[6] and Optimal Brain Damage.[7]

The idea of using the output of one neural network to train another neural network was studied as the teacher-student network configuration.[8] In 1992, several papers studied the statistical mechanics of teacher-student network configuration, where both networks are committee machines[9] [10] or both are parity machines.[11]

Another early example of network distillation was also published in 1992, in the field of recurrent neural networks (RNNs). The problem was sequence prediction. It was solved by two RNNs. One of them ("atomizer") predicted the sequence, and another ("chunker") predicted the errors of the atomizer. Simultaneously, the atomizer predicted the internal states of the chunker. After the atomizer manages to predict the chunker's internal states well, it would start fixing the errors, and soon the chunker is obsoleted, leaving just one RNN in the end.[12]

A related methodology to compress the knowledge of multiple models into a single neural network was called model compression in 2006. Compression was achieved by training a smaller model on large amounts of pseudo-data labelled by a higher-performing ensemble, optimising to match the logit of the compressed model to the logit of the ensemble.[13] Knowledge distillation is a generalisation of such approach, introduced by Geoffrey Hinton et al. in 2015, in a preprint that formulated the concept and showed some results achieved in the task of image classification.

Knowledge distillation is also related to the concept of behavioral cloning discussed by Faraz Torabi et. al.[14]

Formulation

Given a large model as a function of the vector variable

x

, trained for a specific classification task, typically the final layer of the network is a softmax in the form

yi(x|t)=

zi(x)
t
e
\sum
zj(x)
t
e
j

where

t

is a parameter called temperature, that for a standard softmax is normally set to 1. The softmax operator converts the logit values

zi(x)

to pseudo-probabilities, and higher values of temperature have the effect of generating a softer distribution of pseudo-probabilities among the output classes. Knowledge distillation consists of training a smaller network, called the distilled model, on a dataset called transfer set (different than the dataset used to train the large model) using the cross entropy as loss function between the output of the distilled model

y(x|t)

and the output

\hat{y

}(\mathbf|t) produced by the large model on the same record (or the average of the individual outputs, if the large model is an ensemble), using a high value of softmax temperature

t

for both models

E(x|t)=-\sumi\hat{y}i(x|t)logyi(x|t).

In this context, a high temperature increases the entropy of the output, and therefore provides more information to learn for the distilled model compared to hard targets, at the same time reducing the variance of the gradient between different records and therefore allowing higher learning rates.

If ground truth is available for the transfer set, the process can be strengthened by adding to the loss the cross-entropy between the output of the distilled model (computed with

t=1

) and the known label

\bar{y}

E(x|t)=-t2\sumi\hat{y}i(x|t)logyi(x|t)-\sumi\bar{y}ilog\hat{y}i(x|1)

where the component of the loss with respect to the large model is weighted by a factor of

t2

since, as the temperature increases, the gradient of the loss with respect to the model weights scales by a factor of
1
t2
.

Relationship with model compression

Under the assumption that the logits have zero mean, it is possible to show that model compression is a special case of knowledge distillation. The gradient of the knowledge distillation loss

E

with respect to the logit of the distilled model

zi

is given by

\begin{align}

\partial
\partialzi

E &=-

\partial
\partialzi

\sumj\hat{y}jlogyj\\ &=-

\partial
\partialzi

\hat{y}ilogyi+\left(-

\partial
\partialzi

\sumk\hat{y}klogyk\right)\\ &=-\hat{y}i

1
yi
\partial
\partialzi

yi+\sumk\left(-\hat{y}k

1
yk

zk
t
e

\left(-

1
\left(\sum
zj
t
e
\right)2
j

\right)

zi
t
e

1
t

\right)\\ &=-\hat{y}i

1
yi
\partial
\partialzi
zi
t
e
\sum
zj
t
e
j

+\sumk\left(\hat{y}k

1
yk

ykyi

1
t

\right)\\ &=-\hat{y}i

1
yi

\left(

1
zi
t
e
\sumj
zj
t
e
-
1
t
\left(
zi
t
e
\right)2
t
\left(\sumj
zj
t
e
\right)2

\right)+

yi\sumk\hat{y
k}{t}\\

&=-\hat{y}i

1
yi

\left(

yi
t

-

2
y
i
t

\right)+

yi(1-\hat{y
i)}{t}\\

&=

1
t

\left(yi-\hat{y}i\right)\\ &=

1
t

\left(

zi
t
e
\sum
zj
t
e
j

-

\hat{z
i
e{t
}} \right) \\\endwhere

\hat{z}i

are the logits of the large model. For large values of

t

this can be approximated as
1
t

\left(

1+
zi
t
N+\sumj
zj
t

-

1+
\hat{z
i
t
}\right)and under the zero-mean hypothesis

\sumjzj=\sumj\hat{z}j=0

it becomes
zi-\hat{z
i}{NT

2}

, which is the derivative of
1
2

\left(zi-\hat{z}i\right)2

, i.e. the loss is equivalent to matching the logits of the two models, as done in model compression.

References

  1. Distilling the knowledge in a neural network. 2015. 1503.02531. Hinton. Geoffrey. Vinyals. Oriol. Dean. Jeff. stat.ML.
  2. Chen. Guobin. Wongun. Choi. Xiang. Yu. Tony. Han. Manmohan. Chandraker. Learning efficient object detection models with knowledge distillation. Advances in Neural Information Processing Systems. 742–751. 2017.
  3. Asami. Taichi. Ryo. Masumura. Yoshikazu. Yamaguchi. Hirokazu. Masataki. Yushi. Aono. Domain adaptation of DNN acoustic models using knowledge distillation. IEEE International Conference on Acoustics, Speech and Signal Processing. 5185–5189. 2017.
  4. Cui. Jia. Brian. Kingsbury. Bhuvana. Ramabhadran. Bhuvana Ramabhadran. George. Saon. Tom. Sercu. Kartik. Audhkhasi. Abhinav. Sethy. Markus. Nussbaum-Thom. Andrew. Rosenberg. Knowledge distillation across ensembles of multilingual models for low-resource languages. IEEE International Conference on Acoustics, Speech and Signal Processing. 4825–4829. 2017.
  5. Yang. Yiding. Qiu. Jiayan. Song. Mingli. Tao. Dacheng. Wang. Xinchao. Distilling Knowledge from Graph Convolutional Networks. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 7072–7081. 2020. 2003.10477. 2020arXiv200310477Y.
  6. Hanson . Stephen . Pratt . Lorien . 1988 . Comparing Biases for Minimal Network Construction with Back-Propagation . Advances in Neural Information Processing Systems . Morgan-Kaufmann . 1.
  7. LeCun . Yann . Denker . John . Solla . Sara . 1989 . Optimal Brain Damage . Advances in Neural Information Processing Systems . Morgan-Kaufmann . 2.
  8. Watkin . Timothy L. H. . Rau . Albrecht . Biehl . Michael . 1993-04-01 . The statistical mechanics of learning a rule . Reviews of Modern Physics . 65 . 2 . 499–556 . 10.1103/RevModPhys.65.499.
  9. Schwarze . H . Hertz . J . 1992-10-15 . Generalization in a Large Committee Machine . Europhysics Letters (EPL) . 20 . 4 . 375–380 . 10.1209/0295-5075/20/4/015 . 0295-5075.
  10. Mato . G . Parga . N . 1992-10-07 . Generalization properties of multilayered neural networks . Journal of Physics A: Mathematical and General . 25 . 19 . 5047–5054 . 10.1088/0305-4470/25/19/017 . 0305-4470.
  11. Hansel . D . Mato . G . Meunier . C . 1992-11-01 . Memorization Without Generalization in a Multilayered Neural Network . Europhysics Letters (EPL) . 20 . 5 . 471–476 . 10.1209/0295-5075/20/5/015 . 0295-5075.
  12. Schmidhuber . Jürgen . 1992 . [ftp://ftp.idsia.ch/pub/juergen/chunker.pdf Learning complex, extended sequences using the principle of history compression ]. Neural Computation . 4 . 2 . 234–242 . 10.1162/neco.1992.4.2.234 . 18271205 .
  13. Model compression. Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining. 2006. Buciluǎ. Cristian. Caruana. Rich. Niculescu-Mizil. Alexandru.
  14. 1805.01954 . Torabi . Faraz . Warnell . Garrett . Stone . Peter . Behavioral Cloning from Observation . 2018 . cs.AI .

External links