Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel


/**
Expand Down Expand Up @@ -250,10 +251,19 @@ class BisectingKMeans @Since("2.0.0") (

@Since("2.0.0")
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
fit(dataset, handlePersistence)
}

@Since("2.2.0")
protected def fit(dataset: Dataset[_], handlePersistence: Boolean): BisectingKMeansModel = {
transformSchema(dataset.schema, logging = true)
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
if (handlePersistence) {
rdd.persist(StorageLevel.MEMORY_AND_DISK)
}

val instr = Instrumentation.create(this, rdd)
instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
Expand All @@ -268,6 +278,9 @@ class BisectingKMeans @Since("2.0.0") (
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
if (handlePersistence) {
rdd.unpersist()
}
instr.logSuccess(model)
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class KMeans @Since("1.5.0") (

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
fit(dataset, handlePersistence)
}

Expand Down Expand Up @@ -330,10 +330,10 @@ class KMeans @Since("1.5.0") (
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer to keep this form according to style guide.

}
instr.logSuccess(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handlePersistence check in KMeans at L309 should also be updated to use dataset.storageLevel. Since we're touching KMeans here anyway we may as well do it now.

model
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.storage.StorageLevel

private[clustering] case class TestRow(features: Vector)

Expand Down Expand Up @@ -143,6 +144,14 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(model.getPredictionCol == predictionColName)
}

test("DataFrame storage level check") {
val df = KMeansSuite.generateKMeansData(spark, 5, 3, 2)
assert(df.storageLevel == StorageLevel.NONE)
df.persist(StorageLevel.MEMORY_AND_DISK)
assert(df.storageLevel != StorageLevel.NONE)
df.unpersist()
}

test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
Expand Down