Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that HasSkipInvalid should be mixed-in with StringIndexerModel rather than StringIndexerBase. Skipping invalid is really a parameter that can be set for the resulting model and is not used by StringIndexer#fit except in copyValues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion on L69.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK as is, sorry for my mixup

with HasHandleInvalid {

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
Expand Down Expand Up @@ -64,13 +65,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod

def this() = this(Identifiable.randomUID("strIdx"))

/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

// TODO: handle unseen labels

override def fit(dataset: DataFrame): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
Expand Down Expand Up @@ -110,6 +114,10 @@ class StringIndexerModel private[ml] (
map
}

/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand All @@ -127,14 +135,24 @@ class StringIndexerModel private[ml] (
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
// TODO: handle unseen labels
throw new SparkException(s"Unseen label: $label.")
}
}

val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"),
// If we are skipping invalid records, filter them out.
val filteredDataset = (getHandleInvalid) match {
case "skip" => {
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
dataset.where(filterer(dataset($(inputCol))))
}
case _ => dataset
}
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
"will filter out rows with bad values), or error (which will throw an errror). More " +
"options may be added later.",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
" before fitting the model.", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ private[ml] trait HasFitIntercept extends Params {
final def getFitIntercept: Boolean = $(fitIntercept)
}

/**
* Trait for shared param handleInvalid.
*/
private[ml] trait HasHandleInvalid extends Params {

/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
* @group param
*/
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))

/** @group getParam */
final def getHandleInvalid: String = $(handleInvalid)
}

/**
* Trait for shared param standardization (default: true).
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
Expand Down Expand Up @@ -49,6 +50,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(output === expected)
}

test("StringIndexerUnseen") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
// Verify we throw by default with unseen values
intercept[SparkException] {
indexer.transform(df2).collect()
}
val indexerSkipInvalid = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.setHandleInvalid("skip")
.fit(df)
// Verify that we skip the c record
val transformed = indexerSkipInvalid.transform(df2)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("b", "a"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expected = Set((0, 1.0), (1, 0.0))
assert(output === expected)
}

test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
Expand Down