1717
1818package org .apache .spark .ml .feature
1919
20- import scala .language .existentials
21-
2220import org .apache .hadoop .fs .Path
2321
2422import org .apache .spark .SparkException
@@ -37,24 +35,26 @@ import org.apache.spark.util.collection.OpenHashMap
3735 * Base trait for [[StringIndexer ]] and [[StringIndexerModel ]].
3836 */
3937private [feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
40- val SKIP_UNSEEN_LABEL : String = " skip"
41- val ERROR_UNSEEN_LABEL : String = " error"
42- val KEEP_UNSEEN_LABEL : String = " keep"
43- val supportedHandleInvalids : Array [String ] =
44- Array (SKIP_UNSEEN_LABEL , ERROR_UNSEEN_LABEL , KEEP_UNSEEN_LABEL )
4538
4639 /**
4740 * Param for how to handle unseen labels. Options are 'skip' (filter out rows with
48- * unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with
49- * indices [ numLabels]) .
41+ * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
42+ * bucket, at index numLabels.
5043 * Default: "error"
5144 * @group param
5245 */
5346 @ Since (" 2.1.0" )
5447 val handleInvalid : Param [String ] = new Param [String ](this , " handleInvalid" , " how to handle " +
5548 " unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
56- " error (throw an error), or 'keep' (map unseen labels with indices [numLabels])." ,
57- ParamValidators .inArray(supportedHandleInvalids))
49+ " error (throw an error), or 'keep' (put unseen labels in a special additional bucket," +
50+ " at index numLabels)." ,
51+ ParamValidators .inArray(StringIndexer .supportedHandleInvalids))
52+
53+ setDefault(handleInvalid, StringIndexer .ERROR_UNSEEN_LABEL )
54+
55+ /** @group getParam */
56+ @ Since (" 2.1.0" )
57+ def getHandleInvalid : String = $(handleInvalid)
5858
5959 /** Validates and transforms the input schema. */
6060 protected def validateAndTransformSchema (schema : StructType ): StructType = {
@@ -97,14 +97,9 @@ class StringIndexer @Since("1.4.0") (
9797 @ Since (" 1.4.0" )
9898 def setOutputCol (value : String ): this .type = set(outputCol, value)
9999
100- /** @group getParam */
101- @ Since (" 2.1.0" )
102- def getHandleInvalid : String = $(handleInvalid)
103-
104100 /** @group setParam */
105- @ Since (" 2.1 .0" )
101+ @ Since (" 2.2 .0" )
106102 def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
107- setDefault(handleInvalid, ERROR_UNSEEN_LABEL )
108103
109104 @ Since (" 2.0.0" )
110105 override def fit (dataset : Dataset [_]): StringIndexerModel = {
@@ -128,7 +123,11 @@ class StringIndexer @Since("1.4.0") (
128123
129124@ Since (" 1.6.0" )
130125object StringIndexer extends DefaultParamsReadable [StringIndexer ] {
131-
126+ private [feature] val SKIP_UNSEEN_LABEL : String = " skip"
127+ private [feature] val ERROR_UNSEEN_LABEL : String = " error"
128+ private [feature] val KEEP_UNSEEN_LABEL : String = " keep"
129+ private [feature] val supportedHandleInvalids : Array [String ] =
130+ Array (SKIP_UNSEEN_LABEL , ERROR_UNSEEN_LABEL , KEEP_UNSEEN_LABEL )
132131 @ Since (" 1.6.0" )
133132 override def load (path : String ): StringIndexer = super .load(path)
134133}
@@ -172,14 +171,10 @@ class StringIndexerModel (
172171 @ Since (" 1.4.0" )
173172 def setOutputCol (value : String ): this .type = set(outputCol, value)
174173
175- /** @group getParam */
176- @ Since (" 2.1.0" )
177- def getHandleInvalid : String = $(handleInvalid)
178-
179174 /** @group setParam */
180175 @ Since (" 2.1.0" )
181176 def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
182- setDefault(handleInvalid, ERROR_UNSEEN_LABEL )
177+ setDefault(handleInvalid, StringIndexer . ERROR_UNSEEN_LABEL )
183178
184179 @ Since (" 2.0.0" )
185180 override def transform (dataset : Dataset [_]): DataFrame = {
@@ -194,12 +189,12 @@ class StringIndexerModel (
194189 .withName($(outputCol)).withValues(labels).toMetadata()
195190 // If we are skipping invalid records, filter them out.
196191 val (filteredDataset, keepInvalid) = getHandleInvalid match {
197- case SKIP_UNSEEN_LABEL =>
192+ case StringIndexer . SKIP_UNSEEN_LABEL =>
198193 val filterer = udf { label : String =>
199194 labelToIndex.contains(label)
200195 }
201196 (dataset.where(filterer(dataset($(inputCol)))), false )
202- case _ => (dataset, getHandleInvalid == KEEP_UNSEEN_LABEL )
197+ case _ => (dataset, getHandleInvalid == StringIndexer . KEEP_UNSEEN_LABEL )
203198 }
204199
205200 val indexer = udf { label : String =>
0 commit comments