diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index a97bd0fb16fd7..add8ee2a4ff8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) + + private var trainingSummary: Option[BisectingKMeansSummary] = None + + private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.1.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.1.0") + def summary: BisectingKMeansSummary = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { @@ -228,14 +252,22 @@ class BisectingKMeans @Since("2.0.0") ( case Row(point: Vector) => OldVectors.fromML(point) } + val instr = Instrumentation.create(this, rdd) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val bkm = new MLlibBisectingKMeans() .setK($(k)) .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) val parentModel = bkm.run(rdd) - val model = new BisectingKMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) + val summary = new BisectingKMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(summary) + val m = model.setSummary(summary) + instr.logSuccess(m) + m } @Since("2.0.0") @@ -251,3 +283,41 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { @Since("2.0.0") override def load(path: String): BisectingKMeans = super.load(path) } + + +/** + * :: Experimental :: + * Summary of BisectingKMeans. + * + * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.1.0") +@Experimental +class BisectingKMeansSummary private[clustering] ( + @Since("2.1.0") @transient val predictions: DataFrame, + @Since("2.1.0") val predictionCol: String, + @Since("2.1.0") val featuresCol: String, + @Since("2.1.0") val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.1.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of (number of data points in) each cluster. + */ + @Since("2.1.0") + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 4f7d4418a8d09..f2368a9f8dad5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -68,7 +68,7 @@ class BisectingKMeansSuite } } - test("fit & transform") { + test("fit, transform and summary") { val predictionColName = "bisecting_kmeans_prediction" val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) @@ -85,6 +85,22 @@ class BisectingKMeansSuite assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: BisectingKMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 04366f5250287..003fa6abf6597 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -70,7 +70,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext } } - test("fit, transform, and summary") { + test("fit, transform and summary") { val predictionColName = "gm_prediction" val probabilityColName = "gm_probability" val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c9ba5a288aadf..ca392653557c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit, transform, and summary") { + test("fit, transform and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset)