Skip to content

Commit d1acfdb

Browse files
author
VinceShieh
committed
update
Signed-off-by: VinceShieh <vincent.xie@intel.com>
1 parent fa24e43 commit d1acfdb

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

docs/ml-features.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ to transform another:
549549

550550
- throw an exception (which is the default)
551551
- skip the row containing the unseen label entirely
552-
- map the unseen labels with indices [numLabels]
552+
- put unseen labels in a special additional bucket, at index numLabels
553553

554554
**Examples**
555555

@@ -594,7 +594,7 @@ will be generated:
594594
4 | e | 3.0
595595
~~~~
596596

597-
Notice that the rows containing "d" or "e" are mapped with indices "3.0"
597+
Notice that the rows containing "d" or "e" are mapped to index "3.0"
598598

599599
<div class="codetabs">
600600

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
4848
@Since("2.1.0")
4949
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
5050
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
51-
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket," +
51+
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
5252
"at index numLabels).",
5353
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
5454

@@ -91,6 +91,10 @@ class StringIndexer @Since("1.4.0") (
9191
@Since("1.4.0")
9292
def this() = this(Identifiable.randomUID("strIdx"))
9393

94+
/** @group setParam */
95+
@Since("2.2.0")
96+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
97+
9498
/** @group setParam */
9599
@Since("1.4.0")
96100
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -99,10 +103,6 @@ class StringIndexer @Since("1.4.0") (
99103
@Since("1.4.0")
100104
def setOutputCol(value: String): this.type = set(outputCol, value)
101105

102-
/** @group setParam */
103-
@Since("2.2.0")
104-
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
105-
106106
@Since("2.0.0")
107107
override def fit(dataset: Dataset[_]): StringIndexerModel = {
108108
transformSchema(dataset.schema, logging = true)
@@ -130,6 +130,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
130130
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
131131
private[feature] val supportedHandleInvalids: Array[String] =
132132
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
133+
133134
@Since("1.6.0")
134135
override def load(path: String): StringIndexer = super.load(path)
135136
}
@@ -176,7 +177,6 @@ class StringIndexerModel (
176177
/** @group setParam */
177178
@Since("2.2.0")
178179
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
179-
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
180180

181181
@Since("2.0.0")
182182
override def transform(dataset: Dataset[_]): DataFrame = {
@@ -187,8 +187,13 @@ class StringIndexerModel (
187187
}
188188
transformSchema(dataset.schema, logging = true)
189189

190+
val filteredLabels = getHandleInvalid match {
191+
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
192+
case _ => labels
193+
}
194+
190195
val metadata = NominalAttribute.defaultAttr
191-
.withName($(outputCol)).withValues(labels).toMetadata()
196+
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
192197
// If we are skipping invalid records, filter them out.
193198
val (filteredDataset, keepInvalid) = getHandleInvalid match {
194199
case StringIndexer.SKIP_UNSEEN_LABEL =>
@@ -205,7 +210,8 @@ class StringIndexerModel (
205210
} else if (keepInvalid) {
206211
labels.length
207212
} else {
208-
throw new SparkException(s"Unseen label: $label.")
213+
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
214+
s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
209215
}
210216
}
211217

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class StringIndexerSuite
9494
val transformedKeep = indexer.transform(df2)
9595
val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
9696
.asInstanceOf[NominalAttribute]
97-
assert(attrKeep.values.get === Array("b", "a"))
97+
assert(attrKeep.values.get === Array("b", "a", "__unknown"))
9898
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
9999
(r.getInt(0), r.getDouble(1))
100100
}.collect().toSet

0 commit comments

Comments
 (0)