diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala index 9ac473aabece..e4c29a789b52 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -117,6 +117,24 @@ private[spark] abstract class DistanceMeasure extends Serializable { packedValues } + /** + * @param centers the clustering centers + * @param statistics optional statistics to accelerate the computation, which should not + * change the result. + * @param point given point + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: Array[VectorWithNorm], + statistics: Option[Array[Double]], + point: VectorWithNorm): (Int, Double) = { + if (statistics.nonEmpty) { + findClosest(centers, statistics.get, point) + } else { + findClosest(centers, point) + } + } + /** * @return the index of the closest center to the given point, as well as the cost. */ @@ -253,6 +271,11 @@ object DistanceMeasure { case _ => false } } + + private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000 + + private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean = + k.toLong * k * numFeatures < 1000000 } private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 76e2928f1223..c140b1b9e091 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -269,15 +269,22 @@ class KMeans private ( instr.foreach(_.logNumFeatures(numFeatures)) - val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L + val shouldComputeStats = + DistanceMeasure.shouldComputeStatistics(centers.length) + val shouldComputeStatsLocally = + DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures) // Execute iterations of Lloyd's algorithm until converged while (iteration < maxIterations && !converged) { val bcCenters = sc.broadcast(centers) - val stats = if (shouldDistributed) { - distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters) + val stats = if (shouldComputeStats) { + if (shouldComputeStatsLocally) { + Some(distanceMeasureInstance.computeStatistics(centers)) + } else { + Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)) + } } else { - distanceMeasureInstance.computeStatistics(centers) + None } val bcStats = sc.broadcast(stats) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a24493bb7a8f..64b352157caf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -50,9 +50,16 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], // TODO: computation of statistics may take seconds, so save it to KMeansModel in training @transient private lazy val statistics = if (clusterCenters == null) { - null + None } else { - distanceMeasureInstance.computeStatistics(clusterCentersWithNorm) + val k = clusterCenters.length + val numFeatures = clusterCenters.head.size + if (DistanceMeasure.shouldComputeStatistics(k) && + DistanceMeasure.shouldComputeStatisticsLocally(k, numFeatures)) { + Some(distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)) + } else { + None + } } @Since("2.4.0")