Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
val counts = dataset.select(col($(inputCol)).cast(StringType))
.map(_.getString(0))
.countByValue()
// Because we treat null label as invalid,
// we will always filter it out first. By the time we get to transform stage,
// we will look at the value of handleInvalid then either filter out invalid records,
// or throw an error
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
.filterNot({ case (v) => v == null })
copyValues(new StringIndexerModel(uid, labels).setParent(this))
}

Expand Down Expand Up @@ -151,10 +156,16 @@ class StringIndexerModel (
}

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
if (label == null) {
// The default to handle null value is to throw an error
throw new SparkException("The input column contains null value." +
" You can use StringIndexer.setHandleInvalid(\"skip\") to filter out null value.")
} else {
throw new SparkException(s"Unseen label: $label.")
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
throw new SparkException(s"Unseen label: $label.")
}
}
}

Expand All @@ -164,7 +175,7 @@ class StringIndexerModel (
val filteredDataset = (getHandleInvalid) match {
case "skip" => {
val filterer = udf { label: String =>
labelToIndex.contains(label)
label != null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This no longer filters out values that are not present in labaelToIndex

}
dataset.where(filterer(dataset($(inputCol))))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ private[ml] trait HasFitIntercept extends Params {
private[ml] trait HasHandleInvalid extends Params {

/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
* Param for how to handle invalid entries. Options are skip (which will filter out rows with null values), or error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've changed the meaning of an existing parameter - this is generally something we want to avoid doing.

* (which will throw an error). More options may be added later..
* @group param
*/
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,6 @@ class StringIndexerSuite
intercept[SparkException] {
indexer.transform(df2).collect()
}
val indexerSkipInvalid = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.setHandleInvalid("skip")
.fit(df)
// Verify that we skip the c record
val transformed = indexerSkipInvalid.transform(df2)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("b", "a"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expected = Set((0, 1.0), (1, 0.0))
assert(output === expected)
}

test("StringIndexer with a numeric input column") {
Expand Down Expand Up @@ -199,4 +183,61 @@ class StringIndexerSuite
.setLabels(Array("a", "b", "c"))
testDefaultReadWrite(t)
}

test("StringIndexer with null value (SPARK-11569)") {
val df = sqlContext.createDataFrame(
Seq(("asd2s", "1e1e", 1.1, 0, 0.0), ("asd2s", "1e1e", 0.1, 0, 0.0),
(null, "1e3e", 1.2, 0, 9.9), (null, "1e1e", 5.1, 1, 9.9),
("asd2s", "1e3e", 0.2, 0, 0.0), ("bd34t", "1e2e", 4.3, 1, 1.0))
).toDF("x0", "x1", "x2", "x3", "expected")

// setHandleInvalid("skip") after fit
val indexer1 = new StringIndexer().setInputCol("x0").setOutputCol("actual").fit(df)
.setHandleInvalid("skip")
val transformed1 = indexer1.transform(df)
// Verify that we skip the null record
val attr = Attribute.fromStructField(transformed1.schema("actual"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("asd2s", "bd34t"))
// asd2s -> 0, bd24t -> 1, null is filterd out
transformed1.select("expected", "actual").collect().foreach {
case Row(actual, expected) =>
assert(actual === expected)
}

// setHandleInvalid("skip") before fit
val indexer2 = new StringIndexer().setInputCol("x0").setOutputCol("actual")
.setHandleInvalid("skip").fit(df)
val transformed2 = indexer2.transform(df)
// Verify that we skip the null record
val attr2 = Attribute.fromStructField(transformed2.schema("actual"))
.asInstanceOf[NominalAttribute]
assert(attr2.values.get === Array("asd2s", "bd34t"))
// asd2s -> 0, bd24t -> 1, null is filterd out
transformed2.select("expected", "actual").collect().foreach {
case Row(actual, expected) =>
assert(actual === expected)
}

// setHandleInvalid("error") before fit
intercept[SparkException] {
val indexer3 = new StringIndexer().setInputCol("x0").setOutputCol("actual")
.setHandleInvalid("error").fit(df)
indexer3.transform(df).collect()
}

// setHandleInvalid("error") after fit
intercept[SparkException] {
val indexer4 = new StringIndexer().setInputCol("x0").setOutputCol("actual")
.fit(df).setHandleInvalid("error")
indexer4.transform(df).collect()
}

// default is setHandleInvalid("error")
intercept[SparkException] {
val indexer5 = new StringIndexer().setInputCol("x0").setOutputCol("actual")
.fit(df)
indexer5.transform(df).collect()
}
}
}