diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1c074e204ad9..bdad804083b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -73,7 +73,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu s" and outputCols(${$(outputCols).length}) should have the same length") val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => val inputField = schema(inputCol) - SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) + SchemaUtils.checkNumericType(schema, inputCol) StructField(outputCol, inputField.dataType, inputField.nullable) } StructType(schema ++ outputFields) @@ -84,9 +84,13 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu * :: Experimental :: * Imputation estimator for completing missing values, either using the mean or the median * of the columns in which the missing values are located. The input columns should be of - * DoubleType or FloatType. Currently Imputer does not support categorical features + * numeric type. Currently Imputer does not support categorical features * (SPARK-15041) and possibly creates incorrect values for a categorical feature. * + * Note when an input column is integer, the imputed value is casted (truncated) to an integer type. + * For example, if the input column is IntegerType (1, 2, 4, null), + * the output will be IntegerType (1, 2, 4, 2) after mean imputation. + * * Note that the mean/median value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. @@ -218,7 +222,7 @@ class ImputerModel private[ml] ( val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType - val ic = col(inputCol) + val ic = col(inputCol).cast(DoubleType) when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 75f63a623e6d..02ef261a6c06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -20,6 +20,8 @@ import org.apache.spark.SparkException import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ class ImputerSuite extends MLTest with DefaultReadWriteTest { @@ -176,6 +178,48 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) } + test("Imputer for IntegerType with default missing value null") { + + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (null, 5, 3) + )).toDF("value1", "expected_mean_value1", "expected_median_value1") + + val imputer = new Imputer() + .setInputCols(Array("value1")) + .setOutputCols(Array("out1")) + + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(imputer, df2) + } + } + + test("Imputer for IntegerType with missing value -1") { + + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (-1, 5, 3) + )).toDF("value1", "expected_mean_value1", "expected_median_value1") + + val imputer = new Imputer() + .setInputCols(Array("value1")) + .setOutputCols(Array("out1")) + .setMissingValue(-1.0) + + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(imputer, df2) + } + } } object ImputerSuite { @@ -190,6 +234,13 @@ object ImputerSuite { val model = imputer.fit(df) val resultDF = model.transform(df) imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + + // check dataType is consistent between input and output + val inputType = resultDF.schema(inputCol).dataType + val outputType = resultDF.schema(outputCol).dataType + assert(inputType == outputType, "Output type is not the same as input type.") + + // check value resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { case Row(exp: Float, out: Float) => assert((exp.isNaN && out.isNaN) || (exp == out), @@ -197,6 +248,12 @@ object ImputerSuite { case Row(exp: Double, out: Double) => assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Integer, out: Integer) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Long, out: Long) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") } } }