Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
Expand Down Expand Up @@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] (

@Since("2.0.0")
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)

private var trainingSummary: Option[BisectingKMeansSummary] = None

private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = {
this.trainingSummary = Some(summary)
this
}

/**
* Return true if there exists summary of model.
*/
@Since("2.1.0")
def hasSummary: Boolean = trainingSummary.nonEmpty

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.1.0")
def summary: BisectingKMeansSummary = trainingSummary.getOrElse {
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}
}

object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
Expand Down Expand Up @@ -228,14 +252,22 @@ class BisectingKMeans @Since("2.0.0") (
case Row(point: Vector) => OldVectors.fromML(point)
}

val instr = Instrumentation.create(this, rdd)
instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)

val bkm = new MLlibBisectingKMeans()
.setK($(k))
.setMaxIterations($(maxIter))
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
.setSeed($(seed))
val parentModel = bkm.run(rdd)
val model = new BisectingKMeansModel(uid, parentModel)
copyValues(model.setParent(this))
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a superfluous call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I will del this line ASAP. Thanks for your comment!

val m = model.setSummary(summary)
instr.logSuccess(m)
m
}

@Since("2.0.0")
Expand All @@ -251,3 +283,41 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
@Since("2.0.0")
override def load(path: String): BisectingKMeans = super.load(path)
}


/**
* :: Experimental ::
* Summary of BisectingKMeans.
*
* @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]]
* @param predictionCol Name for column of predicted clusters in `predictions`
* @param featuresCol Name for column of features in `predictions`
* @param k Number of clusters
*/
@Since("2.1.0")
@Experimental
class BisectingKMeansSummary private[clustering] (
@Since("2.1.0") @transient val predictions: DataFrame,
@Since("2.1.0") val predictionCol: String,
@Since("2.1.0") val featuresCol: String,
@Since("2.1.0") val k: Int) extends Serializable {

/**
* Cluster centers of the transformed data.
*/
@Since("2.1.0")
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)

/**
* Size of (number of data points in) each cluster.
*/
@Since("2.1.0")
lazy val clusterSizes: Array[Long] = {
val sizes = Array.fill[Long](k)(0)
cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
case Row(cluster: Int, count: Long) => sizes(cluster) = count
}
sizes
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class BisectingKMeansSuite
}
}

test("fit & transform") {
test("fit, transform and summary") {
val predictionColName = "bisecting_kmeans_prediction"
val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = bkm.fit(dataset)
Expand All @@ -85,6 +85,22 @@ class BisectingKMeansSuite
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)

// Check validity of model summary
val numRows = dataset.count()
assert(model.hasSummary)
val summary: BisectingKMeansSummary = model.summary
assert(summary.predictionCol === predictionColName)
assert(summary.featuresCol === "features")
assert(summary.predictions.count() === numRows)
for (c <- Array(predictionColName, "features")) {
assert(summary.predictions.columns.contains(c))
}
assert(summary.cluster.columns === Array(predictionColName))
val clusterSizes = summary.clusterSizes
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))
}

test("read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
}
}

test("fit, transform, and summary") {
test("fit, transform and summary") {
val predictionColName = "gm_prediction"
val probabilityColName = "gm_probability"
val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}

test("fit, transform, and summary") {
test("fit, transform and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
Expand Down