diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 5288907b7d7f..78df89d491ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -455,7 +455,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => - Seq(buildExpr(source), buildExpr(target)) + Seq(Literal(source), buildExpr(target)) }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index e6983b6be555..7cf0d25b07fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -36,6 +36,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("name", "age", "height") } + def createNaNDF(): DataFrame = { + Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0), + (0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + } + test("drop") { val input = createDF() val rows = input.collect() @@ -305,4 +313,40 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { )).na.drop("name" :: Nil).select("name"), Row("Alice") :: Row("David") :: Nil) } + + test("replace nan with float") { + checkAnswer( + createNaNDF().na.replace("*", Map( + Float.NaN -> 10.0f + )), + Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) :: + Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil) + } + + test("replace nan with double") { + checkAnswer( + createNaNDF().na.replace("*", Map( + Double.NaN -> 10.0 + )), + Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) :: + Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil) + } + + test("replace float with nan") { + checkAnswer( + createNaNDF().na.replace("*", Map( + 1.0f -> Float.NaN + )), + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + } + + test("replace double with nan") { + checkAnswer( + createNaNDF().na.replace("*", Map( + 1.0 -> Double.NaN + )), + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) + } }