Skip to content
19 changes: 16 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ class KMeansModel private[ml] (
/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
*
* @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator
* instead. You can also get the cost on the training dataset in the summary.
*/
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
@deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " +
"instead. You can also get the cost on the training dataset in the summary.", "2.4.0")
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
Expand Down Expand Up @@ -356,7 +360,12 @@ class KMeans @Since("1.5.0") (
val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k), parentModel.numIter)
model.transform(dataset),
$(predictionCol),
$(featuresCol),
$(k),
parentModel.numIter,
parentModel.trainingCost)

model.setSummary(Some(summary))
instr.logNamedValue("clusterSizes", summary.clusterSizes)
Expand Down Expand Up @@ -389,6 +398,8 @@ object KMeans extends DefaultParamsReadable[KMeans] {
* @param featuresCol Name for column of features in `predictions`.
* @param k Number of clusters.
* @param numIter Number of iterations.
* @param trainingCost K-means cost (sum of squared distances to the nearest centroid for all
* points in the training dataset). This is equivalent to sklearn's inertia.
*/
@Since("2.0.0")
@Experimental
Expand All @@ -397,4 +408,6 @@ class KMeansSummary private[clustering] (
predictionCol: String,
featuresCol: String,
k: Int,
numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter)
numIter: Int,
@Since("2.4.0") val trainingCost: Double)
extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter)
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ class KMeans private (

logInfo(s"The cost is $cost.")

new KMeansModel(centers.map(_.vector), distanceMeasure, iteration)
new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.{Row, SparkSession}
@Since("0.8.0")
class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
@Since("2.4.0") val distanceMeasure: String,
@Since("2.4.0") val trainingCost: Double,
private[spark] val numIter: Int)
extends Saveable with Serializable with PMMLExportable {

Expand All @@ -49,11 +50,11 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],

@Since("2.4.0")
private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) =
this(clusterCenters: Array[Vector], distanceMeasure, -1)
this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1)

@Since("1.1.0")
def this(clusterCenters: Array[Vector]) =
this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN)
this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN, 0.0, -1)

/**
* A Java-friendly constructor that takes an Iterable of Vectors.
Expand Down Expand Up @@ -187,7 +188,8 @@ object KMeansModel extends Loader[KMeansModel] {
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure)))
~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure)
~ ("trainingCost" -> model.trainingCost)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
Cluster(id, p.vector)
Expand All @@ -207,7 +209,8 @@ object KMeansModel extends Loader[KMeansModel] {
val localCentroids = centroids.rdd.map(Cluster.apply).collect()
assert(k == localCentroids.length)
val distanceMeasure = (metadata \ "distanceMeasure").extract[String]
new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure)
val trainingCost = (metadata \ "trainingCost").extract[Double]
new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure, trainingCost, -1)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
assert(summary.predictions.columns.contains(c))
}
assert(summary.cluster.columns === Array(predictionColName))
assert(summary.trainingCost < 0.1)
assert(model.computeCost(dataset) == summary.trainingCost)
val clusterSizes = summary.clusterSizes
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
Expand Down
19 changes: 18 additions & 1 deletion python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import sys
import warnings

from pyspark import since, keyword_only
from pyspark.ml.util import *
Expand Down Expand Up @@ -303,7 +304,15 @@ class KMeansSummary(ClusteringSummary):

.. versionadded:: 2.1.0
"""
pass

@property
@since("2.4.0")
def trainingCost(self):
"""
K-means cost (sum of squared distances to the nearest centroid for all points in the
training dataset). This is equivalent to sklearn's inertia.
"""
return self._call_java("trainingCost")


class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
Expand All @@ -323,7 +332,13 @@ def computeCost(self, dataset):
"""
Return the K-means cost (sum of squared distances of points to their nearest center)
for this model on the given data.

..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead.
You can also get the cost on the training dataset in the summary.
"""
warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator "
"instead. You can also get the cost on the training dataset in the summary.",
DeprecationWarning)
return self._call_java("computeCost", dataset)

@property
Expand Down Expand Up @@ -379,6 +394,8 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
2
>>> summary.clusterSizes
[2, 2]
>>> summary.trainingCost
2.000...
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
Expand Down