From a4c88ee89f175c0672eeebe004ef3ac87a7a464a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 26 Nov 2016 21:34:18 -0800 Subject: [PATCH 1/3] add checking and caching to bisecting kmeans --- .../apache/spark/ml/clustering/BisectingKMeans.scala | 11 +++++++++++ .../scala/org/apache/spark/ml/clustering/KMeans.scala | 4 +--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index c7a170ddc7351..fe36fdd0f0939 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.storage.StorageLevel /** @@ -255,10 +256,19 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + fit(dataset, handlePersistence) + } + + @Since("2.2.0") + protected def fit(dataset: Dataset[_], handlePersistence: Boolean): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } + if (handlePersistence) { + rdd.persist(StorageLevel.MEMORY_AND_DISK) + } val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) @@ -273,6 +283,7 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + if (handlePersistence) rdd.unpersist() instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index ad4f79a79c90f..2867b31a26855 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -334,10 +334,8 @@ class KMeans @Since("1.5.0") ( val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + if (handlePersistence) instances.unpersist() instr.logSuccess(model) - if (handlePersistence) { - instances.unpersist() - } model } From 0105220c27e6ddadef91b5ea2cb448865bfe79a6 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 28 Nov 2016 05:33:47 -0800 Subject: [PATCH 2/3] change storage check --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 6 ++++-- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index fe36fdd0f0939..75b67e6084841 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -256,7 +256,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + val handlePersistence = dataset.storageLevel == StorageLevel.NONE fit(dataset, handlePersistence) } @@ -283,7 +283,9 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - if (handlePersistence) rdd.unpersist() + if (handlePersistence) { + rdd.unpersist() + } instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 2867b31a26855..8c5bd70bdbb5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -334,7 +334,9 @@ class KMeans @Since("1.5.0") ( val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - if (handlePersistence) instances.unpersist() + if (handlePersistence) { + instances.unpersist() + } instr.logSuccess(model) model } From 678cd43321e98a20968e7d4ede561792d578dde4 Mon Sep 17 00:00:00 2001 From: Yuhao Date: Fri, 2 Dec 2016 10:44:20 -0800 Subject: [PATCH 3/3] kmeans cache --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 2 +- .../org/apache/spark/ml/clustering/KMeansSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 3194c35d5afa4..2a3a2454daf42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -302,7 +302,7 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + val handlePersistence = dataset.storageLevel == StorageLevel.NONE fit(dataset, handlePersistence) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c1b7242e11a8f..bdd23ce8dff74 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.storage.StorageLevel private[clustering] case class TestRow(features: Vector) @@ -143,6 +144,14 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("DataFrame storage level check") { + val df = KMeansSuite.generateKMeansData(spark, 5, 3, 2) + assert(df.storageLevel == StorageLevel.NONE) + df.persist(StorageLevel.MEMORY_AND_DISK) + assert(df.storageLevel != StorageLevel.NONE) + df.unpersist() + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters)