diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 4c20e6563bad1..aed09e0d62fd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.storage.StorageLevel /** @@ -250,10 +251,19 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + fit(dataset, handlePersistence) + } + + @Since("2.2.0") + protected def fit(dataset: Dataset[_], handlePersistence: Boolean): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } + if (handlePersistence) { + rdd.persist(StorageLevel.MEMORY_AND_DISK) + } val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) @@ -268,6 +278,9 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + if (handlePersistence) { + rdd.unpersist() + } instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e168a418cbb42..2a3a2454daf42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -302,7 +302,7 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + val handlePersistence = dataset.storageLevel == StorageLevel.NONE fit(dataset, handlePersistence) } @@ -330,10 +330,10 @@ class KMeans @Since("1.5.0") ( val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - instr.logSuccess(model) if (handlePersistence) { instances.unpersist() } + instr.logSuccess(model) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c1b7242e11a8f..bdd23ce8dff74 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.storage.StorageLevel private[clustering] case class TestRow(features: Vector) @@ -143,6 +144,14 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("DataFrame storage level check") { + val df = KMeansSuite.generateKMeansData(spark, 5, 3, 2) + assert(df.storageLevel == StorageLevel.NONE) + df.persist(StorageLevel.MEMORY_AND_DISK) + assert(df.storageLevel != StorageLevel.NONE) + df.unpersist() + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters)