|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.feature |
19 | 19 |
|
20 | | -import org.apache.spark.SparkFunSuite |
| 20 | +import org.apache.spark.{SparkException, SparkFunSuite} |
21 | 21 | import org.apache.spark.ml.attribute._ |
22 | 22 | import org.apache.spark.ml.linalg.Vectors |
23 | 23 | import org.apache.spark.ml.param.ParamsSuite |
@@ -501,4 +501,51 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul |
501 | 501 | assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) |
502 | 502 | } |
503 | 503 | } |
| 504 | + |
| 505 | + test("handle unseen features or labels") { |
| 506 | + val df1 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") |
| 507 | + val df2 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zy")).toDF("id", "a", "b") |
| 508 | + |
| 509 | + // Handle unseen features. |
| 510 | + val formula1 = new RFormula().setFormula("id ~ a + b") |
| 511 | + intercept[SparkException] { |
| 512 | + formula1.fit(df1).transform(df2).collect() |
| 513 | + } |
| 514 | + val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) |
| 515 | + val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) |
| 516 | + |
| 517 | + val expected1 = Seq( |
| 518 | + (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), |
| 519 | + (2, "bar", "zq", Vectors.dense(1.0, 1.0), 2.0) |
| 520 | + ).toDF("id", "a", "b", "features", "label") |
| 521 | + val expected2 = Seq( |
| 522 | + (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0, 0.0), 1.0), |
| 523 | + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 1.0, 0.0), 2.0), |
| 524 | + (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) |
| 525 | + ).toDF("id", "a", "b", "features", "label") |
| 526 | + |
| 527 | + assert(result1.collect() === expected1.collect()) |
| 528 | + assert(result2.collect() === expected2.collect()) |
| 529 | + |
| 530 | + // Handle unseen labels. |
| 531 | + val formula2 = new RFormula().setFormula("b ~ a + id") |
| 532 | + intercept[SparkException] { |
| 533 | + formula2.fit(df1).transform(df2).collect() |
| 534 | + } |
| 535 | + val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) |
| 536 | + val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) |
| 537 | + |
| 538 | + val expected3 = Seq( |
| 539 | + (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), |
| 540 | + (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) |
| 541 | + ).toDF("id", "a", "b", "features", "label") |
| 542 | + val expected4 = Seq( |
| 543 | + (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), |
| 544 | + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), |
| 545 | + (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) |
| 546 | + ).toDF("id", "a", "b", "features", "label") |
| 547 | + |
| 548 | + assert(result3.collect() === expected3.collect()) |
| 549 | + assert(result4.collect() === expected4.collect()) |
| 550 | + } |
504 | 551 | } |
0 commit comments