diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 8447ada88a70..2a86b65b8f79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -456,7 +456,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => - Seq(buildExpr(source), buildExpr(target)) + Seq(Literal(source), buildExpr(target)) }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 6cb35656835a..fb1ca69b6f73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -37,6 +37,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { ).toDF("name", "age", "height") } + def createNaNDF(): DataFrame = { + Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0), + (0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + } + test("drop") { val input = createDF() val rows = input.collect() @@ -404,4 +412,40 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { df.na.drop("any"), Row("5", "6", "6") :: Nil) } + + test("replace nan with float") { + checkAnswer( + createNaNDF().na.replace("*", Map( + Float.NaN -> 10.0f + )), + Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) :: + Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil) + } + + test("replace nan with double") { + checkAnswer( + createNaNDF().na.replace("*", Map( + Double.NaN -> 10.0 + )), + Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) :: + Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil) + } + + test("replace float with nan") { + checkAnswer( + createNaNDF().na.replace("*", Map( + 1.0f -> Float.NaN + )), + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + } + + test("replace double with nan") { + checkAnswer( + createNaNDF().na.replace("*", Map( + 1.0 -> Double.NaN + )), + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + } }