Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.2.0
*/
def fill(value: Long): DataFrame = fill(value, df.columns)
def fill(value: Long): DataFrame = fillValue(value, outputAttributes)

/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1
*/
def fill(value: Double): DataFrame = fill(value, df.columns)
def fill(value: Double): DataFrame = fillValue(value, outputAttributes)

/**
* Returns a new `DataFrame` that replaces null values in string columns with `value`.
*
* @since 1.3.1
*/
def fill(value: String): DataFrame = fill(value, df.columns)
def fill(value: String): DataFrame = fillValue(value, outputAttributes)

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
Expand All @@ -167,15 +167,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.2.0
*/
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
* @since 1.3.1
*/
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))


/**
Expand All @@ -192,22 +192,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))

/**
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
*
* @since 2.3.0
*/
def fill(value: Boolean): DataFrame = fill(value, df.columns)
def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
* boolean columns. If a specified column is not a boolean column, it is ignored.
*
* @since 2.3.0
*/
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))

/**
* Returns a new `DataFrame` that replaces null values in specified boolean columns.
Expand Down Expand Up @@ -433,15 +433,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {

/**
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
* It selects a column based on its name.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
val quotedColName = "`" + col.name + "`"
val colValue = col.dataType match {
fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
}

/**
* Returns a [[Column]] expression that replaces null value in `expr` with `replacement`.
* It uses the given `expr` as a column.
*/
private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = {
val colValue = dataType match {
case DoubleType | FloatType =>
nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
case _ => df.col(quotedColName)
nanvl(expr, lit(null)) // nanvl only supports these types
case _ => expr
}
coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
coalesce(colValue, lit(replacement).cast(dataType)).as(name)
}

/**
Expand All @@ -468,12 +477,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
s"Unsupported value type ${v.getClass.getName} ($v).")
}

private def toAttributes(cols: Seq[String]): Seq[Attribute] = {
cols.map(name => df.col(name).expr).collect {
case a: Attribute => a
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I noticed that drop() is resolving column names by cols.map(name => df.resolve(name) instead of df.col(name). The difference (other than the return type) is that df.col() will try to resolve using regex and adding metadata. Do you think we need to make this consistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should. To avoid breaking change, I think we should change drop to follow fill to make it more powerful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I will do this in a separate PR which will address the issue with handling duplicate columns in drop when no columns are specified.


private def outputAttributes: Seq[Attribute] = {
df.queryExecution.analyzed.output
}

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric, string columns. If a specified column is not a numeric, string
* or boolean column it is ignored.
* Returns a new `DataFrame` that replaces null or NaN values in the specified
* columns. If a specified column is not a numeric, string or boolean column,
* it is ignored.
*/
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = {
// the fill[T] which T is Long/Double,
// should apply on all the NumericType Column, for example:
// val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
Expand All @@ -487,18 +506,21 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
s"Unsupported value type ${value.getClass.getName} ($value).")
}

val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val filledColumns = df.schema.fields.filter { f =>
val typeMatches = (targetType, f.dataType) match {
val projections = outputAttributes.map { col =>
val typeMatches = (targetType, col.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
case (BooleanType, dt) => dt == BooleanType
case _ =>
throw new IllegalArgumentException(s"$targetType is not matched at fillValue")
}
// Only fill if the column is part of the cols list.
typeMatches && cols.exists(col => columnEquals(f.name, col))
if (typeMatches && cols.exists(_.semanticEquals(col))) {
fillCol(col.dataType, col.name, Column(col), value)
} else {
Column(col)
}
}
df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value)))
df.select(projections : _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{StringType, StructType}

class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -266,6 +267,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
assert(message.contains("Reference 'f2' is ambiguous"))
}

test("fill with col(*)") {
val df = createDF()
// If columns are specified with "*", they are ignored.
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
}

test("fill with nested columns") {
val schema = new StructType()
.add("c1", new StructType()
.add("c1-1", StringType)
.add("c1-2", StringType))

val data = Seq(
Row(Row(null, "a2")),
Row(Row("b1", "b2")),
Row(null))

val df = spark.createDataFrame(
spark.sparkContext.parallelize(data), schema)

checkAnswer(df.select("c1.c1-1"),
Row(null) :: Row("b1") :: Row(null) :: Nil)

// Nested columns are ignored for fill().
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can expose the bug if the data is

val data = Seq(
  Row(Row(null, "a2")),
  Row(Row("b1", "b2")),
  Row(null))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example works fine, and you also need to update the previous line to:

checkAnswer(df.select("c1.c1-1"),
  Row(null) :: Row("b1") :: Row(null) :: Nil)

(This is just to illustrate the nested type, but I can remove it if you think it's confusing.)

The reason the nested types are ignored is the following check:

  case (NumericType, dt) => dt.isInstanceOf[NumericType]
  case (StringType, dt) => dt == StringType
  case (BooleanType, dt) => dt == BooleanType

The datatype for the nested column that is resolved to Attribute is StructType, so this will not be matched.

}

test("replace") {
val input = createDF()

Expand Down Expand Up @@ -340,4 +368,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
)).na.drop("name" :: Nil).select("name"),
Row("Alice") :: Row("David") :: Nil)
}

test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") {
val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
val df = left.join(right, Seq("col1"))

// If column names are specified, the following fails due to ambiguity.
val exception = intercept[AnalysisException] {
df.na.fill("hello", Seq("col2"))
}
assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))

// If column names are not specified, fill() is applied to all the eligible columns.
checkAnswer(
df.na.fill("hello"),
Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
}
}