Skip to content

Commit 9bcaffc

Browse files
author
VinceShieh
committed
update
Signed-off-by: VinceShieh <vincent.xie@intel.com>
1 parent 27c1b10 commit 9bcaffc

File tree

3 files changed

+31
-35
lines changed

3 files changed

+31
-35
lines changed

docs/ml-features.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,8 @@ for more details on the API.
502502
## StringIndexer
503503

504504
`StringIndexer` encodes a string column of labels to a column of label indices.
505-
The indices are in `[0, numLabels]`, ordered by label frequencies, so the most frequent label gets index `0`.
505+
The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`.
506+
The unseen labels will be put at index numLabels if user chooses to keep them.
506507
If the input column is numeric, we cast it to string and index the string
507508
values. When downstream pipeline components such as `Estimator` or
508509
`Transformer` make use of this string-indexed label, you must set the input
@@ -580,7 +581,7 @@ will be generated:
580581

581582
Notice that the rows containing "d" or "e" do not appear.
582583

583-
If you had called `setHandleInvalid("keep")`, the following dataset
584+
If you call `setHandleInvalid("keep")`, the following dataset
584585
will be generated:
585586

586587
~~~~

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import scala.language.existentials
21-
2220
import org.apache.hadoop.fs.Path
2321

2422
import org.apache.spark.SparkException
@@ -37,24 +35,26 @@ import org.apache.spark.util.collection.OpenHashMap
3735
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
3836
*/
3937
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
40-
val SKIP_UNSEEN_LABEL: String = "skip"
41-
val ERROR_UNSEEN_LABEL: String = "error"
42-
val KEEP_UNSEEN_LABEL: String = "keep"
43-
val supportedHandleInvalids: Array[String] =
44-
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
4538

4639
/**
4740
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with
48-
* unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with
49-
* indices [numLabels]).
41+
* unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
42+
* bucket, at index numLabels.
5043
* Default: "error"
5144
* @group param
5245
*/
5346
@Since("2.1.0")
5447
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
5548
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
56-
"error (throw an error), or 'keep' (map unseen labels with indices [numLabels]).",
57-
ParamValidators.inArray(supportedHandleInvalids))
49+
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket," +
50+
"at index numLabels).",
51+
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
52+
53+
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
54+
55+
/** @group getParam */
56+
@Since("2.1.0")
57+
def getHandleInvalid: String = $(handleInvalid)
5858

5959
/** Validates and transforms the input schema. */
6060
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -97,14 +97,9 @@ class StringIndexer @Since("1.4.0") (
9797
@Since("1.4.0")
9898
def setOutputCol(value: String): this.type = set(outputCol, value)
9999

100-
/** @group getParam */
101-
@Since("2.1.0")
102-
def getHandleInvalid: String = $(handleInvalid)
103-
104100
/** @group setParam */
105-
@Since("2.1.0")
101+
@Since("2.2.0")
106102
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
107-
setDefault(handleInvalid, ERROR_UNSEEN_LABEL)
108103

109104
@Since("2.0.0")
110105
override def fit(dataset: Dataset[_]): StringIndexerModel = {
@@ -128,7 +123,11 @@ class StringIndexer @Since("1.4.0") (
128123

129124
@Since("1.6.0")
130125
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
131-
126+
private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
127+
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
128+
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
129+
private[feature] val supportedHandleInvalids: Array[String] =
130+
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
132131
@Since("1.6.0")
133132
override def load(path: String): StringIndexer = super.load(path)
134133
}
@@ -172,14 +171,10 @@ class StringIndexerModel (
172171
@Since("1.4.0")
173172
def setOutputCol(value: String): this.type = set(outputCol, value)
174173

175-
/** @group getParam */
176-
@Since("2.1.0")
177-
def getHandleInvalid: String = $(handleInvalid)
178-
179174
/** @group setParam */
180175
@Since("2.1.0")
181176
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
182-
setDefault(handleInvalid, ERROR_UNSEEN_LABEL)
177+
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
183178

184179
@Since("2.0.0")
185180
override def transform(dataset: Dataset[_]): DataFrame = {
@@ -194,12 +189,12 @@ class StringIndexerModel (
194189
.withName($(outputCol)).withValues(labels).toMetadata()
195190
// If we are skipping invalid records, filter them out.
196191
val (filteredDataset, keepInvalid) = getHandleInvalid match {
197-
case SKIP_UNSEEN_LABEL =>
192+
case StringIndexer.SKIP_UNSEEN_LABEL =>
198193
val filterer = udf { label: String =>
199194
labelToIndex.contains(label)
200195
}
201196
(dataset.where(filterer(dataset($(inputCol)))), false)
202-
case _ => (dataset, getHandleInvalid == KEEP_UNSEEN_LABEL)
197+
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
203198
}
204199

205200
val indexer = udf { label: String =>

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ class StringIndexerSuite
7878

7979
indexer.setHandleInvalid("skip")
8080
// Verify that we skip the c record
81-
var transformed = indexer.transform(df2)
82-
var attr = Attribute.fromStructField(transformed.schema("labelIndex"))
81+
val transformedSkip = indexer.transform(df2)
82+
val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex"))
8383
.asInstanceOf[NominalAttribute]
84-
assert(attr.values.get === Array("b", "a"))
85-
val outputSkip = transformed.select("id", "labelIndex").rdd.map { r =>
84+
assert(attrSkip.values.get === Array("b", "a"))
85+
val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
8686
(r.getInt(0), r.getDouble(1))
8787
}.collect().toSet
8888
// a -> 1, b -> 0
@@ -91,11 +91,11 @@ class StringIndexerSuite
9191

9292
indexer.setHandleInvalid("keep")
9393
// Verify that we keep the unseen records
94-
transformed = indexer.transform(df2)
95-
attr = Attribute.fromStructField(transformed.schema("labelIndex"))
94+
val transformedKeep = indexer.transform(df2)
95+
val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
9696
.asInstanceOf[NominalAttribute]
97-
assert(attr.values.get === Array("b", "a"))
98-
val outputKeep = transformed.select("id", "labelIndex").rdd.map { r =>
97+
assert(attrKeep.values.get === Array("b", "a"))
98+
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
9999
(r.getInt(0), r.getDouble(1))
100100
}.collect().toSet
101101
// a -> 1, b -> 0, c -> 2, d -> 3

0 commit comments

Comments
 (0)