diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 6dd21f114c90..07b0a54ba077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -130,20 +130,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long): DataFrame = fill(value, df.columns) + def fill(value: Long): DataFrame = fillValue(value, outputAttributes) /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DataFrame = fill(value, df.columns) + def fill(value: Double): DataFrame = fillValue(value, outputAttributes) /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DataFrame = fill(value, df.columns) + def fill(value: String): DataFrame = fillValue(value, outputAttributes) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. @@ -167,7 +167,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.2.0 */ - def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -175,7 +175,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** @@ -192,14 +192,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. * * @since 2.3.0 */ - def fill(value: Boolean): DataFrame = fill(value, df.columns) + def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified @@ -207,7 +207,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols) + def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols)) /** * Returns a new `DataFrame` that replaces null values in specified boolean columns. @@ -433,15 +433,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + * It selects a column based on its name. */ private def fillCol[T](col: StructField, replacement: T): Column = { val quotedColName = "`" + col.name + "`" - val colValue = col.dataType match { + fillCol(col.dataType, col.name, df.col(quotedColName), replacement) + } + + /** + * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`. + * It uses the given `expr` as a column. + */ + private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = { + val colValue = dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types - case _ => df.col(quotedColName) + nanvl(expr, lit(null)) // nanvl only supports these types + case _ => expr } - coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) + coalesce(colValue, lit(replacement).cast(dataType)).as(name) } /** @@ -468,12 +477,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { s"Unsupported value type ${v.getClass.getName} ($v).") } + private def toAttributes(cols: Seq[String]): Seq[Attribute] = { + cols.map(name => df.col(name).expr).collect { + case a: Attribute => a + } + } + + private def outputAttributes: Seq[Attribute] = { + df.queryExecution.analyzed.output + } + /** - * Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric, string columns. If a specified column is not a numeric, string - * or boolean column it is ignored. + * Returns a new `DataFrame` that replaces null or NaN values in the specified + * columns. If a specified column is not a numeric, string or boolean column, + * it is ignored. */ - private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { + private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = { // the fill[T] which T is Long/Double, // should apply on all the NumericType Column, for example: // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") @@ -487,9 +506,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { s"Unsupported value type ${value.getClass.getName} ($value).") } - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val filledColumns = df.schema.fields.filter { f => - val typeMatches = (targetType, f.dataType) match { + val projections = outputAttributes.map { col => + val typeMatches = (targetType, col.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType case (BooleanType, dt) => dt == BooleanType @@ -497,8 +515,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. - typeMatches && cols.exists(col => columnEquals(f.name, col)) + if (typeMatches && cols.exists(_.semanticEquals(col))) { + fillCol(col.dataType, col.name, Column(col), value) + } else { + Column(col) + } } - df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value))) + df.select(projections : _*) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 75642a0bd932..1afe733b855b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -266,6 +267,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { assert(message.contains("Reference 'f2' is ambiguous")) } + test("fill with col(*)") { + val df = createDF() + // If columns are specified with "*", they are ignored. + checkAnswer(df.na.fill("new name", Seq("*")), df.collect()) + } + + test("fill with nested columns") { + val schema = new StructType() + .add("c1", new StructType() + .add("c1-1", StringType) + .add("c1-2", StringType)) + + val data = Seq( + Row(Row(null, "a2")), + Row(Row("b1", "b2")), + Row(null)) + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(data), schema) + + checkAnswer(df.select("c1.c1-1"), + Row(null) :: Row("b1") :: Row(null) :: Nil) + + // Nested columns are ignored for fill(). + checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data) + } + test("replace") { val input = createDF() @@ -340,4 +368,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { )).na.drop("name" :: Nil).select("name"), Row("Alice") :: Row("David") :: Nil) } + + test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") { + val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2") + val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2") + val df = left.join(right, Seq("col1")) + + // If column names are specified, the following fails due to ambiguity. + val exception = intercept[AnalysisException] { + df.na.fill("hello", Seq("col2")) + } + assert(exception.getMessage.contains("Reference 'col2' is ambiguous")) + + // If column names are not specified, fill() is applied to all the eligible columns. + checkAnswer( + df.na.fill("hello"), + Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil) + } }