Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
val quotedColName = "`" + col.name + "`"
val colValue = col.dataType match {
fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
}

/**
* Returns a [[Column]] expression that replaces null value in `expr` with `replacement`.
* It uses the given `expr` as a column.
*/
private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = {
val colValue = dataType match {
case DoubleType | FloatType =>
nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
case _ => df.col(quotedColName)
nanvl(expr, lit(null)) // nanvl only supports these types
case _ => expr
}
coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
coalesce(colValue, lit(replacement).cast(dataType)).as(name)
}

/**
Expand Down Expand Up @@ -489,6 +497,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
}

val columnEquals = df.sparkSession.sessionState.analyzer.resolver

val projections = df.schema.fields.map { f =>
val typeMatches = (targetType, f.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
Expand All @@ -499,7 +508,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
fillCol[T](f, value)
fillCol(f.dataType, f.name, Column(f.name), value)
} else {
df.col(f.name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,93 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}

test("SPARK-28897:coalesce error executing dataframe.na.fill " +
"with spark.sql.parser.quotedRegexColumnNames true") {
val input = createDF()

val boolInput = Seq[(String, java.lang.Boolean)](
("Bob", false),
("Alice", null),
("Mallory", true),
(null, null)
).toDF("name", "spy")

withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") {
val fillNumeric = input.na.fill(50.6)
checkAnswer(
fillNumeric,
Row("Bob", 16, 176.5) ::
Row("Alice", 50, 164.3) ::
Row("David", 60, 50.6) ::
Row("Nina", 25, 50.6) ::
Row("Amy", 50, 50.6) ::
Row(null, 50, 50.6) :: Nil)

// Make sure the columns are properly named.
assert(fillNumeric.columns.toSeq === input.columns.toSeq)

// string
checkAnswer(
input.na.fill("unknown").select("name"),
Row("Bob") :: Row("Alice") :: Row("David") ::
Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil)
assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)

// boolean
checkAnswer(
boolInput.na.fill(true).select("spy"),
Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil)
assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq)

// fill double with subset columns
checkAnswer(
input.na.fill(50.6, "age" :: Nil).select("name", "age"),
Row("Bob", 16) ::
Row("Alice", 50) ::
Row("David", 60) ::
Row("Nina", 25) ::
Row("Amy", 50) ::
Row(null, 50) :: Nil)

// fill boolean with subset columns
checkAnswer(
boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"),
Row("Bob", false) ::
Row("Alice", true) ::
Row("Mallory", true) ::
Row(null, true) :: Nil)

// fill string with subset columns
checkAnswer(
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
Row("test", null))

checkAnswer(
Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L))
.toDF("a", "b").na.fill(0),
Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil
)

checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null),
(9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2),
Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6)
:: Row(0, 0.2) :: Nil
)

checkAnswer(
Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null),
(9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2),
Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f)
:: Row(0, 0.2f) :: Nil
)

checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
.toDF("a", "b").na.fill(2.34),
Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil
)
}
}
}