diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index b03450919..ff2a40ee8 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -243,13 +243,11 @@ impl, T, D: Distance> let n_runs = self.n_runs(); for _ in 0..n_runs { - let mut inertia = min_inertia; let mut centroids = self.init_method() .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); - let mut converged_iter: Option = None; let mut n_iter = 0; - while n_iter < self.max_n_iterations() { + let (inertia, converged_iter) = loop { update_memberships_and_dists( self.dist_fn(), ¢roids, @@ -258,17 +256,15 @@ impl, T, D: Distance> &mut dists, ); let new_centroids = compute_centroids(¢roids, &observations, &memberships); - inertia = dists.sum(); let distance = self .dist_fn() .distance(centroids.view(), new_centroids.view()); centroids = new_centroids; n_iter += 1; if distance < self.tolerance() || n_iter == self.max_n_iterations() { - converged_iter = Some(n_iter); - break; + break (dists.sum(), Some(n_iter)); } - } + }; // We keep the centroids which minimize the inertia (defined as the sum of // the squared distances of the closest centroid for all observations)