Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.2.0
*/
def fill(value: Long): DataFrame = fill(value, df.columns)
def fill(value: Long): DataFrame = fillValue(value, outputAttributes)

/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1
*/
def fill(value: Double): DataFrame = fill(value, df.columns)
def fill(value: Double): DataFrame = fillValue(value, outputAttributes)

/**
* Returns a new `DataFrame` that replaces null values in string columns with `value`.
*
* @since 1.3.1
*/
def fill(value: String): DataFrame = fill(value, df.columns)
def fill(value: String): DataFrame = fillValue(value, outputAttributes)

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
Expand All @@ -168,15 +168,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.2.0
*/
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
* @since 1.3.1
*/
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))


/**
Expand All @@ -193,22 +193,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))

/**
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
*
* @since 2.3.0
*/
def fill(value: Boolean): DataFrame = fill(value, df.columns)
def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
* boolean columns. If a specified column is not a boolean column, it is ignored.
*
* @since 2.3.0
*/
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))

/**
* Returns a new `DataFrame` that replaces null values in specified boolean columns.
Expand Down Expand Up @@ -434,15 +434,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {

/**
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
* It selects a column based on its name.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
val quotedColName = "`" + col.name + "`"
val colValue = col.dataType match {
fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
}

/**
* Returns a [[Column]] expression that replaces null value in `expr` with `replacement`.
* It uses the given `expr` as a column.
*/
private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = {
val colValue = dataType match {
case DoubleType | FloatType =>
nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
case _ => df.col(quotedColName)
nanvl(expr, lit(null)) // nanvl only supports these types
case _ => expr
}
coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
coalesce(colValue, lit(replacement).cast(dataType)).as(name)
}

/**
Expand All @@ -469,12 +478,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
s"Unsupported value type ${v.getClass.getName} ($v).")
}

private def toAttributes(cols: Seq[String]): Seq[Attribute] = {
cols.map(name => df.col(name).expr).collect {
case a: Attribute => a
}
}

private def outputAttributes: Seq[Attribute] = {
df.queryExecution.analyzed.output
}

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric, string columns. If a specified column is not a numeric, string
* or boolean column it is ignored.
* Returns a new `DataFrame` that replaces null or NaN values in the specified
* columns. If a specified column is not a numeric, string or boolean column,
* it is ignored.
*/
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = {
// the fill[T] which T is Long/Double,
// should apply on all the NumericType Column, for example:
// val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
Expand All @@ -488,20 +507,19 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
s"Unsupported value type ${value.getClass.getName} ($value).")
}

val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
val typeMatches = (targetType, f.dataType) match {
val projections = outputAttributes.map { col =>
val typeMatches = (targetType, col.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
case (BooleanType, dt) => dt == BooleanType
case _ =>
throw new IllegalArgumentException(s"$targetType is not matched at fillValue")
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
fillCol[T](f, value)
if (typeMatches && cols.exists(_.semanticEquals(col))) {
fillCol(col.dataType, col.name, Column(col), value)
} else {
df.col(f.name)
Column(col)
}
}
df.select(projections : _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}

class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -239,6 +240,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
}
}

test("fill with col(*)") {
val df = createDF()
// If columns are specified with "*", they are ignored.
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
}

test("fill with nested columns") {
val schema = new StructType()
.add("c1", new StructType()
.add("c1-1", StringType)
.add("c1-2", StringType))

val data = Seq(
Row(Row(null, "a2")),
Row(Row("b1", "b2")),
Row(null))

val df = spark.createDataFrame(
spark.sparkContext.parallelize(data), schema)

checkAnswer(df.select("c1.c1-1"),
Row(null) :: Row("b1") :: Row(null) :: Nil)

// Nested columns are ignored for fill().
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
}

test("replace") {
val input = createDF()

Expand Down Expand Up @@ -349,4 +377,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}

test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") {
val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
val df = left.join(right, Seq("col1"))

// If column names are specified, the following fails due to ambiguity.
val exception = intercept[AnalysisException] {
df.na.fill("hello", Seq("col2"))
}
assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))

// If column names are not specified, fill() is applied to all the eligible columns.
checkAnswer(
df.na.fill("hello"),
Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
}
}