Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
s" and outputCols(${$(outputCols).length}) should have the same length")
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
val inputField = schema(inputCol)
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, inputCol)
StructField(outputCol, inputField.dataType, inputField.nullable)
}
StructType(schema ++ outputFields)
Expand All @@ -84,9 +84,13 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean or the median
* of the columns in which the missing values are located. The input columns should be of
* DoubleType or FloatType. Currently Imputer does not support categorical features
* numeric type. Currently Imputer does not support categorical features
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
*
* Note when an input column is integer, the imputed value is casted (truncated) to an integer type.
* For example, if the input column is IntegerType (1, 2, 4, null),
* the output will be IntegerType (1, 2, 4, 2) after mean imputation.
*
* Note that the mean/median value is computed after filtering out missing values.
* All Null values in the input columns are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
Expand Down Expand Up @@ -218,7 +222,7 @@ class ImputerModel private[ml] (
val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
val ic = col(inputCol).cast(DoubleType)
when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import org.apache.spark.SparkException
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class ImputerSuite extends MLTest with DefaultReadWriteTest {

Expand Down Expand Up @@ -176,6 +178,48 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest {
assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
}

test("Imputer for IntegerType with default missing value null") {

val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
(1, 1, 1),
(11, 11, 11),
(3, 3, 3),
(null, 5, 3)
)).toDF("value1", "expected_mean_value1", "expected_median_value1")

val imputer = new Imputer()
.setInputCols(Array("value1"))
.setOutputCols(Array("out1"))

val types = Seq(IntegerType, LongType)
for (mType <- types) {
// cast all columns to desired data type for testing
val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*)
ImputerSuite.iterateStrategyTest(imputer, df2)
}
}

test("Imputer for IntegerType with missing value -1") {

val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)](
(1, 1, 1),
(11, 11, 11),
(3, 3, 3),
(-1, 5, 3)
)).toDF("value1", "expected_mean_value1", "expected_median_value1")

val imputer = new Imputer()
.setInputCols(Array("value1"))
.setOutputCols(Array("out1"))
.setMissingValue(-1.0)

val types = Seq(IntegerType, LongType)
for (mType <- types) {
// cast all columns to desired data type for testing
val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*)
ImputerSuite.iterateStrategyTest(imputer, df2)
}
}
}

object ImputerSuite {
Expand All @@ -190,13 +234,26 @@ object ImputerSuite {
val model = imputer.fit(df)
val resultDF = model.transform(df)
imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) =>

// check dataType is consistent between input and output
val inputType = resultDF.schema(inputCol).dataType
val outputType = resultDF.schema(outputCol).dataType
assert(inputType == outputType, "Output type is not the same as input type.")

// check value
resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach {
case Row(exp: Float, out: Float) =>
assert((exp.isNaN && out.isNaN) || (exp == out),
s"Imputed values differ. Expected: $exp, actual: $out")
case Row(exp: Double, out: Double) =>
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
s"Imputed values differ. Expected: $exp, actual: $out")
case Row(exp: Integer, out: Integer) =>
assert(exp == out,
s"Imputed values differ. Expected: $exp, actual: $out")
case Row(exp: Long, out: Long) =>
assert(exp == out,
s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}
Expand Down