diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a0d481b294ac..fe357accdc56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.clustering +import scala.util.{Failure, Success} + import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} @@ -35,7 +38,25 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} /** - * Common params for KMeans and KMeansModel + * Params for KMeans + */ + +private[clustering] trait KMeansInitialModelParams extends HasInitialModel[KMeansModel] { + /** + * Param for KMeansModel to use for warm start. + * Whenever initialModel is set: + * 1. the initialModel k will override the param k; + * 2. the param initMode is set to initialModel and manually set is ignored; + * 3. other params are untouched. + * @group param + */ + final val initialModel: Param[KMeansModel] = + new Param[KMeansModel](this, "initialModel", "A KMeansModel for warm start.") + +} + +/** + * Params for KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol with HasTol { @@ -58,6 +79,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * The param initMode will be ignored if the param initialModel is set. * @group expertParam */ @Since("1.5.0") @@ -82,6 +104,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** * Validates and transforms the input schema. * @param schema input schema @@ -103,7 +126,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) + private[ml] val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") @@ -124,7 +147,8 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predictUDF = udf((vector: Vector) => predict(vector)) + val tmpParent: MLlibKMeansModel = parentModel + val predictUDF = udf((vector: Vector) => tmpParent.predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -133,8 +157,6 @@ class KMeansModel private[ml] ( validateAndTransformSchema(schema) } - private[clustering] def predict(features: Vector): Int = parentModel.predict(features) - @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) @@ -210,6 +232,7 @@ object KMeansModel extends MLReadable[KMeansModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => Data(idx, center) @@ -244,6 +267,7 @@ object KMeansModel extends MLReadable[KMeansModel] { } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) + model } } @@ -259,7 +283,8 @@ object KMeansModel extends MLReadable[KMeansModel] { @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { + extends Estimator[KMeansModel] + with KMeansParams with KMeansInitialModelParams with DefaultParamsWritable { setDefault( k -> 2, @@ -284,11 +309,26 @@ class KMeans @Since("1.5.0") ( /** @group setParam */ @Since("1.5.0") - def setK(value: Int): this.type = set(k, value) + def setK(value: Int): this.type = { + if (isSet(initialModel)) { + logWarning("initialModel is set, so k will be ignored. Clear initialModel first.") + this + } else { + set(k, value) + } + } /** @group expertSetParam */ @Since("1.5.0") - def setInitMode(value: String): this.type = set(initMode, value) + def setInitMode(value: String): this.type = { + if (isSet(initialModel)) { + logWarning(s"initialModel is set, so initMode will be ignored. Clear initialModel first.") + } + if (value == MLlibKMeans.K_MEANS_INITIAL_MODEL) { + logWarning(s"initMode of $value is not supported here, please use setInitialModel.") + } + set(initMode, value) + } /** @group expertSetParam */ @Since("1.5.0") @@ -306,6 +346,25 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.1.0") + def setInitialModel(value: KMeansModel): this.type = { + val kOfInitialModel = value.parentModel.clusterCenters.length + if (isSet(k)) { + if ($(k) != kOfInitialModel) { + val previousK = $(k) + set(k, kOfInitialModel) + logWarning(s"Param K is set to $kOfInitialModel by the initialModel." + + s" Previous value is $previousK.") + } + } else { + set(k, kOfInitialModel) + logWarning(s"Param K is set to $kOfInitialModel by the initialModel.") + } + set(initMode, "initialModel") + set(initialModel, value) + } + @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { transformSchema(dataset.schema, logging = true) @@ -323,6 +382,24 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + + if (isDefined(initialModel)) { + // Check that the feature dimensions are equal + val dimOfData = rdd.first().size + val dimOfInitialModel = $(initialModel).clusterCenters.head.size + require(dimOfData == dimOfInitialModel, + s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") + + // Check that the number of clusters are equal + val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length + if (kOfInitialModel != $(k)) { + logWarning(s"mismatched cluster count, ${$(k)} cluster centers required but" + + s" $kOfInitialModel found in the initial model.") + } + + algo.setInitialModel($(initialModel).parentModel) + } + val parentModel = algo.run(rdd, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( @@ -336,13 +413,48 @@ class KMeans @Since("1.5.0") ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + @Since("2.1.0") + override def write: MLWriter = new KMeans.KMeansWriter(this) } @Since("1.6.0") object KMeans extends DefaultParamsReadable[KMeans] { + // TODO: [SPARK-17784]: Add a fromCenters method + @Since("1.6.0") override def load(path: String): KMeans = super.load(path) + + @Since("1.6.0") + override def read: MLReader[KMeans] = new KMeansReader + + /** [[MLWriter]] instance for [[KMeans]] */ + private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter { + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveInitialModel(instance, path) + DefaultParamsWriter.saveMetadata(instance, path, sc) + } + } + + private class KMeansReader extends MLReader[KMeans] { + + /** Checked against metadata when loading estimator */ + private val className = classOf[KMeans].getName + + override def load(path: String): KMeans = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val instance = new KMeans(metadata.uid) + + DefaultParamsReader.getAndSetParams(instance, metadata) + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { + case Success(v) => instance.setInitialModel(v) + case Failure(_: InvalidInputException) => // initialModel doesn't exist, do nothing + case Failure(e) => throw e + } + instance + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala new file mode 100644 index 000000000000..c67380edaa60 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param._ + +private[ml] trait HasInitialModel[T <: Model[T]] extends Params { + + def initialModel: Param[T] + + /** @group getParam */ + final def getInitialModel: T = $(initialModel) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc4f9e6716ee..17483a17eb6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.io.IOException +import scala.util.Try + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.{DefaultFormats, JObject} @@ -32,6 +34,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.param.shared.HasInitialModel import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils @@ -300,7 +303,8 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.extractParamMap().toSeq + .filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]] val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) @@ -309,6 +313,7 @@ private[ml] object DefaultParamsWriter { ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) + val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -318,6 +323,20 @@ private[ml] object DefaultParamsWriter { val metadataJson: String = compact(render(metadata)) metadataJson } + + def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]]( + instance: T, path: String): Unit = { + if (instance.isDefined(instance.initialModel)) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = instance.getOrDefault(instance.initialModel) + // When saving, only keep the direct initialModel by eliminating possible initialModels of the + // direct initialModel, to avoid unnecessary deep recursion of initialModel. + if (initialModel.hasParam("initialModel")) { + initialModel.clear(initialModel.getParam("initialModel")) + } + initialModel.save(initialModelPath) + } + } } /** @@ -446,6 +465,11 @@ private[ml] object DefaultParamsReader { val cls = Utils.classForName(metadata.className) cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } + + def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Try[M] = { + val initialModelPath = new Path(path, "initialModel").toString + Try(loadParamsInstance[M](initialModelPath, sc)) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ed9c064879d0..60ebab39186e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -414,6 +414,8 @@ object KMeans { val RANDOM = "random" @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" + @Since("2.1.0") + val K_MEANS_INITIAL_MODEL = "initialModel" /** * Trains a k-means model using the given set of parameters. @@ -589,6 +591,7 @@ object KMeans { initMode match { case KMeans.RANDOM => true case KMeans.K_MEANS_PARALLEL => true + case KMeans.K_MEANS_INITIAL_MODEL => true case _ => false } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 73972557d263..0783c9953902 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.ml.clustering +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -29,13 +33,14 @@ private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - final val k = 5 + final val k: Int = 5 + final val dim: Int = 3 @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, dim, k) } test("default parameters") { @@ -145,18 +150,67 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData, + Map("initialModel" -> (checkModelData _).asInstanceOf[(Any, Any) => Unit])) + } + + test("Initialize using a trained model") { + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + val oneIterModel = kmeans.fit(dataset) + val twoIterModel = kmeans.copy(ParamMap(ParamPair(kmeans.maxIter, 2))).fit(dataset) + val oneMoreIterModel = kmeans.setInitialModel(oneIterModel).fit(dataset) + + assert(oneMoreIterModel.getK === k) + + twoIterModel.clusterCenters.zip(oneMoreIterModel.clusterCenters) + .foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } + } + + test("Initialize using a model with wrong dimension of cluster centers") { + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + + val wrongDimModel = KMeansSuite.generateRandomKMeansModel(4, k) + val wrongDimModelThrown = intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongDimModel).fit(dataset) + } + assert(wrongDimModelThrown.getMessage.contains("mismatched dimension")) + } + + test("Infer K from an initial model") { + val kmeans = new KMeans().setK(5) + val testNewK = 10 + val randomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK) + assert(kmeans.setInitialModel(randomModel).getK === testNewK) + } + + test("Ignore k if initialModel is set") { + val kmeans = new KMeans() + + val randomModel = KMeansSuite.generateRandomKMeansModel(dim, k) + // ignore k if initialModel is set + assert(kmeans.setInitialModel(randomModel).setK(k - 1).getK === k) + kmeans.clear(kmeans.initialModel) + // k is not ignored after initialModel is cleared + assert(kmeans.setK(k - 1).getK === k - 1) } } object KMeansSuite { + def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = { val sc = spark.sparkContext val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) + .map(v => TestRow(v)) spark.createDataFrame(rdd) } + def generateRandomKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = { + val rng = new Random(seed) + val clusterCenters = (1 to k) + .map(i => MLlibVectors.dense(Array.fill(dim)(rng.nextDouble))) + new KMeansModel("test model", new MLlibKMeansModel(clusterCenters.toArray)) + } + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. @@ -166,6 +220,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "initialModel" -> generateRandomKMeansModel(3, 3) ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a..5f1ef837b527 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -78,6 +78,34 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => another } + /** + * Compare Params with complex types that could not compare with [[===]]. + * + * @param stage A pipeline stage contains these params. + * @param stage2 Another pipeline stage to compare. + * @param testParams Params to compare. + * @param testFunctions Functions to compare complex type params. + */ + def compareParamsWithComplexTypes( + stage: PipelineStage, + stage2: PipelineStage, + testParams: Map[String, Any], + testFunctions: Map[String, (Any, Any) => Unit]): Unit = { + testParams.foreach { case (p, v) => + if (stage.hasParam(p)) { + assert(stage2.hasParam(p)) + val param = stage.getParam(p) + val paramVal = stage.get(param).get + val paramVal2 = stage2.get(param).get + if (testFunctions.contains(p)) { + testFunctions(p)(paramVal, paramVal2) + } else { + assert(paramVal === paramVal2) + } + } + } + } + /** * Default test for Estimator, Model pairs: * - Explicitly set Params, and train model @@ -100,7 +128,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => estimator: E, dataset: Dataset[_], testParams: Map[String, Any], - checkModelData: (M, M) => Unit): Unit = { + checkModelData: (M, M) => Unit, + checkParamsFunctions: Map[String, (Any, Any) => Unit] = Map.empty): Unit = { // Set some Params to make sure set Params are serialized. testParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) @@ -108,18 +137,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val model = estimator.fit(dataset) // Test Estimator save/load - val estimator2 = testDefaultReadWrite(estimator) - testParams.foreach { case (p, v) => - val param = estimator.getParam(p) - assert(estimator.get(param).get === estimator2.get(param).get) - } + val estimator2 = testDefaultReadWrite(estimator, testParams = false) + compareParamsWithComplexTypes(estimator, estimator2, testParams, checkParamsFunctions) // Test Model save/load - val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) - } + val model2 = testDefaultReadWrite(model, testParams = false) + compareParamsWithComplexTypes(model, model2, testParams, checkParamsFunctions) checkModelData(model, model2) }