This is a course project to select a subset of data to build an efficient nearest neighbor classifier. The paper can be found here: Latent KMeans Prototype Selection
Choosing a representative subset of "prototypes" from the training set is crucial for accelerating nearest neighbor classifiers. This project proposes projecting the data into a latent space using a pretrained embedder, applying in-class KMeans Clustering to the latent representations and taking the clustering centers as the selected prototypes. Experiments demonstrate that the performance of the nearest neighbor classifier with latent KMeans clustering prototypes significantly outperforms other baselines, even with a training set containing as few as 10 points. An ablation study on the embedder's pretraining domain reveals taht using in-domain data for training greatly affects the quality of the representations.
The performance of nearest neighbor classifier is heavily dependent on the distribution of the training data. Original data, such as image data, often contains a large amount of high-frequency noise, which can impair the effectiveness of the classifier. The perceptual distance in the latent space has been shown to offer advantages over per-pixel distance in the original data space~\citep{johnson2016perceptual}, so it is beneficial to project the data into a meaningful latent space using a pretrained embedder. However, regardless of the space in which the data resides, there can be multiple outliers in each class that would directly affect the nearest neighbor classifier's accuracy. Therefore, this project proposes using the centers derived from latent KMeans clustering as the training data for the nearest neighbor classifier. First, the entire dataset is projected into the latent space using an embedder built from a pretrained classifier. Then, for each class, KMeans Clustering is applied to represent the data points with
Prototype Selection Algorithm calculates the KMeans Clustering centers on each class' embedding data. The original training set is
It can be observed in Figure1 that the latent KMeans clustering using MNIST embedder method (LK-M) beats or matches all other methods with nearly 100% accuracy across all experimented
If you think the code useful or like the repo, please kindly star this repo. If you use the code in the repo, please cite this repo.
Wang, Xinyuan. (Feb 2024). Latent KMeans Prototype Selection.
https://github.com/XinyuanWangCS/LatentKMeansPrototypeSelection