diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 78df89d491ac..fda8dddcefa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -437,12 +437,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ private def fillCol[T](col: StructField, replacement: T): Column = { val quotedColName = "`" + col.name + "`" - val colValue = col.dataType match { + fillCol(col.dataType, col.name, df.col(quotedColName), replacement) + } + + /** + * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`. + * It uses the given `expr` as a column. + */ + private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = { + val colValue = dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types - case _ => df.col(quotedColName) + nanvl(expr, lit(null)) // nanvl only supports these types + case _ => expr } - coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) + coalesce(colValue, lit(replacement).cast(dataType)).as(name) } /** @@ -489,6 +497,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } val columnEquals = df.sparkSession.sessionState.analyzer.resolver + val projections = df.schema.fields.map { f => val typeMatches = (targetType, f.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] @@ -499,7 +508,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { - fillCol[T](f, value) + fillCol(f.dataType, f.name, Column(f.name), value) } else { df.col(f.name) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 7cf0d25b07fd..7a8f82936a92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -349,4 +349,93 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) } + + test("SPARK-28897:coalesce error executing dataframe.na.fill " + + "with spark.sql.parser.quotedRegexColumnNames true") { + val input = createDF() + + val boolInput = Seq[(String, java.lang.Boolean)]( + ("Bob", false), + ("Alice", null), + ("Mallory", true), + (null, null) + ).toDF("name", "spy") + + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // boolean + checkAnswer( + boolInput.na.fill(true).select("spy"), + Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil) + assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) + + // fill boolean with subset columns + checkAnswer( + boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"), + Row("Bob", false) :: + Row("Alice", true) :: + Row("Mallory", true) :: + Row(null, true) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + + checkAnswer( + Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) + .toDF("a", "b").na.fill(0), + Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(2.34), + Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil + ) + } + } }