Skip to content

Commit d399c4f

Browse files
committed
init
init init
1 parent 0af4fc8 commit d399c4f

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,17 @@ private[spark] abstract class DistanceMeasure extends Serializable {
117117
packedValues
118118
}
119119

120+
def findClosest(
121+
centers: Array[VectorWithNorm],
122+
statistics: Option[Array[Double]],
123+
point: VectorWithNorm): (Int, Double) = {
124+
if (statistics.nonEmpty) {
125+
findClosest(centers, statistics.get, point)
126+
} else {
127+
findClosest(centers, point)
128+
}
129+
}
130+
120131
/**
121132
* @return the index of the closest center to the given point, as well as the cost.
122133
*/
@@ -253,6 +264,11 @@ object DistanceMeasure {
253264
case _ => false
254265
}
255266
}
267+
268+
private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000
269+
270+
private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean =
271+
k.toLong * k * numFeatures < 1000000
256272
}
257273

258274
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,22 @@ class KMeans private (
269269

270270
instr.foreach(_.logNumFeatures(numFeatures))
271271

272-
val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L
272+
val shouldComputeStats =
273+
DistanceMeasure.shouldComputeStatistics(centers.length)
274+
val shouldComputeStatsLocally =
275+
DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures)
273276

274277
// Execute iterations of Lloyd's algorithm until converged
275278
while (iteration < maxIterations && !converged) {
276279
val bcCenters = sc.broadcast(centers)
277-
val stats = if (shouldDistributed) {
278-
distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
280+
val stats = if (shouldComputeStats) {
281+
if (shouldComputeStatsLocally) {
282+
Some(distanceMeasureInstance.computeStatistics(centers))
283+
} else {
284+
Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters))
285+
}
279286
} else {
280-
distanceMeasureInstance.computeStatistics(centers)
287+
None
281288
}
282289
val bcStats = sc.broadcast(stats)
283290

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
5252
@transient private lazy val statistics = if (clusterCenters == null) {
5353
null
5454
} else {
55-
distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)
55+
val k = clusterCenters.length
56+
val numFeatures = clusterCenters.head.size
57+
if (DistanceMeasure.shouldComputeStatistics(k) &&
58+
DistanceMeasure.shouldComputeStatisticsLocally(k, numFeatures)) {
59+
Some(distanceMeasureInstance.computeStatistics(clusterCentersWithNorm))
60+
} else {
61+
None
62+
}
5663
}
5764

5865
@Since("2.4.0")

0 commit comments

Comments
 (0)