Skip to content

Commit 41bbd2c

Browse files
committed
fix bug in Nanvl type coercion
1 parent 3dee204 commit 41bbd2c

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,10 @@ object HiveTypeCoercion {
563563
case None => c
564564
}
565565

566-
case n @ NaNvl(l, r) if l.dataType != r.dataType =>
567-
l.dataType match {
568-
case DoubleType => NaNvl(l, Cast(r, DoubleType))
569-
case FloatType => NaNvl(Cast(l, DoubleType), r)
570-
}
566+
case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
567+
NaNvl(l, Cast(r, DoubleType))
568+
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
569+
NaNvl(Cast(l, DoubleType), r)
571570
}
572571
}
573572

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,20 +227,24 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
227227

228228
test("nanvl") {
229229
val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
230-
Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil),
230+
Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
231231
StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
232-
StructField("c", DoubleType), StructField("d", DoubleType))))
232+
StructField("c", DoubleType), StructField("d", DoubleType),
233+
StructField("e", FloatType), StructField("f", IntegerType))))
233234

234235
checkAnswer(
235236
testData.select(
236-
nanvl($"a", lit(5)), nanvl($"b", lit(10)),
237-
nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))),
238-
Row(null, 3.0, null, Double.PositiveInfinity)
237+
nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"),
238+
nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)),
239+
nanvl($"b", $"e"), nanvl($"e", $"f")),
240+
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
239241
)
240242
testData.registerTempTable("t")
241243
checkAnswer(
242-
ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"),
243-
Row(null, 3.0, null, Double.PositiveInfinity)
244+
ctx.sql(
245+
"select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
246+
" nanvl(b, e), nanvl(e, f) from t"),
247+
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
244248
)
245249
}
246250

0 commit comments

Comments
 (0)