diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 5c40c35eeaa4..33cdccb2f81f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -84,7 +84,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod val counts = dataset.select(col($(inputCol)).cast(StringType)) .map(_.getString(0)) .countByValue() + // Because we treat null label as invalid, + // we will always filter it out first. By the time we get to transform stage, + // we will look at the value of handleInvalid then either filter out invalid records, + // or throw an error val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + .filterNot({ case (v) => v == null }) copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -151,10 +156,16 @@ class StringIndexerModel ( } val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) + if (label == null) { + // The default to handle null value is to throw an error + throw new SparkException("The input column contains null value." + + " You can use StringIndexer.setHandleInvalid(\"skip\") to filter out null value.") } else { - throw new SparkException(s"Unseen label: $label.") + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + throw new SparkException(s"Unseen label: $label.") + } } } @@ -164,7 +175,7 @@ class StringIndexerModel ( val filteredDataset = (getHandleInvalid) match { case "skip" => { val filterer = udf { label: String => - labelToIndex.contains(label) + label != null } dataset.where(filterer(dataset($(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cb2a060a34dd..b319bd7be7ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -255,7 +255,8 @@ private[ml] trait HasFitIntercept extends Params { private[ml] trait HasHandleInvalid extends Params { /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * Param for how to handle invalid entries. Options are skip (which will filter out rows with null values), or error + * (which will throw an error). More options may be added later.. * @group param */ final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 749bfac74782..613e274a683a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -73,22 +73,6 @@ class StringIndexerSuite intercept[SparkException] { indexer.transform(df2).collect() } - val indexerSkipInvalid = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .setHandleInvalid("skip") - .fit(df) - // Verify that we skip the c record - val transformed = indexerSkipInvalid.transform(df2) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - // a -> 1, b -> 0 - val expected = Set((0, 1.0), (1, 0.0)) - assert(output === expected) } test("StringIndexer with a numeric input column") { @@ -199,4 +183,61 @@ class StringIndexerSuite .setLabels(Array("a", "b", "c")) testDefaultReadWrite(t) } + + test("StringIndexer with null value (SPARK-11569)") { + val df = sqlContext.createDataFrame( + Seq(("asd2s", "1e1e", 1.1, 0, 0.0), ("asd2s", "1e1e", 0.1, 0, 0.0), + (null, "1e3e", 1.2, 0, 9.9), (null, "1e1e", 5.1, 1, 9.9), + ("asd2s", "1e3e", 0.2, 0, 0.0), ("bd34t", "1e2e", 4.3, 1, 1.0)) + ).toDF("x0", "x1", "x2", "x3", "expected") + + // setHandleInvalid("skip") after fit + val indexer1 = new StringIndexer().setInputCol("x0").setOutputCol("actual").fit(df) + .setHandleInvalid("skip") + val transformed1 = indexer1.transform(df) + // Verify that we skip the null record + val attr = Attribute.fromStructField(transformed1.schema("actual")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("asd2s", "bd34t")) + // asd2s -> 0, bd24t -> 1, null is filterd out + transformed1.select("expected", "actual").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + // setHandleInvalid("skip") before fit + val indexer2 = new StringIndexer().setInputCol("x0").setOutputCol("actual") + .setHandleInvalid("skip").fit(df) + val transformed2 = indexer2.transform(df) + // Verify that we skip the null record + val attr2 = Attribute.fromStructField(transformed2.schema("actual")) + .asInstanceOf[NominalAttribute] + assert(attr2.values.get === Array("asd2s", "bd34t")) + // asd2s -> 0, bd24t -> 1, null is filterd out + transformed2.select("expected", "actual").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + // setHandleInvalid("error") before fit + intercept[SparkException] { + val indexer3 = new StringIndexer().setInputCol("x0").setOutputCol("actual") + .setHandleInvalid("error").fit(df) + indexer3.transform(df).collect() + } + + // setHandleInvalid("error") after fit + intercept[SparkException] { + val indexer4 = new StringIndexer().setInputCol("x0").setOutputCol("actual") + .fit(df).setHandleInvalid("error") + indexer4.transform(df).collect() + } + + // default is setHandleInvalid("error") + intercept[SparkException] { + val indexer5 = new StringIndexer().setInputCol("x0").setOutputCol("actual") + .fit(df) + indexer5.transform(df).collect() + } + } }