@@ -48,7 +48,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
4848 @ Since (" 2.1.0" )
4949 val handleInvalid : Param [String ] = new Param [String ](this , " handleInvalid" , " how to handle " +
5050 " unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
51- " error (throw an error), or 'keep' (put unseen labels in a special additional bucket," +
51+ " error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
5252 " at index numLabels)." ,
5353 ParamValidators .inArray(StringIndexer .supportedHandleInvalids))
5454
@@ -91,6 +91,10 @@ class StringIndexer @Since("1.4.0") (
9191 @ Since (" 1.4.0" )
9292 def this () = this (Identifiable .randomUID(" strIdx" ))
9393
94+ /** @group setParam */
95+ @ Since (" 2.2.0" )
96+ def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
97+
9498 /** @group setParam */
9599 @ Since (" 1.4.0" )
96100 def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -99,10 +103,6 @@ class StringIndexer @Since("1.4.0") (
99103 @ Since (" 1.4.0" )
100104 def setOutputCol (value : String ): this .type = set(outputCol, value)
101105
102- /** @group setParam */
103- @ Since (" 2.2.0" )
104- def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
105-
106106 @ Since (" 2.0.0" )
107107 override def fit (dataset : Dataset [_]): StringIndexerModel = {
108108 transformSchema(dataset.schema, logging = true )
@@ -130,6 +130,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
130130 private [feature] val KEEP_UNSEEN_LABEL : String = " keep"
131131 private [feature] val supportedHandleInvalids : Array [String ] =
132132 Array (SKIP_UNSEEN_LABEL , ERROR_UNSEEN_LABEL , KEEP_UNSEEN_LABEL )
133+
133134 @ Since (" 1.6.0" )
134135 override def load (path : String ): StringIndexer = super .load(path)
135136}
@@ -176,7 +177,6 @@ class StringIndexerModel (
176177 /** @group setParam */
177178 @ Since (" 2.2.0" )
178179 def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
179- setDefault(handleInvalid, StringIndexer .ERROR_UNSEEN_LABEL )
180180
181181 @ Since (" 2.0.0" )
182182 override def transform (dataset : Dataset [_]): DataFrame = {
@@ -187,8 +187,13 @@ class StringIndexerModel (
187187 }
188188 transformSchema(dataset.schema, logging = true )
189189
190+ val filteredLabels = getHandleInvalid match {
191+ case StringIndexer .KEEP_UNSEEN_LABEL => labels :+ " __unknown"
192+ case _ => labels
193+ }
194+
190195 val metadata = NominalAttribute .defaultAttr
191- .withName($(outputCol)).withValues(labels ).toMetadata()
196+ .withName($(outputCol)).withValues(filteredLabels ).toMetadata()
192197 // If we are skipping invalid records, filter them out.
193198 val (filteredDataset, keepInvalid) = getHandleInvalid match {
194199 case StringIndexer .SKIP_UNSEEN_LABEL =>
@@ -205,7 +210,8 @@ class StringIndexerModel (
205210 } else if (keepInvalid) {
206211 labels.length
207212 } else {
208- throw new SparkException (s " Unseen label: $label. " )
213+ throw new SparkException (s " Unseen label: $label. To handle unseen labels, " +
214+ s " set Param handleInvalid to ${StringIndexer .KEEP_UNSEEN_LABEL }. " )
209215 }
210216 }
211217
0 commit comments