Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ for more details on the API.

`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`.
The unseen labels will be put at index numLabels if user chooses to keep them.
If the input column is numeric, we cast it to string and index the string
values. When downstream pipeline components such as `Estimator` or
`Transformer` make use of this string-indexed label, you must set the input
Expand Down Expand Up @@ -542,12 +543,13 @@ column, we should get the following:
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
index `2`.

Additionally, there are two strategies regarding how `StringIndexer` will handle
Additionally, there are three strategies regarding how `StringIndexer` will handle
unseen labels when you have fit a `StringIndexer` on one dataset and then use it
to transform another:

- throw an exception (which is the default)
- skip the row containing the unseen label entirely
- put unseen labels in a special additional bucket, at index numLabels

**Examples**

Expand All @@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined
1 | b
2 | c
3 | d
4 | e
~~~~

If you've not set how `StringIndexer` handles unseen labels or set it to
Expand All @@ -576,7 +579,22 @@ will be generated:
2 | c | 1.0
~~~~

Notice that the row containing "d" does not appear.
Notice that the rows containing "d" or "e" do not appear.

If you call `setHandleInvalid("keep")`, the following dataset
will be generated:

~~~~
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
2 | c | 1.0
3 | d | 3.0
4 | e | 3.0
~~~~

Notice that the rows containing "d" or "e" are mapped to index "3.0"

<div class="codetabs">

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.ml.feature

import scala.language.existentials
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local build&test are fine, but will get compilation error on line 193 on Jenkins


import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
Expand All @@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
with HasHandleInvalid {
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {

/**
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with
* unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
* bucket, at index numLabels.
* Default: "error"
* @group param
*/
@Since("1.6.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
"at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))

setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)

/** @group getParam */
@Since("1.6.0")
def getHandleInvalid: String = $(handleInvalid)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
Expand Down Expand Up @@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") (
/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
@Since("1.4.0")
Expand Down Expand Up @@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") (

@Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)

@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
Expand Down Expand Up @@ -144,7 +169,6 @@ class StringIndexerModel (
/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
@Since("1.4.0")
Expand All @@ -163,25 +187,34 @@ class StringIndexerModel (
}
transformSchema(dataset.schema, logging = true)

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
throw new SparkException(s"Unseen label: $label.")
}
val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
case _ => labels
}

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val filteredDataset = getHandleInvalid match {
case "skip" =>
val (filteredDataset, keepInvalid) = getHandleInvalid match {
case StringIndexer.SKIP_UNSEEN_LABEL =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
dataset.where(filterer(dataset($(inputCol))))
case _ => dataset
(dataset.where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
}

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
}
}

filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class StringIndexerSuite

test("StringIndexerUnseen") {
val data = Seq((0, "a"), (1, "b"), (4, "b"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
val df = data.toDF("id", "label")
val df2 = data2.toDF("id", "label")
val indexer = new StringIndexer()
Expand All @@ -75,22 +75,32 @@ class StringIndexerSuite
intercept[SparkException] {
indexer.transform(df2).collect()
}
val indexerSkipInvalid = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.setHandleInvalid("skip")
.fit(df)

indexer.setHandleInvalid("skip")
// Verify that we skip the c record
val transformed = indexerSkipInvalid.transform(df2)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
val transformedSkip = indexer.transform(df2)
val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("b", "a"))
val output = transformed.select("id", "labelIndex").rdd.map { r =>
assert(attrSkip.values.get === Array("b", "a"))
val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expected = Set((0, 1.0), (1, 0.0))
assert(output === expected)
val expectedSkip = Set((0, 1.0), (1, 0.0))
assert(outputSkip === expectedSkip)

indexer.setHandleInvalid("keep")
// Verify that we keep the unseen records
val transformedKeep = indexer.transform(df2)
val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attrKeep.values.get === Array("b", "a", "__unknown"))
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0, c -> 2, d -> 3
val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
assert(outputKeep === expectedKeep)
}

test("StringIndexer with a numeric input column") {
Expand Down
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,10 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
) ++ Seq(
// [SPARK-17498] StringIndexer enhancement for handling unseen labels
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel")
) ++ Seq(
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
Expand Down