From ab9a7aff3c2685a5470cdcc44a7525f7f22c8321 Mon Sep 17 00:00:00 2001 From: dparthiban Date: Fri, 2 Sep 2022 00:35:49 -0400 Subject: [PATCH 1/2] initial commit --- .../linfa-clustering/src/k_means/algorithm.rs | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index 726b10ab8..b03450919 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -248,7 +248,8 @@ impl, T, D: Distance> self.init_method() .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); let mut converged_iter: Option = None; - for n_iter in 0..self.max_n_iterations() { + let mut n_iter = 0; + while n_iter < self.max_n_iterations() { update_memberships_and_dists( self.dist_fn(), ¢roids, @@ -262,7 +263,8 @@ impl, T, D: Distance> .dist_fn() .distance(centroids.view(), new_centroids.view()); centroids = new_centroids; - if distance < self.tolerance() { + n_iter += 1; + if distance < self.tolerance() || n_iter == self.max_n_iterations() { converged_iter = Some(n_iter); break; } @@ -851,6 +853,23 @@ mod tests { assert!(params.fit_with(None, &data).is_ok()); } + #[test] + fn test_max_n_iterations() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data.clone()); + // For data created using the above rng and seed, for 6 clusters, it would take 8 iterations to converge. + // However, when specifying max_n_iterations as 5, the algorithm should stop early gracefully. + let _model = KMeans::params_with(6, rng.clone(), L2Dist) + .n_runs(1) + .max_n_iterations(5) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("KMeans fitted"); + } + fn fittable, (), KMeansError>>(_: T) {} #[test] fn thread_rng_fittable() { From e8ad15f951ac3b31e7610303081a3987f8c3e021 Mon Sep 17 00:00:00 2001 From: dparthiban Date: Mon, 5 Sep 2022 23:11:04 -0400 Subject: [PATCH 2/2] use loop over while and related changes --- .../linfa-clustering/src/k_means/algorithm.rs | 47 ++++++++----------- .../linfa-clustering/src/k_means/errors.rs | 3 -- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index b03450919..bd8297326 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -236,20 +236,17 @@ impl, T, D: Distance> let mut min_inertia = F::infinity(); let mut best_centroids = None; - let mut best_iter = None; let mut memberships = Array1::zeros(n_samples); let mut dists = Array1::zeros(n_samples); let n_runs = self.n_runs(); for _ in 0..n_runs { - let mut inertia = min_inertia; let mut centroids = self.init_method() .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); - let mut converged_iter: Option = None; let mut n_iter = 0; - while n_iter < self.max_n_iterations() { + let inertia = loop { update_memberships_and_dists( self.dist_fn(), ¢roids, @@ -258,17 +255,15 @@ impl, T, D: Distance> &mut dists, ); let new_centroids = compute_centroids(¢roids, &observations, &memberships); - inertia = dists.sum(); let distance = self .dist_fn() .distance(centroids.view(), new_centroids.view()); centroids = new_centroids; n_iter += 1; if distance < self.tolerance() || n_iter == self.max_n_iterations() { - converged_iter = Some(n_iter); - break; + break dists.sum(); } - } + }; // We keep the centroids which minimize the inertia (defined as the sum of // the squared distances of the closest centroid for all observations) @@ -276,27 +271,23 @@ impl, T, D: Distance> if inertia < min_inertia { min_inertia = inertia; best_centroids = Some(centroids.clone()); - best_iter = converged_iter; } } - match best_iter { - Some(_n_iter) => match best_centroids { - Some(centroids) => { - let mut cluster_count = Array1::zeros(self.n_clusters()); - memberships - .iter() - .for_each(|&c| cluster_count[c] += F::one()); - Ok(KMeans { - centroids, - cluster_count, - inertia: min_inertia / F::cast(dataset.nsamples()), - dist_fn: self.dist_fn().clone(), - }) - } - _ => Err(KMeansError::InertiaError), - }, - None => Err(KMeansError::NotConverged), + match best_centroids { + Some(centroids) => { + let mut cluster_count = Array1::zeros(self.n_clusters()); + memberships + .iter() + .for_each(|&c| cluster_count[c] += F::one()); + Ok(KMeans { + centroids, + cluster_count, + inertia: min_inertia / F::cast(dataset.nsamples()), + dist_fn: self.dist_fn().clone(), + }) + } + _ => Err(KMeansError::InertiaError), } } } @@ -855,11 +846,11 @@ mod tests { #[test] fn test_max_n_iterations() { - let mut rng = Xoshiro256Plus::seed_from_u64(42); + let mut rng = Xoshiro256Plus::seed_from_u64(42); let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); let yt = function_test_1d(&xt); let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); - let dataset = DatasetBase::from(data.clone()); + let dataset = DatasetBase::from(data.clone()); // For data created using the above rng and seed, for 6 clusters, it would take 8 iterations to converge. // However, when specifying max_n_iterations as 5, the algorithm should stop early gracefully. let _model = KMeans::params_with(6, rng.clone(), L2Dist) diff --git a/algorithms/linfa-clustering/src/k_means/errors.rs b/algorithms/linfa-clustering/src/k_means/errors.rs index 49fae5695..bcc26b569 100644 --- a/algorithms/linfa-clustering/src/k_means/errors.rs +++ b/algorithms/linfa-clustering/src/k_means/errors.rs @@ -22,9 +22,6 @@ pub enum KMeansError { /// When inertia computation fails #[error("Fitting failed: No inertia improvement (-inf)")] InertiaError, - /// When fitting algorithm does not converge - #[error("Fitting failed: Did not converge. Try different init parameters or check for degenerate data.")] - NotConverged, #[error(transparent)] LinfaError(#[from] linfa::error::Error), }