Large margin nearest neighbor (LMNN)[1] classification is a statistical machine learning algorithm for metric learning. It learns a pseudometric designed for k-nearest neighbor classification. The algorithm is based on semidefinite programming, a sub-class of convex optimization.
The goal of supervised learning (more specifically classification) is to learn a decision rule that can categorize data instances into pre-defined classes. The k-nearest neighbor rule assumes a training data set of labeled instances (i.e. the classes are known). It classifies a new data instance with the class obtained from the majority vote of the k closest (labeled) training instances. Closeness is measured with a pre-defined metric. Large margin nearest neighbors is an algorithm that learns this global (pseudo-)metric in a supervised fashion to improve the classification accuracy of the k-nearest neighbor rule.
The main intuition behind LMNN is to learn a pseudometric under which all data instances in the training set are surrounded by at least k instances that share the same class label. If this is achieved, the leave-one-out error (a special case of cross validation) is minimized. Let the training data consist of a data set
D=\{(\vecx1,y1),...,(\vecxn,yn)\}\subsetRd x C
C=\{1,...,c\}
The algorithm learns a pseudometric of the type
d(\vecxi,\vecxj)=(\vecxi-\vec
\topM(\vec | |
x | |
j) |
xi-\vecxj)
d( ⋅ , ⋅ )
M
M
Figure 1 illustrates the effect of the metric under varying
M
\vecxi
The algorithm distinguishes between two types of special data points: target neighbors and impostors.
Target neighbors are selected before learning. Each instance
\vecxi
k
D
yi
\vecxi
Ni
An impostor of a data point
\vecxi
\vecxj
yi ≠ yj
\vecxi
Large margin nearest neighbors optimizes the matrix
M
\vecxi
\vecxi
k=3
The first optimization goal is achieved by minimizing the average distance between instances and their target neighbors
\sum | |
i,j\inNi |
d(\vecxi,\vecxj)
The second goal is achieved by penalizing distances to impostors
\vecxl
\vecxj
\vecxi
\sum | |
i,j\inNi,l,yl ≠ yi |
[d(\vecxi,\vecxj)+1-d(\vecxi,\vecxl)]+
With a hinge loss function , which ensures that impostor proximity is not penalized when outside the margin. The margin of exactly one unit fixes the scale of the matrix
M
c>0
M
1/c
The final optimization problem becomes:
minM
\sum | |
i,j\inNi |
d(\vecxi,\vecxj)+λ\sumi,j,l\xiijl
\forall | |
i,j\inNi,l,yl ≠ yi |
d(\vecxi,\vecxj)+1-d(\vecxi,\vecxl)\leq\xiijl
\xiijl\geq0
M\succeq0
The hyperparameter is some positive constant (typically set through cross-validation). Here the variables
\xiijl
M
\xiijl
LMNN was extended to multiple local metrics in the 2008 paper.[2] This extension significantly improves the classification error, but involves a more expensive optimization problem. In their 2009 publication in the Journal of Machine Learning Research,[3] Weinberger and Saul derive an efficient solver for the semi-definite program. It can learn a metric for the MNIST handwritten digit data set in several hours, involving billions of pairwise constraints. An open source Matlab implementation is freely available at the authors web page.
Kumal et al.[4] extended the algorithm to incorporate local invariances to multivariate polynomial transformations and improved regularization.