From bb990f1a6511d8ce20f4fff254dfe0ff43262a10 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 Jan 2018 03:51:59 +0000 Subject: [PATCH 01/20] Add multi-column support to StringIndexer. --- .../fulltests/test_mllib_classification.R | 6 +- .../spark/ml/feature/StringIndexer.scala | 295 +++++++++++++----- .../org/apache/spark/ml/param/params.scala | 10 + .../spark/ml/feature/RFormulaSuite.scala | 24 +- .../spark/ml/feature/StringIndexerSuite.scala | 79 ++++- project/MimaExcludes.scala | 10 +- python/pyspark/ml/classification.py | 9 +- 7 files changed, 336 insertions(+), 97 deletions(-) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index ad47717ddc12..b88c487964f7 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -313,7 +313,7 @@ test_that("spark.mlp", { # Test predict method mlpTestDF <- df mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) + expect_equal(head(mlpPredictions$prediction, 6), c("0.0", "1.0", "1.0", "1.0", "1.0", "1.0")) # Test model save/load if (windows_with_hadoop()) { @@ -348,12 +348,12 @@ test_that("spark.mlp", { # Test random seed # default seed - model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 100) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # seed equals 10 - model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 100, seed = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab7..cdb81a31b8e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -26,18 +26,19 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.VersionUtils.majorMinorVersion import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol - with HasOutputCol { + with HasOutputCol with HasInputCols with HasOutputCols { /** * Param for how to handle invalid data (unseen labels or NULL values). @@ -79,26 +80,59 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi @Since("2.3.0") def getStringOrderType: String = $(stringOrderType) - /** Validates and transforms the input schema. */ - protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputColName = $(inputCol) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols") + ParamValidators.checkExclusiveParams(this, "inputCol", "outputCols") + ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols") + ParamValidators.checkExclusiveParams(this, "inputCols", "outputCol") + + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "The number of input columns does not match output columns") + ($(inputCols), $(outputCols)) + } + } + + private def validateAndTransformField( + schema: StructType, + inputColName: String, + outputColName: String): StructField = { val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be either string type or numeric type, " + s"but got $inputDataType.") - val inputFields = schema.fields - val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), + require(schema.fields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() - StructType(outputFields) + NominalAttribute.defaultAttr.withName($(outputCol)).toStructField() + } + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema( + schema: StructType, + skipNonExistsCol: Boolean = false): StructType = { + val (inputColNames, outputColNames) = getInOutCols() + + val outputFields = for (i <- 0 until inputColNames.length) yield { + if (schema.fieldNames.contains(inputColNames(i))) { + validateAndTransformField(schema, inputColNames(i), outputColNames(i)) + } else { + if (skipNonExistsCol) { + null + } else { + throw new SparkException(s"Input column ${inputColNames(i)} does not exist.") + } + } + } + StructType(schema.fields ++ outputFields.filter(_ != null)) } } /** - * A label indexer that maps a string column of labels to an ML column of label indices. - * If the input column is numeric, we cast it to string and index the string values. + * A label indexer that maps string column(s) of labels to ML column(s) of label indices. + * If the input columns are numeric, we cast them to string and index the string values. * The indices are in [0, numLabels). By default, this is ordered by label frequencies * so the most frequent label gets index 0. The ordering behavior is controlled by * setting `stringOrderType`. @@ -130,21 +164,53 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val values = dataset.na.drop(Array($(inputCol))) - .select(col($(inputCol)).cast(StringType)) - .rdd.map(_.getString(0)) - val labels = $(stringOrderType) match { - case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) - .map(_._1).toArray - case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) - .map(_._1).toArray - case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) - case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) - } - copyValues(new StringIndexerModel(uid, labels).setParent(this)) + + val (inputCols, _) = getInOutCols() + val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) + + // Counts by the string values in the dataset. + val countByValueArray = dataset.na.drop(inputCols) + .select(inputCols.map(col(_).cast(StringType)): _*) + .rdd.treeAggregate(zeroState)( + (state: Array[OpenHashMap[String, Long]], row: Row) => { + for (i <- 0 until inputCols.length) { + state(i).changeValue(row.getString(i), 1L, _ + 1) + } + state + }, + (state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => { + for (i <- 0 until inputCols.length) { + state2(i).foreach { case (key: String, count: Long) => + state1(i).changeValue(key, count, _ + count) + } + } + state1 + } + ) + + // In case of equal frequency when frequencyDesc/Asc, we further sort the strings by alphabet. + val labelsArray = countByValueArray.map { countByValue => + $(stringOrderType) match { + case StringIndexer.frequencyDesc => + countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray + case StringIndexer.frequencyAsc => + countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray + case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray + case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray + } + } + copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) } @Since("1.4.0") @@ -177,32 +243,47 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * @param labels Ordered list of labels, corresponding to indices to be assigned. + * @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned + * for each input column. * - * @note During transformation, if the input column does not exist, - * `StringIndexerModel.transform` would return the input dataset unmodified. + * @note During transformation, if any input column does not exist, + * `StringIndexerModel.transform` would skip the input column. + * If all input columns do not exist, it returns the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. */ @Since("1.4.0") class StringIndexerModel ( @Since("1.4.0") override val uid: String, - @Since("1.5.0") val labels: Array[String]) + @Since("2.3.0") val labelsArray: Array[Array[String]]) extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @Since("1.5.0") - def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) - - private val labelToIndex: OpenHashMap[String, Double] = { - val n = labels.length - val map = new OpenHashMap[String, Double](n) - var i = 0 - while (i < n) { - map.update(labels(i), i) - i += 1 + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), Array(labels)) + + @Since("2.3.0") + def this(labelsArray: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labelsArray) + + @Since("1.5.0") + def labels: Array[String] = { + require(labelsArray.length == 1, "This StringIndexerModel is fitted by multi-columns, " + + "call for `labelsArray` instead.") + labelsArray(0) + } + + // Prepares the maps for string values to corresponding index values. + private val labelsToIndexArray: Array[OpenHashMap[String, Double]] = { + for (labels <- labelsArray) yield { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map } - map } /** @group setParam */ @@ -217,33 +298,32 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset.toDF - } - transformSchema(dataset.schema, logging = true) + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) - val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_INVALID => labels :+ "__unknown" - case _ => labels + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private def filterInvalidData(dataset: Dataset[_], inputColNames: Seq[String]): Dataset[_] = { + var filteredDataset = dataset.na.drop(inputColNames.filter( + dataset.schema.fieldNames.contains(_))) + for (i <- 0 until inputColNames.length) { + val inputColName = inputColNames(i) + val labelToIndex = labelsToIndexArray(i) + val filterer = udf { label: String => + labelToIndex.contains(label) + } + filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) } + filteredDataset + } - val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(filteredLabels).toMetadata() - // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_INVALID => - val filterer = udf { label: String => - labelToIndex.contains(label) - } - (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) - } + private def getIndexer(labels: Seq[String], labelToIndex: OpenHashMap[String, Double]) = { + val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) - val indexer = udf { label: String => + udf { label: String => if (label == null) { if (keepInvalid) { labels.length @@ -257,29 +337,72 @@ class StringIndexerModel ( } else if (keepInvalid) { labels.length } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") } } }.asNondeterministic() + } + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + var (inputColNames, outputColNames) = getInOutCols() + val outputColumns = new Array[Column](outputColNames.length) + + // Skips invalid rows if `handleInvalid` is set to `StringIndexer.SKIP_INVALID`. + val filteredDataset = if (getHandleInvalid == StringIndexer.SKIP_INVALID) { + filterInvalidData(dataset, inputColNames) + } else { + dataset + } + + for (i <- 0 until outputColNames.length) { + val inputColName = inputColNames(i) + val outputColName = outputColNames(i) + val labelToIndex = labelsToIndexArray(i) + val labels = labelsArray(i) + + if (!dataset.schema.fieldNames.contains(inputColName)) { + logInfo(s"Input column ${inputColName} does not exist during transformation. " + + "Skip StringIndexerModel for this column.") + outputColNames(i) = null + } else { + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_INVALID => labels :+ "__unknown" + case _ => labels + } + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(filteredLabels).toMetadata() + val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) + + val indexer = getIndexer(labels, labelToIndex) + + outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) + .as(outputColName, metadata) + } + } + + val filteredOutputColNames = outputColNames.filter(_ != null) + val filteredOutputColumns = outputColumns.filter(_ != null) - filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) + require(filteredOutputColNames.length == filteredOutputColumns.length) + if (filteredOutputColNames.length > 0) { + filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns) + } else { + filteredDataset.toDF() + } } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains($(inputCol))) { - validateAndTransformSchema(schema) - } else { - // If the input column does not exist during transformation, we skip StringIndexerModel. - schema - } + validateAndTransformSchema(schema, skipNonExistsCol = true) } @Since("1.4.1") override def copy(extra: ParamMap): StringIndexerModel = { - val copied = new StringIndexerModel(uid, labels) + val copied = new StringIndexerModel(uid, labelsArray) copyValues(copied, extra).setParent(parent) } @@ -293,11 +416,11 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { - private case class Data(labels: Array[String]) + private case class Data(labelsArray: Array[Array[String]]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.labels) + val data = Data(instance.labelsArray) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -310,11 +433,23 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("labels") - .head() - val labels = data.getAs[Seq[String]](0).toArray - val model = new StringIndexerModel(metadata.uid, labels) + + val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) + val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { + // Spark 2.2 and before. + val data = sparkSession.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + Array(labels) + } else { + // After Spark 2.3. + val data = sparkSession.read.parquet(dataPath) + .select("labelsArray") + .head() + data.getAs[Seq[Seq[String]]](0).map(_.toArray).toArray + } + val model = new StringIndexerModel(metadata.uid, labelsArray) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa..461c3eb46ea9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -249,6 +249,16 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** Check if more than one param in a set of exclusive params are set. */ + def checkExclusiveParams(model: Params, params: String*): Unit = { + if (params.filter(paramName => model.hasParam(paramName) && + model.isSet(model.getParam(paramName))).size > 1) { + val paramString = params.mkString("`", "`, `", "`") + throw new IllegalArgumentException(s"$paramString are exclusive, " + + "but more than one among them are set.") + } + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6df..2d3b8d202846 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -114,7 +114,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") - val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), + (5, "bar", 6), (6, "foo", 6)) .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) @@ -123,7 +124,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), - (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0), + (5, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 5.0), + (6, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 6.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -299,7 +302,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = - Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) + Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5), + ("female", "bar", 6), ("female", "foo", 6)) .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) @@ -307,7 +311,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), - ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0), + ("female", "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0), + ("female", "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 0.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -316,7 +322,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("force to index label even it is numeric type") { val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true) val original = spark.createDataFrame( - Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) + Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5), + (1.0, "bar", 6), (0.0, "foo", 6)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) @@ -325,14 +332,17 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), - (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) + (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0), + (1.0, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0), + (0.0, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 1.0)) ).toDF("id", "a", "b", "features", "label") assert(result.collect() === expected.collect()) } test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") - val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), + (1, "bar", 6), (0, "foo", 6)) .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 775a04d3df05..f4aac66b81f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -33,12 +33,38 @@ class StringIndexerSuite test("params") { ParamsSuite.checkParams(new StringIndexer) - val model = new StringIndexerModel("indexer", Array("a", "b")) + val model = new StringIndexerModel("indexer", Array(Array("a", "b"))) val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) ParamsSuite.checkParams(modelWithoutUid) } + test("params: input/output columns") { + val stringIndexerSingleCol = new StringIndexer() + .setInputCol("in").setOutputCol("out") + val inOutCols1 = stringIndexerSingleCol.getInOutCols() + assert(inOutCols1._1 === Array("in")) + assert(inOutCols1._2 === Array("out")) + + val stringIndexerMultiCol = new StringIndexer() + .setInputCols(Array("in1", "in2")).setOutputCols(Array("out1", "out2")) + val inOutCols2 = stringIndexerMultiCol.getInOutCols() + assert(inOutCols2._1 === Array("in1", "in2")) + assert(inOutCols2._2 === Array("out1", "out2")) + + intercept[IllegalArgumentException] { + new StringIndexer().setInputCol("in").setOutputCols(Array("out1", "out2")).getInOutCols() + } + intercept[IllegalArgumentException] { + new StringIndexer().setInputCols(Array("in1", "in2")).setOutputCol("out1").getInOutCols() + } + intercept[IllegalArgumentException] { + new StringIndexer().setInputCols(Array("in1", "in2")) + .setOutputCols(Array("out1", "out2", "out3")) + .getInOutCols() + } + } + test("StringIndexer") { val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) val df = data.toDF("id", "label") @@ -167,7 +193,7 @@ class StringIndexerSuite } test("StringIndexerModel should keep silent if the input column does not exist.") { - val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + val indexerModel = new StringIndexerModel("indexer", Array(Array("a", "b", "c"))) .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() @@ -202,7 +228,7 @@ class StringIndexerSuite } test("StringIndexerModel read/write") { - val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + val instance = new StringIndexerModel("myStringIndexerModel", Array(Array("a", "b", "c"))) .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setHandleInvalid("skip") @@ -331,4 +357,51 @@ class StringIndexerSuite val dfWithIndex = model.transform(dfNoBristol) assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) } + + test("StringIndexer multiple input columns") { + val data = Seq( + Row("a", 0.0, "e", 1.0), + Row("b", 2.0, "f", 0.0), + Row("c", 1.0, "e", 1.0), + Row("a", 0.0, "f", 0.0), + Row("a", 0.0, "f", 0.0), + Row("c", 1.0, "f", 0.0)) + + val schema = StructType(Array( + StructField("label1", StringType), + StructField("expected1", DoubleType), + StructField("label2", StringType), + StructField("expected2", DoubleType))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val indexer = new StringIndexer() + .setInputCols(Array("label1", "label2")) + .setOutputCols(Array("labelIndex1", "labelIndex2")) + val indexerModel = indexer.fit(df) + + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) + + val transformed = indexerModel.transform(df) + + // Checks output attribute correctness. + val attr1 = Attribute.fromStructField(transformed.schema("labelIndex1")) + .asInstanceOf[NominalAttribute] + assert(attr1.values.get === Array("a", "c", "b")) + val attr2 = Attribute.fromStructField(transformed.schema("labelIndex2")) + .asInstanceOf[NominalAttribute] + assert(attr2.values.get === Array("f", "e")) + + transformed.select("labelIndex1", "expected1").rdd.map { r => + (r.getDouble(0), r.getDouble(1)) + }.collect().foreach { case (index, expected) => + assert(index == expected) + } + + transformed.select("labelIndex2", "expected2").rdd.map { r => + (r.getDouble(0), r.getDouble(1)) + }.collect().foreach { case (index, expected) => + assert(index == expected) + } + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b452f35c5ec..495f6a3f9925 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -98,7 +98,15 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-11215][ML] Add multiple columns support to StringIndexer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.this"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.outputCols"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.getOutputCols"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_=") ) // Exclude rules for 2.2.x diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d..0d2712196200 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -914,7 +914,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", + ... stringOrderType="alphabetAsc") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") @@ -1050,7 +1051,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", + ... stringOrderType="alphabetAsc") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) @@ -1188,7 +1190,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", + ... stringOrderType="alphabetAsc") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) From 26cc94bb335cf0ba3bcdbc2b78effd447026792c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 7 Jan 2018 01:42:28 +0000 Subject: [PATCH 02/20] Fix glm test. --- R/pkg/tests/fulltests/test_mllib_regression.R | 42 +++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/R/pkg/tests/fulltests/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R index 23daca75fcc2..b40c4cb9a969 100644 --- a/R/pkg/tests/fulltests/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -102,10 +102,18 @@ test_that("spark.glm and predict", { }) test_that("spark.glm summary", { + # prepare dataset + Sepal.Length <- c(2.0, 1.5, 1.8, 3.4, 5.1, 1.8, 1.0, 2.3) + Sepal.Width <- c(2.1, 2.3, 5.4, 4.7, 3.1, 2.1, 3.1, 5.5) + Petal.Length <- c(1.8, 2.1, 7.1, 2.5, 3.7, 6.3, 2.2, 7.2) + Species <- c("setosa", "versicolor", "versicolor", "versicolor", "virginica", "virginica", + "versicolor", "virginica") + dataset <- data.frame(Sepal.Length, Sepal.Width, Petal.Length, Species, stringsAsFactors = TRUE) + # gaussian family - training <- suppressWarnings(createDataFrame(iris)) + training <- suppressWarnings(createDataFrame(dataset)) stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = dataset)) # test summary coefficients return matrix type expect_true(class(stats$coefficients) == "matrix") @@ -126,15 +134,15 @@ test_that("spark.glm summary", { out <- capture.output(print(stats)) expect_match(out[2], "Deviance Residuals:") - expect_true(any(grepl("AIC: 59.22", out))) + expect_true(any(grepl("AIC: 35.84", out))) # binomial family - df <- suppressWarnings(createDataFrame(iris)) + df <- suppressWarnings(createDataFrame(dataset)) training <- df[df$Species %in% c("versicolor", "virginica"), ] stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, family = binomial(link = "logit"))) - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rTraining <- dataset[dataset$Species %in% c("versicolor", "virginica"), ] rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit"))) @@ -174,17 +182,17 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # Test spark.glm works with offset - training <- suppressWarnings(createDataFrame(iris)) + training <- suppressWarnings(createDataFrame(dataset)) stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species, family = poisson(), offsetCol = "Petal_Length")) rStats <- suppressWarnings(summary(glm(Sepal.Width ~ Sepal.Length + Species, - data = iris, family = poisson(), offset = iris$Petal.Length))) + data = dataset, family = poisson(), offset = dataset$Petal.Length))) expect_true(all(abs(rStats$coefficients - stats$coefficients) < 1e-3)) # Test summary works on base GLM models - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = dataset) baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) + expect_true(abs(baseSummary$deviance - 11.84013) < 1e-4) # Test spark.glm works with regularization parameter data <- as.data.frame(cbind(a1, a2, b)) @@ -300,11 +308,19 @@ test_that("glm and predict", { }) test_that("glm summary", { + # prepare dataset + Sepal.Length <- c(2.0, 1.5, 1.8, 3.4, 5.1, 1.8, 1.0, 2.3) + Sepal.Width <- c(2.1, 2.3, 5.4, 4.7, 3.1, 2.1, 3.1, 5.5) + Petal.Length <- c(1.8, 2.1, 7.1, 2.5, 3.7, 6.3, 2.2, 7.2) + Species <- c("setosa", "versicolor", "versicolor", "versicolor", "virginica", "virginica", + "versicolor", "virginica") + dataset <- data.frame(Sepal.Length, Sepal.Width, Petal.Length, Species, stringsAsFactors = TRUE) + # gaussian family - training <- suppressWarnings(createDataFrame(iris)) + training <- suppressWarnings(createDataFrame(dataset)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = dataset)) coefs <- stats$coefficients rCoefs <- rStats$coefficients @@ -320,12 +336,12 @@ test_that("glm summary", { expect_equal(stats$aic, rStats$aic) # binomial family - df <- suppressWarnings(createDataFrame(iris)) + df <- suppressWarnings(createDataFrame(dataset)) training <- df[df$Species %in% c("versicolor", "virginica"), ] stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = binomial(link = "logit"))) - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rTraining <- dataset[dataset$Species %in% c("versicolor", "virginica"), ] rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit"))) From 18acbbf7b70b87c75ba62be863580fe9accc23b4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Jan 2018 12:03:26 +0000 Subject: [PATCH 03/20] Improve test cases. --- .../spark/ml/feature/StringIndexerSuite.scala | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index f4aac66b81f6..6a7569fd215f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -52,16 +52,19 @@ class StringIndexerSuite assert(inOutCols2._1 === Array("in1", "in2")) assert(inOutCols2._2 === Array("out1", "out2")) + + val df = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")).toDF("id", "label") + intercept[IllegalArgumentException] { - new StringIndexer().setInputCol("in").setOutputCols(Array("out1", "out2")).getInOutCols() + new StringIndexer().setInputCol("in").setOutputCols(Array("out1", "out2")).fit(df) } intercept[IllegalArgumentException] { - new StringIndexer().setInputCols(Array("in1", "in2")).setOutputCol("out1").getInOutCols() + new StringIndexer().setInputCols(Array("in1", "in2")).setOutputCol("out1").fit(df) } intercept[IllegalArgumentException] { new StringIndexer().setInputCols(Array("in1", "in2")) .setOutputCols(Array("out1", "out2", "out3")) - .getInOutCols() + .fit(df) } } @@ -341,6 +344,27 @@ class StringIndexerSuite } } + test("StringIndexer order types: secondary sort by alphabets when frequency equal") { + val data = Seq((0, "a"), (1, "a"), (2, "b"), (3, "b"), (4, "c"), (5, "d")) + val df = data.toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 1.0), (3, 1.0), (4, 2.0), (5, 3.0)), + Set((0, 2.0), (1, 2.0), (2, 3.0), (3, 3.0), (4, 0.0), (5, 1.0))) + + var idx = 0 + for (orderType <- Seq("frequencyDesc", "frequencyAsc")) { + val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) + val output = transformed.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + assert(output === expected(idx)) + idx += 1 + } + } + test("SPARK-22446: StringIndexerModel's indexer UDF should not apply on filtered data") { val df = List( ("A", "London", "StrA"), From 50af02eaccce7cecb7c3093d5bc14675ca860c22 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 11:30:46 +0000 Subject: [PATCH 04/20] Change from 2.3 to 2.4. --- .../spark/ml/feature/StringIndexer.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 20c9fa79ad68..f723a00467b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -162,11 +162,11 @@ class StringIndexer @Since("1.4.0") ( def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - @Since("2.3.0") + @Since("2.4.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) /** @group setParam */ - @Since("2.3.0") + @Since("2.4.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) @Since("2.0.0") @@ -251,7 +251,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.4.0") class StringIndexerModel ( @Since("1.4.0") override val uid: String, - @Since("2.3.0") val labelsArray: Array[Array[String]]) + @Since("2.4.0") val labelsArray: Array[Array[String]]) extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -259,7 +259,7 @@ class StringIndexerModel ( @Since("1.5.0") def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), Array(labels)) - @Since("2.3.0") + @Since("2.4.0") def this(labelsArray: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labelsArray) @Since("1.5.0") @@ -296,11 +296,11 @@ class StringIndexerModel ( def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - @Since("2.3.0") + @Since("2.4.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) /** @group setParam */ - @Since("2.3.0") + @Since("2.4.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) private def filterInvalidData(dataset: Dataset[_], inputColNames: Seq[String]): Dataset[_] = { @@ -432,15 +432,15 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { val dataPath = new Path(path, "data").toString val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) - val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { - // Spark 2.2 and before. + val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 3)) { + // Spark 2.3 and before. val data = sparkSession.read.parquet(dataPath) .select("labels") .head() val labels = data.getAs[Seq[String]](0).toArray Array(labels) } else { - // After Spark 2.3. + // After Spark 2.4. val data = sparkSession.read.parquet(dataPath) .select("labelsArray") .head() From c1be2c7e28ebdfed580577a108d2f254834caed7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Apr 2018 10:15:49 +0000 Subject: [PATCH 05/20] Address comments. --- .../spark/ml/feature/StringIndexer.scala | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f723a00467b9..1e01151fcd7c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -112,18 +112,15 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi skipNonExistsCol: Boolean = false): StructType = { val (inputColNames, outputColNames) = getInOutCols() - val outputFields = for (i <- 0 until inputColNames.length) yield { - if (schema.fieldNames.contains(inputColNames(i))) { - validateAndTransformField(schema, inputColNames(i), outputColNames(i)) - } else { - if (skipNonExistsCol) { - null - } else { - throw new SparkException(s"Input column ${inputColNames(i)} does not exist.") + val outputFields = inputColNames.zip(outputColNames).flatMap { + case (inputColName, outputColName) => + schema.fieldNames.contains(inputColName) match { + case true => Some(validateAndTransformField(schema, inputColName, outputColName)) + case false if skipNonExistsCol => None + case _ => throw new SparkException(s"Input column $inputColName does not exist.") } - } } - StructType(schema.fields ++ outputFields.filter(_ != null)) + StructType(schema.fields ++ outputFields) } } @@ -303,18 +300,20 @@ class StringIndexerModel ( @Since("2.4.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + // This filters out any null values and also the input labels which are not in + // the dataset used for fitting. private def filterInvalidData(dataset: Dataset[_], inputColNames: Seq[String]): Dataset[_] = { - var filteredDataset = dataset.na.drop(inputColNames.filter( - dataset.schema.fieldNames.contains(_))) - for (i <- 0 until inputColNames.length) { + val conditions: Seq[Column] = (0 until inputColNames.length).map { i => val inputColName = inputColNames(i) val labelToIndex = labelsToIndexArray(i) - val filterer = udf { label: String => + val filter = udf { label: String => labelToIndex.contains(label) } - filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) + filter(dataset(inputColName)) } - filteredDataset + + dataset.na.drop(inputColNames.filter(dataset.schema.fieldNames.contains(_))) + .where(conditions.reduce(_ and _)) } private def getIndexer(labels: Seq[String], labelToIndex: OpenHashMap[String, Double]) = { From ed35d875414ba3cf8751a77463f61665e9c373b0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 23 Apr 2018 14:00:16 +0000 Subject: [PATCH 06/20] Address comment. --- .../spark/ml/feature/StringIndexer.scala | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1e01151fcd7c..7c7b0b249df6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -166,15 +166,13 @@ class StringIndexer @Since("1.4.0") ( @Since("2.4.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - @Since("2.0.0") - override def fit(dataset: Dataset[_]): StringIndexerModel = { - transformSchema(dataset.schema, logging = true) + private def countByValue( + dataset: Dataset[_], + inputCols: Array[String]): Array[OpenHashMap[String, Long]] = { - val (inputCols, _) = getInOutCols() val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) - // Counts by the string values in the dataset. - val countByValueArray = dataset.na.drop(inputCols) + dataset.na.drop(inputCols) .select(inputCols.map(col(_).cast(StringType)): _*) .rdd.treeAggregate(zeroState)( (state: Array[OpenHashMap[String, Long]], row: Row) => { @@ -192,17 +190,34 @@ class StringIndexer @Since("1.4.0") ( state1 } ) + } + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StringIndexerModel = { + transformSchema(dataset.schema, logging = true) + + val (inputCols, _) = getInOutCols() // In case of equal frequency when frequencyDesc/Asc, we further sort the strings by alphabet. - val labelsArray = countByValueArray.map { countByValue => - $(stringOrderType) match { - case StringIndexer.frequencyDesc => - countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray - case StringIndexer.frequencyAsc => - countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray - case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray - case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray - } + val labelsArray = $(stringOrderType) match { + case StringIndexer.frequencyDesc => + countByValue(dataset, inputCols).map { counts => + counts.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray + } + case StringIndexer.frequencyAsc => + countByValue(dataset, inputCols).map { counts => + counts.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray + } + case StringIndexer.alphabetDesc => + import dataset.sparkSession.implicits._ + inputCols.map { inputCol => + dataset.select(inputCol).distinct().sort(dataset(s"$inputCol").desc).as[String].collect() + } + case StringIndexer.alphabetAsc => + import dataset.sparkSession.implicits._ + inputCols.map { inputCol => + dataset.select(inputCol).distinct().sort(dataset(s"$inputCol").asc).as[String].collect() + } } copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) } From a1dcfda85243a1e2210177f2acfb78821c539b17 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Apr 2018 06:41:07 +0000 Subject: [PATCH 07/20] Use SQL Aggregator for counting string labels. --- .../spark/ml/feature/StringIndexer.scala | 66 ++++++++++++++----- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7c7b0b249df6..821895ab3705 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -28,7 +28,8 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row} +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.VersionUtils.majorMinorVersion @@ -170,26 +171,15 @@ class StringIndexer @Since("1.4.0") ( dataset: Dataset[_], inputCols: Array[String]): Array[OpenHashMap[String, Long]] = { - val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) + val aggregator = new StringIndexerAggregator(inputCols.length) + implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]] dataset.na.drop(inputCols) .select(inputCols.map(col(_).cast(StringType)): _*) - .rdd.treeAggregate(zeroState)( - (state: Array[OpenHashMap[String, Long]], row: Row) => { - for (i <- 0 until inputCols.length) { - state(i).changeValue(row.getString(i), 1L, _ + 1) - } - state - }, - (state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => { - for (i <- 0 until inputCols.length) { - state2(i).foreach { case (key: String, count: Long) => - state1(i).changeValue(key, count, _ + count) - } - } - state1 - } - ) + .toDF + .groupBy().agg(aggregator.toColumn) + .as[Array[OpenHashMap[String, Long]]] + .collect()(0) } @Since("2.0.0") @@ -567,3 +557,43 @@ object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") override def load(path: String): IndexToString = super.load(path) } + +/** + * A SQL `Aggregator` used by `StringIndexer` to count labels in string columns during fitting. + */ +private class StringIndexerAggregator(numColumns: Int) + extends Aggregator[Row, Array[OpenHashMap[String, Long]], Array[OpenHashMap[String, Long]]] { + + override def zero: Array[OpenHashMap[String, Long]] = + Array.fill(numColumns)(new OpenHashMap[String, Long]()) + + def reduce( + array: Array[OpenHashMap[String, Long]], + row: Row): Array[OpenHashMap[String, Long]] = { + for (i <- 0 until numColumns) { + array(i).changeValue(row.getString(i), 1L, _ + 1) + } + array + } + + def merge( + array1: Array[OpenHashMap[String, Long]], + array2: Array[OpenHashMap[String, Long]]): Array[OpenHashMap[String, Long]] = { + for (i <- 0 until numColumns) { + array2(i).foreach { case (key: String, count: Long) => + array1(i).changeValue(key, count, _ + count) + } + } + array1 + } + + def finish(array: Array[OpenHashMap[String, Long]]): Array[OpenHashMap[String, Long]] = array + + override def bufferEncoder: Encoder[Array[OpenHashMap[String, Long]]] = { + Encoders.kryo[Array[OpenHashMap[String, Long]]] + } + + override def outputEncoder: Encoder[Array[OpenHashMap[String, Long]]] = { + Encoders.kryo[Array[OpenHashMap[String, Long]]] + } +} From a6551b02a10428d66e0dadcfcb5a8da3798ec814 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Apr 2018 04:13:09 +0000 Subject: [PATCH 08/20] Drop NA values for both frequency and alphabet order types. --- .../spark/ml/feature/StringIndexer.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 569d8bd7075b..832c5414827f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -174,8 +174,7 @@ class StringIndexer @Since("1.4.0") ( val aggregator = new StringIndexerAggregator(inputCols.length) implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]] - dataset.na.drop(inputCols) - .select(inputCols.map(col(_).cast(StringType)): _*) + dataset.select(inputCols.map(col(_).cast(StringType)): _*) .toDF .groupBy().agg(aggregator.toColumn) .as[Array[OpenHashMap[String, Long]]] @@ -188,25 +187,29 @@ class StringIndexer @Since("1.4.0") ( val (inputCols, _) = getInOutCols() + val filteredDF = dataset.na.drop(inputCols) + // In case of equal frequency when frequencyDesc/Asc, we further sort the strings by alphabet. val labelsArray = $(stringOrderType) match { case StringIndexer.frequencyDesc => - countByValue(dataset, inputCols).map { counts => + countByValue(filteredDF, inputCols).map { counts => counts.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray } case StringIndexer.frequencyAsc => - countByValue(dataset, inputCols).map { counts => + countByValue(filteredDF, inputCols).map { counts => counts.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray } case StringIndexer.alphabetDesc => import dataset.sparkSession.implicits._ inputCols.map { inputCol => - dataset.select(inputCol).distinct().sort(dataset(s"$inputCol").desc).as[String].collect() + filteredDF.select(inputCol).distinct().sort(dataset(s"$inputCol").desc) + .as[String].collect() } case StringIndexer.alphabetAsc => import dataset.sparkSession.implicits._ inputCols.map { inputCol => - dataset.select(inputCol).distinct().sort(dataset(s"$inputCol").asc).as[String].collect() + filteredDF.select(inputCol).distinct().sort(dataset(s"$inputCol").asc) + .as[String].collect() } } copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) @@ -375,7 +378,9 @@ class StringIndexerModel ( case _ => labels } val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues(filteredLabels).toMetadata() + .withName(outputColName) + .withValues(filteredLabels) + .toMetadata() val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) val indexer = getIndexer(labels, labelToIndex) From f7102e92fe512c893ff066e43b48b124f0a117e6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 8 Dec 2018 12:17:28 +0800 Subject: [PATCH 09/20] Address comments. --- .../spark/ml/feature/StringIndexer.scala | 40 ++++++++++++------- .../spark/ml/feature/StringIndexerSuite.scala | 5 +++ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 832c5414827f..ea45e3b77008 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -68,6 +68,9 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi * - 'alphabetAsc': ascending alphabetical order * Default is 'frequencyDesc'. * + * Note: In case of equal frequency when under frequencyDesc/Asc, the strings are further sorted + * by alphabet. + * * @group param */ @Since("2.3.0") @@ -113,6 +116,9 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi skipNonExistsCol: Boolean = false): StructType = { val (inputColNames, outputColNames) = getInOutCols() + require(outputColNames.distinct.length == outputColNames.length, + s"Output columns should not be duplicate.") + val outputFields = inputColNames.zip(outputColNames).flatMap { case (inputColName, outputColName) => schema.fieldNames.contains(inputColName) match { @@ -160,11 +166,11 @@ class StringIndexer @Since("1.4.0") ( def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - @Since("2.4.0") + @Since("3.0.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) /** @group setParam */ - @Since("2.4.0") + @Since("3.0.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) private def countByValue( @@ -201,16 +207,22 @@ class StringIndexer @Since("1.4.0") ( } case StringIndexer.alphabetDesc => import dataset.sparkSession.implicits._ - inputCols.map { inputCol => + filteredDF.persist() + val labels = inputCols.map { inputCol => filteredDF.select(inputCol).distinct().sort(dataset(s"$inputCol").desc) .as[String].collect() } + filteredDF.unpersist() + labels case StringIndexer.alphabetAsc => import dataset.sparkSession.implicits._ - inputCols.map { inputCol => + filteredDF.persist() + val labels = inputCols.map { inputCol => filteredDF.select(inputCol).distinct().sort(dataset(s"$inputCol").asc) .as[String].collect() } + filteredDF.unpersist() + labels } copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) } @@ -256,7 +268,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.4.0") class StringIndexerModel ( @Since("1.4.0") override val uid: String, - @Since("2.4.0") val labelsArray: Array[Array[String]]) + @Since("3.0.0") val labelsArray: Array[Array[String]]) extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -264,7 +276,7 @@ class StringIndexerModel ( @Since("1.5.0") def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), Array(labels)) - @Since("2.4.0") + @Since("3.0.0") def this(labelsArray: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labelsArray) @Since("1.5.0") @@ -279,10 +291,8 @@ class StringIndexerModel ( for (labels <- labelsArray) yield { val n = labels.length val map = new OpenHashMap[String, Double](n) - var i = 0 - while (i < n) { - map.update(labels(i), i) - i += 1 + labels.zipWithIndex.foreach { case (label, idx) => + map.update(label, idx) } map } @@ -301,11 +311,11 @@ class StringIndexerModel ( def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - @Since("2.4.0") + @Since("3.0.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) /** @group setParam */ - @Since("2.4.0") + @Since("3.0.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) // This filters out any null values and also the input labels which are not in @@ -441,15 +451,15 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { val dataPath = new Path(path, "data").toString val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) - val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 3)) { - // Spark 2.3 and before. + val labelsArray = if (majorVersion < 3) { + // Spark 2.4 and before. val data = sparkSession.read.parquet(dataPath) .select("labels") .head() val labels = data.getAs[Seq[String]](0).toArray Array(labels) } else { - // After Spark 2.4. + // After Spark 3.0. val data = sparkSession.read.parquet(dataPath) .select("labelsArray") .head() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0d66fce9291..41a9ebe12fed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -63,6 +63,11 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { .setOutputCols(Array("out1", "out2", "out3")) .fit(df) } + intercept[IllegalArgumentException] { + new StringIndexer().setInputCols(Array("in1", "in2")) + .setOutputCols(Array("out1", "out1")) + .fit(df) + } } test("StringIndexer") { From 301fa4cbb4d62d9a180dcbafbe0d2b68dac5a3c8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 8 Dec 2018 12:51:06 +0800 Subject: [PATCH 10/20] Update ml document. --- docs/ml-features.md | 6 ++++-- docs/ml-guide.md | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a140bc6e7a22..33373e094656 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -585,11 +585,13 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels)`, and four ordering options are supported: +`StringIndexer` can encode multiple columns. The indices are in `[0, numLabels)`, and four ordering options are supported: "frequencyDesc": descending order by label frequency (most frequent label assigned 0), "frequencyAsc": ascending order by label frequency (least frequent label assigned 0), "alphabetDesc": descending alphabetical order, and "alphabetAsc": ascending alphabetical order -(default = "frequencyDesc"). +(default = "frequencyDesc"). Note that in case of equal frequency when under +"frequencyDesc"/"frequencyAsc", the strings are further sorted by alphabet. + The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 57d4e1fe9d33..cffe41940eed 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -110,6 +110,15 @@ and the migration guide below will explain all changes between releases. * `OneHotEncoder` which is deprecated in 2.3, is removed in 3.0 and `OneHotEncoderEstimator` is now renamed to `OneHotEncoder`. +### Changes of behavior + +* [SPARK-11215](https://issues.apache.org/jira/browse/SPARK-11215): + In Spark 2.4 and previous versions, when specifying `frequencyDesc` or `frequencyAsc` as + `stringOrderType` param in `StringIndexer`, in case of equal frequency, the order of + strings is undefined. Since Spark 3.0, the strings with equal frequency are further + sorted by alphabet. And since Spark 3.0, `StringIndexer` supports encoding multiple + columns. + ## From 2.2 to 2.3 ### Breaking changes From 196db6356f56cf48bc4617a939c27daffb8aa3c2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 21 Dec 2018 16:16:03 +0800 Subject: [PATCH 11/20] Address comments. --- .../spark/ml/feature/StringIndexer.scala | 28 +++++++++++++++---- project/MimaExcludes.scala | 20 ++++++------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ea45e3b77008..cda1e9b176cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -69,7 +69,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi * Default is 'frequencyDesc'. * * Note: In case of equal frequency when under frequencyDesc/Asc, the strings are further sorted - * by alphabet. + * alphabetically. * * @group param */ @@ -195,15 +195,18 @@ class StringIndexer @Since("1.4.0") ( val filteredDF = dataset.na.drop(inputCols) - // In case of equal frequency when frequencyDesc/Asc, we further sort the strings by alphabet. + // In case of equal frequency when frequencyDesc/Asc, the strings are further sorted + // alphabetically. val labelsArray = $(stringOrderType) match { case StringIndexer.frequencyDesc => countByValue(filteredDF, inputCols).map { counts => - counts.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray + val func = StringIndexer.getSortFunc(ascending = false) + counts.toSeq.sortWith(func).map(_._1).toArray } case StringIndexer.frequencyAsc => + val func = StringIndexer.getSortFunc(ascending = true) countByValue(filteredDF, inputCols).map { counts => - counts.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray + counts.toSeq.sortWith(func).map(_._1).toArray } case StringIndexer.alphabetDesc => import dataset.sparkSession.implicits._ @@ -252,6 +255,19 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) + + // Returns a function used to sort strings by frequency (ascending or descending). + // In case of equal frequency, it sorts strings by alphabet (ascending). + private[feature] def getSortFunc( + ascending: Boolean): ((String, Long), (String, Long)) => Boolean = { + (a: (String, Long), b: (String, Long)) => { + if (a._2 == b._2) { + a._1 < b._1 + } else { + if (ascending) a._2 < b._2 else a._2 > b._2 + } + } + } } /** @@ -379,7 +395,7 @@ class StringIndexerModel ( val labels = labelsArray(i) if (!dataset.schema.fieldNames.contains(inputColName)) { - logInfo(s"Input column ${inputColName} does not exist during transformation. " + + logWarning(s"Input column ${inputColName} does not exist during transformation. " + "Skip StringIndexerModel for this column.") outputColNames(i) = null } else { @@ -450,6 +466,8 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString + // We support to load old `StringIndexerModel` saved by previous Spark versions. + // Previous model has `labels`, but new model has `labelsArray`. val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) val labelsArray = if (majorVersion < 3) { // Spark 2.4 and before. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index baf985f941f1..a81039c462a2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -234,7 +234,15 @@ object MimaExcludes { // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.expressions.UserDefinedFunction") + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), + + // [SPARK-11215][ML] Add multiple columns support to StringIndexer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.this"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.outputCols"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.getOutputCols"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_=") ) // Exclude rules for 2.4.x @@ -457,15 +465,7 @@ object MimaExcludes { ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid"), - - // [SPARK-11215][ML] Add multiple columns support to StringIndexer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.this"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.outputCols"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.getOutputCols"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_=") + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") ) // Exclude rules for 2.2.x From cd1eda0a2ae939770c7ef4bce1f5524f9097cbfc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 21 Dec 2018 18:19:13 +0800 Subject: [PATCH 12/20] Add a comment. --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cda1e9b176cb..fdf8d533bef6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -340,6 +340,11 @@ class StringIndexerModel ( val conditions: Seq[Column] = (0 until inputColNames.length).map { i => val inputColName = inputColNames(i) val labelToIndex = labelsToIndexArray(i) + // We have this additional lookup at `labelToIndex` when `handleInvalid` is set to + // `StringIndexer.SKIP_INVALID`. Another idea is to do this lookup natively by SQL + // expression, however, lookup for a key in a map is not efficient in SparkSQL now. + // See `ElementAt` and `GetMapValue` expressions. If SQL's map lookup is improved, + // we can consider to change this. val filter = udf { label: String => labelToIndex.contains(label) } From 70009a5d23f121cf538523c6670d8fa1735d662d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 22 Dec 2018 18:24:15 +0800 Subject: [PATCH 13/20] Address part of comments. --- docs/ml-guide.md | 3 +- .../spark/ml/feature/StringIndexer.scala | 31 ++++++++++++------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index cffe41940eed..cb936162d64d 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -117,7 +117,8 @@ and the migration guide below will explain all changes between releases. `stringOrderType` param in `StringIndexer`, in case of equal frequency, the order of strings is undefined. Since Spark 3.0, the strings with equal frequency are further sorted by alphabet. And since Spark 3.0, `StringIndexer` supports encoding multiple - columns. + columns. Because of this change, `StringIndexerModel`'s public constructor `def this(uid: String, labels: Array[String])` + is not available. Since Spark 3.0, Developers can use `def this(uid: String, labelsArray: Array[Array[String]])` instead. ## From 2.2 to 2.3 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index fdf8d533bef6..bfa23865f6b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -199,14 +199,14 @@ class StringIndexer @Since("1.4.0") ( // alphabetically. val labelsArray = $(stringOrderType) match { case StringIndexer.frequencyDesc => + val sortFunc = StringIndexer.getSortFunc(ascending = false) countByValue(filteredDF, inputCols).map { counts => - val func = StringIndexer.getSortFunc(ascending = false) - counts.toSeq.sortWith(func).map(_._1).toArray + counts.toSeq.sortWith(sortFunc).map(_._1).toArray } case StringIndexer.frequencyAsc => - val func = StringIndexer.getSortFunc(ascending = true) + val sortFunc = StringIndexer.getSortFunc(ascending = true) countByValue(filteredDF, inputCols).map { counts => - counts.toSeq.sortWith(func).map(_._1).toArray + counts.toSeq.sortWith(sortFunc).map(_._1).toArray } case StringIndexer.alphabetDesc => import dataset.sparkSession.implicits._ @@ -260,11 +260,21 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { // In case of equal frequency, it sorts strings by alphabet (ascending). private[feature] def getSortFunc( ascending: Boolean): ((String, Long), (String, Long)) => Boolean = { - (a: (String, Long), b: (String, Long)) => { - if (a._2 == b._2) { - a._1 < b._1 - } else { - if (ascending) a._2 < b._2 else a._2 > b._2 + if (ascending) { + (a: (String, Long), b: (String, Long)) => { + if (a._2 == b._2) { + a._1 < b._1 + } else { + a._2 < b._2 + } + } + } else { + (a: (String, Long), b: (String, Long)) => { + if (a._2 == b._2) { + a._1 < b._1 + } else { + a._2 > b._2 + } } } } @@ -412,7 +422,6 @@ class StringIndexerModel ( .withName(outputColName) .withValues(filteredLabels) .toMetadata() - val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) val indexer = getIndexer(labels, labelToIndex) @@ -471,7 +480,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - // We support to load old `StringIndexerModel` saved by previous Spark versions. + // We support loading old `StringIndexerModel` saved by previous Spark versions. // Previous model has `labels`, but new model has `labelsArray`. val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) val labelsArray = if (majorVersion < 3) { From 3c6ffc731bf75382c117bd83802b6eb7f99c5da8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 22 Dec 2018 19:58:38 +0800 Subject: [PATCH 14/20] Fix null and NaN issue. --- .../spark/ml/feature/StringIndexer.scala | 31 ++++++++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 32 +++++++++++++------ 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index bfa23865f6b5..c2b9952859f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -177,7 +177,8 @@ class StringIndexer @Since("1.4.0") ( dataset: Dataset[_], inputCols: Array[String]): Array[OpenHashMap[String, Long]] = { - val aggregator = new StringIndexerAggregator(inputCols.length) + val inputColTypes = inputCols.map(dataset.col(_).expr.dataType) + val aggregator = new StringIndexerAggregator(inputCols.length, inputColTypes) implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]] dataset.select(inputCols.map(col(_).cast(StringType)): _*) @@ -193,38 +194,36 @@ class StringIndexer @Since("1.4.0") ( val (inputCols, _) = getInOutCols() - val filteredDF = dataset.na.drop(inputCols) - // In case of equal frequency when frequencyDesc/Asc, the strings are further sorted // alphabetically. val labelsArray = $(stringOrderType) match { case StringIndexer.frequencyDesc => val sortFunc = StringIndexer.getSortFunc(ascending = false) - countByValue(filteredDF, inputCols).map { counts => + countByValue(dataset, inputCols).map { counts => counts.toSeq.sortWith(sortFunc).map(_._1).toArray } case StringIndexer.frequencyAsc => val sortFunc = StringIndexer.getSortFunc(ascending = true) - countByValue(filteredDF, inputCols).map { counts => + countByValue(dataset, inputCols).map { counts => counts.toSeq.sortWith(sortFunc).map(_._1).toArray } case StringIndexer.alphabetDesc => import dataset.sparkSession.implicits._ - filteredDF.persist() + dataset.persist() val labels = inputCols.map { inputCol => - filteredDF.select(inputCol).distinct().sort(dataset(s"$inputCol").desc) + dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").desc) .as[String].collect() } - filteredDF.unpersist() + dataset.unpersist() labels case StringIndexer.alphabetAsc => import dataset.sparkSession.implicits._ - filteredDF.persist() + dataset.persist() val labels = inputCols.map { inputCol => - filteredDF.select(inputCol).distinct().sort(dataset(s"$inputCol").asc) + dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").asc) .as[String].collect() } - filteredDF.unpersist() + dataset.unpersist() labels } copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) @@ -608,7 +607,7 @@ object IndexToString extends DefaultParamsReadable[IndexToString] { /** * A SQL `Aggregator` used by `StringIndexer` to count labels in string columns during fitting. */ -private class StringIndexerAggregator(numColumns: Int) +private class StringIndexerAggregator(numColumns: Int, inputColTypes: Seq[DataType]) extends Aggregator[Row, Array[OpenHashMap[String, Long]], Array[OpenHashMap[String, Long]]] { override def zero: Array[OpenHashMap[String, Long]] = @@ -618,7 +617,13 @@ private class StringIndexerAggregator(numColumns: Int) array: Array[OpenHashMap[String, Long]], row: Row): Array[OpenHashMap[String, Long]] = { for (i <- 0 until numColumns) { - array(i).changeValue(row.getString(i), 1L, _ + 1) + val stringValue = row.getString(i) + // We don't count for null and NaN values. + // For NaN values, because the values in the row are converted to string before aggregation, + // we skip for `NaN` string if the original column type is not string type. + if (stringValue != null && (inputColTypes(i) == StringType || stringValue != "NaN")) { + array(i).changeValue(stringValue, 1L, _ + 1) + } } array } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 41a9ebe12fed..12548010a4d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -85,7 +85,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { (2, 1.0), (3, 0.0), (4, 0.0), - (5, 1.0) + (5, 1.0) ).toDF("id", "labelIndex") testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => @@ -380,9 +380,9 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { test("SPARK-22446: StringIndexerModel's indexer UDF should not apply on filtered data") { val df = List( - ("A", "London", "StrA"), - ("B", "Bristol", null), - ("C", "New York", "StrC")).toDF("ID", "CITY", "CONTENT") + ("A", "London", "StrA"), + ("B", "Bristol", null), + ("C", "New York", "StrC")).toDF("ID", "CITY", "CONTENT") val dfNoBristol = df.filter($"CONTENT".isNotNull) @@ -409,10 +409,10 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { Row("c", 1.0, "f", 0.0)) val schema = StructType(Array( - StructField("label1", StringType), - StructField("expected1", DoubleType), - StructField("label2", StringType), - StructField("expected2", DoubleType))) + StructField("label1", StringType), + StructField("expected1", DoubleType), + StructField("label2", StringType), + StructField("expected2", DoubleType))) val df = spark.createDataFrame(sc.parallelize(data), schema) @@ -434,15 +434,27 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { assert(attr2.values.get === Array("f", "e")) transformed.select("labelIndex1", "expected1").rdd.map { r => - (r.getDouble(0), r.getDouble(1)) + (r.getDouble(0), r.getDouble(1)) }.collect().foreach { case (index, expected) => assert(index == expected) } transformed.select("labelIndex2", "expected2").rdd.map { r => - (r.getDouble(0), r.getDouble(1)) + (r.getDouble(0), r.getDouble(1)) }.collect().foreach { case (index, expected) => assert(index == expected) } } + + test("Correctly skipping NULL and NaN values") { + val df = Seq(("a", Double.NaN), (null, 1.0), ("b", 2.0), (null, 3.0)).toDF("str", "double") + + val indexer = new StringIndexer() + .setInputCols(Array("str", "double")) + .setOutputCols(Array("strIndex", "doubleIndex")) + + val model = indexer.fit(df) + assert(model.labelsArray(0) === Array("a", "b")) + assert(model.labelsArray(1) === Array("1.0", "2.0", "3.0")) + } } From b6ad1e4830f908b2a11040d335cd4f7afb5cc90f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 23 Dec 2018 11:53:50 +0800 Subject: [PATCH 15/20] Add test for loading model prior to Spark 3.0. --- ...980-4c42-b8a7-a5a94265c479-c000.snappy.parquet | Bin 0 -> 478 bytes .../test-data/strIndexerModel/metadata/part-00000 | 1 + .../spark/ml/feature/StringIndexerSuite.scala | 11 +++++++++++ 3 files changed, 12 insertions(+) create mode 100644 mllib/src/test/resources/test-data/strIndexerModel/data/part-00000-cfefeb56-2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet create mode 100644 mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000 diff --git a/mllib/src/test/resources/test-data/strIndexerModel/data/part-00000-cfefeb56-2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet b/mllib/src/test/resources/test-data/strIndexerModel/data/part-00000-cfefeb56-2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..917984c2608be1128fcbf635ee995d88fcc0029d GIT binary patch literal 478 zcmX|;O-sW-5QaCYtUdI)A%Q)Vg_ag-Ncv?Af_f7X5%DG>lVsK4W?QpcMM{5zcYm#u zT6-C0-tP4=!hQLb4kGc;i7&CVF^ePZh=Dcu0kYxUr8I*i_fr< z77DQ7e>-bozm`&@!q|G1TQL&PrBsCx>BCxRsG8E>3RwDU-H$+B!xzfCHSYIDrG7y< zUuX-ZHa9D!M2hxuGZGpfkSVlngpv415$SrfdE5~HY6p5$)!h$Sn{uJMaTrHo)E#82 yQWf2362p8xn?})zoMg#38b?t)hB%o{rui%xXT|X}Kc2ua;2gi@2k&!>m;D9ek9wg1 literal 0 HcmV?d00001 diff --git a/mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000 b/mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000 new file mode 100644 index 000000000000..5650199c36dc --- /dev/null +++ b/mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"org.apache.spark.ml.feature.StringIndexerModel","timestamp":1545536052048,"sparkVersion":"2.4.1-SNAPSHOT","uid":"strIdx_056bb5da1bf2","paramMap":{"outputCol":"index","inputCol":"str"},"defaultParamMap":{"outputCol":"strIdx_056bb5da1bf2__output","stringOrderType":"frequencyDesc","handleInvalid":"error"}} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 12548010a4d7..f542e342ffaa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -457,4 +457,15 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { assert(model.labelsArray(0) === Array("a", "b")) assert(model.labelsArray(1) === Array("1.0", "2.0", "3.0")) } + + test("Load StringIndexderModel prior to Spark 3.0") { + val modelPath = testFile("test-data/strIndexerModel") + + val loadedModel = StringIndexerModel.load(modelPath) + assert(loadedModel.labelsArray === Array(Array("b", "c", "a"))) + + val metadata = spark.read.json(s"$modelPath/metadata") + val sparkVersionStr = metadata.select("sparkVersion").first().getString(0) + assert(sparkVersionStr == "2.4.1-SNAPSHOT") + } } From d6fed351c1bc1980d9785b087b68290dc695183a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Dec 2018 17:55:45 +0800 Subject: [PATCH 16/20] Address comment. --- .../spark/ml/feature/StringIndexer.scala | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index c2b9952859f0..e5b2c1b58672 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row} +import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -177,11 +178,21 @@ class StringIndexer @Since("1.4.0") ( dataset: Dataset[_], inputCols: Array[String]): Array[OpenHashMap[String, Long]] = { - val inputColTypes = inputCols.map(dataset.col(_).expr.dataType) - val aggregator = new StringIndexerAggregator(inputCols.length, inputColTypes) + val aggregator = new StringIndexerAggregator(inputCols.length) implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]] - dataset.select(inputCols.map(col(_).cast(StringType)): _*) + val selectedCols = inputCols.map { colName => + val col = dataset.col(colName) + if (col.expr.dataType == StringType) { + col + } else { + // We don't count for NaN values. Because `StringIndexerAggregator` only processes strings, + // we replace NaNs with null in advance. + new Column(If(col.isNaN.expr, Literal(null), col.expr)).cast(StringType) + } + } + + dataset.select(selectedCols: _*) .toDF .groupBy().agg(aggregator.toColumn) .as[Array[OpenHashMap[String, Long]]] @@ -607,7 +618,7 @@ object IndexToString extends DefaultParamsReadable[IndexToString] { /** * A SQL `Aggregator` used by `StringIndexer` to count labels in string columns during fitting. */ -private class StringIndexerAggregator(numColumns: Int, inputColTypes: Seq[DataType]) +private class StringIndexerAggregator(numColumns: Int) extends Aggregator[Row, Array[OpenHashMap[String, Long]], Array[OpenHashMap[String, Long]]] { override def zero: Array[OpenHashMap[String, Long]] = @@ -618,10 +629,8 @@ private class StringIndexerAggregator(numColumns: Int, inputColTypes: Seq[DataTy row: Row): Array[OpenHashMap[String, Long]] = { for (i <- 0 until numColumns) { val stringValue = row.getString(i) - // We don't count for null and NaN values. - // For NaN values, because the values in the row are converted to string before aggregation, - // we skip for `NaN` string if the original column type is not string type. - if (stringValue != null && (inputColTypes(i) == StringType || stringValue != "NaN")) { + // We don't count for null values. + if (stringValue != null) { array(i).changeValue(stringValue, 1L, _ + 1) } } From 0137d670c3ea3c8e080b054434d14b8c43f62430 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 7 Jan 2019 08:19:27 +0800 Subject: [PATCH 17/20] Unpersist if input is not originally cached. Add deprecated info. --- .../apache/spark/ml/feature/StringIndexer.scala | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e5b2c1b58672..3de8cf7b6ec2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorMinorVersion import org.apache.spark.util.collection.OpenHashMap @@ -205,6 +206,10 @@ class StringIndexer @Since("1.4.0") ( val (inputCols, _) = getInOutCols() + // If input dataset is not originally cached, we need to unpersist it + // once we persist it later. + val needUnpersist = dataset.storageLevel == StorageLevel.NONE + // In case of equal frequency when frequencyDesc/Asc, the strings are further sorted // alphabetically. val labelsArray = $(stringOrderType) match { @@ -225,7 +230,9 @@ class StringIndexer @Since("1.4.0") ( dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").desc) .as[String].collect() } - dataset.unpersist() + if (needUnpersist) { + dataset.unpersist() + } labels case StringIndexer.alphabetAsc => import dataset.sparkSession.implicits._ @@ -234,7 +241,9 @@ class StringIndexer @Since("1.4.0") ( dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").asc) .as[String].collect() } - dataset.unpersist() + if (needUnpersist) { + dataset.unpersist() + } labels } copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) @@ -309,12 +318,16 @@ class StringIndexerModel ( import StringIndexerModel._ + @deprecated("`this(labels: Array[String])` is deprecated and will be removed in 3.1.0. " + + "Use `this(labelsArray: Array[Array[String]])` instead.", "3.0.0") @Since("1.5.0") def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), Array(labels)) @Since("3.0.0") def this(labelsArray: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labelsArray) + @deprecated("`labels` is deprecated and will be removed in 3.1.0. Use `labelsArray` " + + "instead.", "3.0.0") @Since("1.5.0") def labels: Array[String] = { require(labelsArray.length == 1, "This StringIndexerModel is fitted by multi-columns, " + From 7a5be1202a0f57a7bd25ba2027b7fc391b80d201 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 7 Jan 2019 09:05:36 +0800 Subject: [PATCH 18/20] Revert classification doctests change. --- python/pyspark/ml/classification.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index aee0018d08f4..6ddfce95a3d4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -911,8 +911,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", - ... stringOrderType="alphabetAsc") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") @@ -1048,8 +1047,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", - ... stringOrderType="alphabetAsc") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) @@ -1216,8 +1214,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", - ... stringOrderType="alphabetAsc") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) From b33556b2f62f8d8dc0f13d5bb2d180310c679683 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 13 Jan 2019 00:02:58 +0800 Subject: [PATCH 19/20] Fix style. Revert deprecated methods. --- docs/ml-guide.md | 3 +-- .../spark/ml/feature/StringIndexer.scala | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index cb936162d64d..cffe41940eed 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -117,8 +117,7 @@ and the migration guide below will explain all changes between releases. `stringOrderType` param in `StringIndexer`, in case of equal frequency, the order of strings is undefined. Since Spark 3.0, the strings with equal frequency are further sorted by alphabet. And since Spark 3.0, `StringIndexer` supports encoding multiple - columns. Because of this change, `StringIndexerModel`'s public constructor `def this(uid: String, labels: Array[String])` - is not available. Since Spark 3.0, Developers can use `def this(uid: String, labelsArray: Array[Array[String]])` instead. + columns. ## From 2.2 to 2.3 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3de8cf7b6ec2..8982a419902b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -280,19 +280,19 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] def getSortFunc( ascending: Boolean): ((String, Long), (String, Long)) => Boolean = { if (ascending) { - (a: (String, Long), b: (String, Long)) => { - if (a._2 == b._2) { - a._1 < b._1 - } else { - a._2 < b._2 - } - } + { case ((strA: String, freqA: Long), (strB: String, freqB: Long)) => + if (freqA == freqB) { + strA < strB + } else { + freqA < freqB + } + } } else { - (a: (String, Long), b: (String, Long)) => { - if (a._2 == b._2) { - a._1 < b._1 + { case ((strA: String, freqA: Long), (strB: String, freqB: Long)) => + if (freqA == freqB) { + strA < strB } else { - a._2 > b._2 + freqA > freqB } } } @@ -318,8 +318,9 @@ class StringIndexerModel ( import StringIndexerModel._ - @deprecated("`this(labels: Array[String])` is deprecated and will be removed in 3.1.0. " + - "Use `this(labelsArray: Array[Array[String]])` instead.", "3.0.0") + @Since("1.5.0") + def this(uid: String, labels: Array[String]) = this(uid, Array(labels)) + @Since("1.5.0") def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), Array(labels)) From 867e0019c33e0bfd46247968ae16648a662e60fa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Jan 2019 08:52:36 +0800 Subject: [PATCH 20/20] Address comments. --- .../apache/spark/ml/feature/StringIndexer.scala | 16 +++++++--------- project/MimaExcludes.scala | 4 ---- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 8982a419902b..f2e6012050af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -280,21 +280,19 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] def getSortFunc( ascending: Boolean): ((String, Long), (String, Long)) => Boolean = { if (ascending) { - { case ((strA: String, freqA: Long), (strB: String, freqB: Long)) => + case ((strA: String, freqA: Long), (strB: String, freqB: Long)) => if (freqA == freqB) { - strA < strB + strA < strB } else { - freqA < freqB + freqA < freqB } - } } else { - { case ((strA: String, freqA: Long), (strB: String, freqB: Long)) => + case ((strA: String, freqA: Long), (strB: String, freqB: Long)) => if (freqA == freqB) { - strA < strB + strA < strB } else { freqA > freqB } - } } } } @@ -331,8 +329,8 @@ class StringIndexerModel ( "instead.", "3.0.0") @Since("1.5.0") def labels: Array[String] = { - require(labelsArray.length == 1, "This StringIndexerModel is fitted by multi-columns, " + - "call for `labelsArray` instead.") + require(labelsArray.length == 1, "This StringIndexerModel is fit on multiple columns. " + + "Call `labelsArray` instead.") labelsArray(0) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1ed3a4388d3c..a0d85f60f9fb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -278,10 +278,6 @@ object MimaExcludes { // [SPARK-11215][ML] Add multiple columns support to StringIndexer ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.this"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.outputCols"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.getOutputCols"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), // [SPARK-26616][MLlib] Expose document frequency in IDFModel ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"),