diff --git a/crates/service/src/algorithms/clustering/elkan_k_means.rs b/crates/service/src/algorithms/clustering/elkan_k_means.rs index 494a10792..287f729da 100644 --- a/crates/service/src/algorithms/clustering/elkan_k_means.rs +++ b/crates/service/src/algorithms/clustering/elkan_k_means.rs @@ -83,6 +83,36 @@ impl ElkanKMeans { } } + /// Quick approach if we have little data + fn quick_centroids(&mut self) -> bool { + let c = self.c; + let samples = &self.samples; + let rand = &mut self.rand; + let centroids = &mut self.centroids; + let n = samples.len(); + let dims = samples.dims(); + let sorted_index = samples.argsort(); + for i in 0..n { + let index = sorted_index.get(i).unwrap(); + let last = sorted_index.get(std::cmp::max(i, 1) - 1).unwrap(); + if *index == 0 || samples[*last] != samples[*index] { + centroids[i].copy_from_slice(&samples[*index]); + } else { + let rand_centroids: Vec<_> = (0..dims) + .map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32))) + .collect(); + centroids[i].copy_from_slice(rand_centroids.as_slice()); + } + } + for i in n..c { + let rand_centroids: Vec<_> = (0..dims) + .map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32))) + .collect(); + centroids[i].copy_from_slice(rand_centroids.as_slice()); + } + true + } + pub fn iterate(&mut self) -> bool { let c = self.c; let dims = self.dims; @@ -94,6 +124,9 @@ impl ElkanKMeans { let upperbound = &mut self.upperbound; let mut change = 0; let n = samples.len(); + if n <= c { + return self.quick_centroids(); + } // Step 1 let mut dist0 = Square::new(c, c); diff --git a/crates/service/src/utils/vec2.rs b/crates/service/src/utils/vec2.rs index 318383b73..b671b681e 100644 --- a/crates/service/src/utils/vec2.rs +++ b/crates/service/src/utils/vec2.rs @@ -20,6 +20,11 @@ impl Vec2 { pub fn len(&self) -> usize { self.v.len() / self.dims as usize } + pub fn argsort(&self) -> Vec { + let mut index: Vec = (0..self.len()).collect(); + index.sort_by_key(|i| &self[*i]); + index + } pub fn copy_within(&mut self, i: usize, j: usize) { assert!(i < self.len() && j < self.len()); unsafe {