Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,49 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
import spark.implicits._
val surrogates = $(inputCols).map { inputCol =>
val ic = col(inputCol)
val filtered = dataset.select(ic.cast(DoubleType))
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
if(filtered.take(1).length == 0) {
throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
}
val surrogate = $(strategy) match {
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
}
surrogate

val cols = $(inputCols).map { inputCol =>
when(col(inputCol).equalTo($(missingValue)), null)
.when(col(inputCol).isNaN, null)
.otherwise(col(inputCol))
.cast("double")
.as(inputCol)
}

val results = $(strategy) match {
case Imputer.mean =>
// Function avg will ignore null automatically.
// For a column only containing null, avg will return null.
val row = dataset.select(cols.map(avg): _*).head()
Array.range(0, $(inputCols).length).map { i =>
if (row.isNullAt(i)) {
Double.NaN
} else {
row.getDouble(i)
}
}

case Imputer.median =>
// Function approxQuantile will ignore null automatically.
// For a column only containing null, approxQuantile will return an empty array.
dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001)
.map { array =>
if (array.isEmpty) {
Double.NaN
} else {
array.head
}
}
}

val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1)
if (emptyCols.nonEmpty) {
throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
s"missingValue(${$(missingValue)})")
}

val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results)))
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
val surrogateDF = spark.createDataFrame(rows, schema)
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
Expand Down