diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index bf7be363b822..eb8def3ccef6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -32,7 +32,8 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -64,13 +65,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -110,6 +114,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -127,14 +135,24 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels throw new SparkException(s"Unseen label: $label.") } } + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f7ae1de522e0..41c38e943658 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,6 +53,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 65e48e4ee508..26556547f675 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -232,6 +232,21 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + /** * Trait for shared param standardization (default: true). */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f82bea4268..a46ba3a3dbd7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -49,6 +50,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(output === expected) } + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) + } + test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label")