Skip to content

Commit 2df7e35

Browse files
committed
RFormula should handle invalid for both features and label column.
1 parent d2d2a5d commit 2df7e35

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,16 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
134134
def getFormula: String = $(formula)
135135

136136
/**
137-
* Param for how to handle invalid data (unseen labels or NULL values).
138-
* Options are 'skip' (filter out rows with invalid data),
137+
* Param for how to handle invalid data (unseen or NULL values) in features and label column
138+
* of string type. Options are 'skip' (filter out rows with invalid data),
139139
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
140140
* bucket, at index numLabels).
141141
* Default: "error"
142142
* @group param
143143
*/
144144
@Since("2.3.0")
145-
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
146-
"How to handle invalid data (unseen labels or NULL values). " +
145+
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to " +
146+
"handle invalid data (unseen or NULL values) in features and label column of string type. " +
147147
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
148148
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
149149
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
@@ -265,6 +265,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
265265
encoderStages += new StringIndexer()
266266
.setInputCol(resolvedFormula.label)
267267
.setOutputCol($(labelCol))
268+
.setHandleInvalid($(handleInvalid))
268269
}
269270

270271
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkException, SparkFunSuite}
2121
import org.apache.spark.ml.attribute._
2222
import org.apache.spark.ml.linalg.Vectors
2323
import org.apache.spark.ml.param.ParamsSuite
@@ -501,4 +501,51 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
501501
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
502502
}
503503
}
504+
505+
test("handle unseen features or labels") {
506+
val df1 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
507+
val df2 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zy")).toDF("id", "a", "b")
508+
509+
// Handle unseen features.
510+
val formula1 = new RFormula().setFormula("id ~ a + b")
511+
intercept[SparkException] {
512+
formula1.fit(df1).transform(df2).collect()
513+
}
514+
val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
515+
val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)
516+
517+
val expected1 = Seq(
518+
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
519+
(2, "bar", "zq", Vectors.dense(1.0, 1.0), 2.0)
520+
).toDF("id", "a", "b", "features", "label")
521+
val expected2 = Seq(
522+
(1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0, 0.0), 1.0),
523+
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 1.0, 0.0), 2.0),
524+
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
525+
).toDF("id", "a", "b", "features", "label")
526+
527+
assert(result1.collect() === expected1.collect())
528+
assert(result2.collect() === expected2.collect())
529+
530+
// Handle unseen labels.
531+
val formula2 = new RFormula().setFormula("b ~ a + id")
532+
intercept[SparkException] {
533+
formula2.fit(df1).transform(df2).collect()
534+
}
535+
val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
536+
val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)
537+
538+
val expected3 = Seq(
539+
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
540+
(2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
541+
).toDF("id", "a", "b", "features", "label")
542+
val expected4 = Seq(
543+
(1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0),
544+
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0),
545+
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
546+
).toDF("id", "a", "b", "features", "label")
547+
548+
assert(result3.collect() === expected3.collect())
549+
assert(result4.collect() === expected4.collect())
550+
}
504551
}

python/pyspark/ml/feature.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,8 +2107,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
21072107
typeConverter=TypeConverters.toString)
21082108

21092109
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
2110-
"labels or NULL values). Options are 'skip' (filter out rows with " +
2111-
"invalid data), error (throw an error), or 'keep' (put invalid data " +
2110+
"or NULL values) in features and label column of string type. " +
2111+
"Options are 'skip' (filter out rows with invalid data), " +
2112+
"error (throw an error), or 'keep' (put invalid data " +
21122113
"in a special additional bucket, at index numLabels).",
21132114
typeConverter=TypeConverters.toString)
21142115

0 commit comments

Comments
 (0)