From b970728f48f22f0c2789a941c1fe1ac6b94a3b49 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 13:50:30 +0800 Subject: [PATCH 01/12] [SPARK-17498][ML] StringIndexer handles unseen labels This PR is an enhancement to ML StringIndexer. Before this PR, String Indexer only supports "skip"/"error" options to deal with unseen records. But sometimes those unseen records might still be useful in certain use cases, so user would like to keep the unseen labels. This PR enables StringIndexer to support keeping unseen labels as indices [numLabels]. '''Before StringIndexer().setHandleInvalid("skip") StringIndexer().setHandleInvalid("error") '''After support the third option "keep" StringIndexer().setHandleInvalid("keep") Signed-off-by: VinceShieh --- .../spark/ml/feature/StringIndexer.scala | 66 ++++++++++++------- .../spark/ml/feature/StringIndexerSuite.scala | 32 +++++---- 2 files changed, 63 insertions(+), 35 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a503411b6361..c607ce4aafe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -34,9 +34,34 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasHandleInvalid { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + val SKIP_UNSEEN_LABEL: String = "skip" + val ERROR_UNSEEN_LABEL: String = "error" + val KEEP_UNSEEN_LABEL: String = "keep" + val supportedHandleInvalids: Array[String] = + Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) + /** + * Param for how to handle unseen labels. Options are 'skip' (filter out rows with + * unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with + * indices [numLabels]). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + + "error (throw an error), or 'keep' (map unseen labels with indices [numLabels]).", + ParamValidators.inArray(supportedHandleInvalids)) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, ERROR_UNSEEN_LABEL) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) @@ -70,11 +95,6 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def this() = this(Identifiable.randomUID("strIdx")) - /** @group setParam */ - @Since("1.6.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") - /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -141,11 +161,6 @@ class StringIndexerModel ( map } - /** @group setParam */ - @Since("1.6.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") - /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -163,25 +178,28 @@ class StringIndexerModel ( } transformSchema(dataset.schema, logging = true) + val metadata = NominalAttribute.defaultAttr + .withName($(outputCol)).withValues(labels).toMetadata() + // If we are skipping invalid records, filter them out. + val (filteredDataset, keepInvalid) = getHandleInvalid match { + case SKIP_UNSEEN_LABEL => + val filterer = udf { label: String => + labelToIndex.contains(label) + } + (dataset.where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == KEEP_UNSEEN_LABEL) + } + val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) + } else if (keepInvalid) { + labels.length } else { throw new SparkException(s"Unseen label: $label.") } } - val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(labels).toMetadata() - // If we are skipping invalid records, filter them out. - val filteredDataset = getHandleInvalid match { - case "skip" => - val filterer = udf { label: String => - labelToIndex.contains(label) - } - dataset.where(filterer(dataset($(inputCol)))) - case _ => dataset - } filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d0e63c9d669..daf2b29f3fdd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -64,7 +64,7 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) - val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) val df = data.toDF("id", "label") val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() @@ -75,22 +75,32 @@ class StringIndexerSuite intercept[SparkException] { indexer.transform(df2).collect() } - val indexerSkipInvalid = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .setHandleInvalid("skip") - .fit(df) + + indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformed = indexerSkipInvalid.transform(df2) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + var transformed = indexer.transform(df2) + var attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + val outputSkip = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 - val expected = Set((0, 1.0), (1, 0.0)) - assert(output === expected) + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + // Verify that we keep the unseen records + transformed = indexer.transform(df2) + attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val outputKeep = transformed.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 + val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) + assert(outputKeep === expectedKeep) } test("StringIndexer with a numeric input column") { From 5d4b07f517cdf52e5b3b0b786e1dba1993659b2e Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 15:02:44 +0800 Subject: [PATCH 02/12] fix compilation issue Signed-off-by: VinceShieh --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index c607ce4aafe2..f0707bbe2682 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException From 0eb7f0784a71cb695f4d936255abbe8ad30bd95d Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 16:16:57 +0800 Subject: [PATCH 03/12] code refactoring Signed-off-by: VinceShieh --- .../spark/ml/feature/StringIndexer.scala | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f0707bbe2682..44766c871f42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -56,14 +56,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha "error (throw an error), or 'keep' (map unseen labels with indices [numLabels]).", ParamValidators.inArray(supportedHandleInvalids)) - /** @group getParam */ - @Since("2.1.0") - def getHandleInvalid: String = $(handleInvalid) - - /** @group setParam */ - @Since("2.1.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, ERROR_UNSEEN_LABEL) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) @@ -105,6 +97,15 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, ERROR_UNSEEN_LABEL) + @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) @@ -171,6 +172,15 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, ERROR_UNSEEN_LABEL) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { From 9a4174579aa811c99a81967dd829e506c0096ccd Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 17:08:30 +0800 Subject: [PATCH 04/12] add exclusion rules in mima to pass binary compability check Signed-off-by: VinceShieh --- project/MimaExcludes.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9d359427f27a..a45d9b3bc23f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -905,8 +905,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksMax"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToJvmGCTime") ) ++ Seq( - // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") + // [SPARK-17498][ML] Enhance StringIndexer to handle unseen labels + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") + ) ++ Seq( + // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") From 1736057d055ad4a01dac3e9e79950bfcd9b91e1e Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 17:33:31 +0800 Subject: [PATCH 05/12] update document Signed-off-by: VinceShieh --- docs/ml-features.md | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 13d97a2290dc..84889b53fac0 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -502,7 +502,7 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The indices are in `[0, numLabels]`, ordered by label frequencies, so the most frequent label gets index `0`. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -542,12 +542,13 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. -Additionally, there are two strategies regarding how `StringIndexer` will handle +Additionally, there are three strategies regarding how `StringIndexer` will handle unseen labels when you have fit a `StringIndexer` on one dataset and then use it to transform another: - throw an exception (which is the default) - skip the row containing the unseen label entirely +- map the unseen labels with indices [numLabels] **Examples** @@ -561,6 +562,7 @@ Let's go back to our previous example but this time reuse our previously defined 1 | b 2 | c 3 | d + 4 | e ~~~~ If you've not set how `StringIndexer` handles unseen labels or set it to @@ -576,7 +578,22 @@ will be generated: 2 | c | 1.0 ~~~~ -Notice that the row containing "d" does not appear. +Notice that the rows containing "d" or "e" do not appear. + +If you had called `setHandleInvalid("keep")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | d | 3.0 + 4 | e | 3.0 +~~~~ + +Notice that the rows containing "d" or "e" are mapped with indices "3.0"
From ebe9ddb0dc3dd597d435f8a641fce790b4033a64 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 17:37:43 +0800 Subject: [PATCH 06/12] Revert "add exclusion rules in mima to pass binary compability check" This reverts commit 9a4174579aa811c99a81967dd829e506c0096ccd. --- project/MimaExcludes.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a45d9b3bc23f..9d359427f27a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -905,12 +905,8 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksMax"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToJvmGCTime") ) ++ Seq( - // [SPARK-17498][ML] Enhance StringIndexer to handle unseen labels - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") - ) ++ Seq( - // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") + // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") From 27c1b10f25db851cd1e670bd6a0d6e6f59c2ce1e Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Feb 2017 17:42:56 +0800 Subject: [PATCH 07/12] Mima changes to pass binary compatibility check Signed-off-by: VinceShieh --- project/MimaExcludes.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9d359427f27a..9d18f9f311c5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -907,6 +907,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") + ) ++ Seq( + // [SPARK-17498] StringIndexer enhancement for handling unseen labels + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") From 9bcaffc19e7a11d31aa6bb9ebbcd96367fc1cd38 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Wed, 1 Mar 2017 10:09:36 +0800 Subject: [PATCH 08/12] update Signed-off-by: VinceShieh --- docs/ml-features.md | 5 ++- .../spark/ml/feature/StringIndexer.scala | 45 +++++++++---------- .../spark/ml/feature/StringIndexerSuite.scala | 16 +++---- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 84889b53fac0..f3c64a6132b5 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -502,7 +502,8 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels]`, ordered by label frequencies, so the most frequent label gets index `0`. +The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -580,7 +581,7 @@ will be generated: Notice that the rows containing "d" or "e" do not appear. -If you had called `setHandleInvalid("keep")`, the following dataset +If you call `setHandleInvalid("keep")`, the following dataset will be generated: ~~~~ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 44766c871f42..8175321db411 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.language.existentials - import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -37,24 +35,26 @@ import org.apache.spark.util.collection.OpenHashMap * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { - val SKIP_UNSEEN_LABEL: String = "skip" - val ERROR_UNSEEN_LABEL: String = "error" - val KEEP_UNSEEN_LABEL: String = "keep" - val supportedHandleInvalids: Array[String] = - Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) /** * Param for how to handle unseen labels. Options are 'skip' (filter out rows with - * unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with - * indices [numLabels]). + * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional + * bucket, at index numLabels. * Default: "error" * @group param */ @Since("2.1.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (map unseen labels with indices [numLabels]).", - ParamValidators.inArray(supportedHandleInvalids)) + "error (throw an error), or 'keep' (put unseen labels in a special additional bucket," + + "at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -97,14 +97,9 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** @group getParam */ - @Since("2.1.0") - def getHandleInvalid: String = $(handleInvalid) - /** @group setParam */ - @Since("2.1.0") + @Since("2.2.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, ERROR_UNSEEN_LABEL) @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { @@ -128,7 +123,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { - + private[feature] val SKIP_UNSEEN_LABEL: String = "skip" + private[feature] val ERROR_UNSEEN_LABEL: String = "error" + private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) } @@ -172,14 +171,10 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** @group getParam */ - @Since("2.1.0") - def getHandleInvalid: String = $(handleInvalid) - /** @group setParam */ @Since("2.1.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, ERROR_UNSEEN_LABEL) + setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -194,12 +189,12 @@ class StringIndexerModel ( .withName($(outputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. val (filteredDataset, keepInvalid) = getHandleInvalid match { - case SKIP_UNSEEN_LABEL => + case StringIndexer.SKIP_UNSEEN_LABEL => val filterer = udf { label: String => labelToIndex.contains(label) } (dataset.where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == KEEP_UNSEEN_LABEL) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) } val indexer = udf { label: String => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index daf2b29f3fdd..68b562a23795 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -78,11 +78,11 @@ class StringIndexerSuite indexer.setHandleInvalid("skip") // Verify that we skip the c record - var transformed = indexer.transform(df2) - var attr = Attribute.fromStructField(transformed.schema("labelIndex")) + val transformedSkip = indexer.transform(df2) + val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val outputSkip = transformed.select("id", "labelIndex").rdd.map { r => + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 @@ -91,11 +91,11 @@ class StringIndexerSuite indexer.setHandleInvalid("keep") // Verify that we keep the unseen records - transformed = indexer.transform(df2) - attr = Attribute.fromStructField(transformed.schema("labelIndex")) + val transformedKeep = indexer.transform(df2) + val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val outputKeep = transformed.select("id", "labelIndex").rdd.map { r => + assert(attrKeep.values.get === Array("b", "a")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0, c -> 2, d -> 3 From 4dc10e6390b30fa8df9789479430e0a3f7c65c39 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Wed, 1 Mar 2017 10:16:29 +0800 Subject: [PATCH 09/12] update target version Signed-off-by: VinceShieh --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 8175321db411..7b6cb6b961e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -53,7 +53,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) /** @group getParam */ - @Since("2.1.0") + @Since("2.2.0") def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ @@ -172,7 +172,7 @@ class StringIndexerModel ( def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - @Since("2.1.0") + @Since("2.2.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) From fa24e433c3f9fe6f76fe0a55df4551881f194d7b Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Wed, 1 Mar 2017 10:26:43 +0800 Subject: [PATCH 10/12] fix compilation on val (filteredDataset, keepInvalid) = getHandleInvalid match { case .. } Signed-off-by: VinceShieh --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7b6cb6b961e7..62181e8e02ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException From d1acfdbf6ca3cb51b8abbb3696f245faafd74fef Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Mon, 6 Mar 2017 21:48:18 +0800 Subject: [PATCH 11/12] update Signed-off-by: VinceShieh --- docs/ml-features.md | 4 ++-- .../spark/ml/feature/StringIndexer.scala | 22 ++++++++++++------- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index f3c64a6132b5..25695b0462f6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -549,7 +549,7 @@ to transform another: - throw an exception (which is the default) - skip the row containing the unseen label entirely -- map the unseen labels with indices [numLabels] +- put unseen labels in a special additional bucket, at index numLabels **Examples** @@ -594,7 +594,7 @@ will be generated: 4 | e | 3.0 ~~~~ -Notice that the rows containing "d" or "e" are mapped with indices "3.0" +Notice that the rows containing "d" or "e" are mapped to index "3.0"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 62181e8e02ab..1e780d220080 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -48,7 +48,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha @Since("2.1.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special additional bucket," + + "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + "at index numLabels).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) @@ -91,6 +91,10 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + @Since("2.2.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -99,10 +103,6 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** @group setParam */ - @Since("2.2.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) @@ -130,6 +130,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_UNSEEN_LABEL: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) + @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) } @@ -176,7 +177,6 @@ class StringIndexerModel ( /** @group setParam */ @Since("2.2.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -187,8 +187,13 @@ class StringIndexerModel ( } transformSchema(dataset.schema, logging = true) + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case _ => labels + } + val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. val (filteredDataset, keepInvalid) = getHandleInvalid match { case StringIndexer.SKIP_UNSEEN_LABEL => @@ -205,7 +210,8 @@ class StringIndexerModel ( } else if (keepInvalid) { labels.length } else { - throw new SparkException(s"Unseen label: $label.") + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 68b562a23795..188dffb3dd55 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -94,7 +94,7 @@ class StringIndexerSuite val transformedKeep = indexer.transform(df2) val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a")) + assert(attrKeep.values.get === Array("b", "a", "__unknown")) val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet From c70e0034a2f84f4d5455a80262075777b72a54d3 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Tue, 7 Mar 2017 08:42:15 +0800 Subject: [PATCH 12/12] annotation update Signed-off-by: VinceShieh --- .../apache/spark/ml/feature/StringIndexer.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1e780d220080..810b02febbe7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -45,7 +45,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * Default: "error" * @group param */ - @Since("2.1.0") + @Since("1.6.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + @@ -55,7 +55,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) /** @group getParam */ - @Since("2.2.0") + @Since("1.6.0") def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ @@ -92,7 +92,7 @@ class StringIndexer @Since("1.4.0") ( def this() = this(Identifiable.randomUID("strIdx")) /** @group setParam */ - @Since("2.2.0") + @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) /** @group setParam */ @@ -166,6 +166,10 @@ class StringIndexerModel ( map } + /** @group setParam */ + @Since("1.6.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -174,10 +178,6 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** @group setParam */ - @Since("2.2.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) {