diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 3a78f14c8b2c..1199a347eafe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -363,9 +363,9 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper left: Expression, right: Expression, f: (Expression, Expression, Expression) => Expression): Expression = { - val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType] - val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType] - MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f)) + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) } val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), @@ -402,6 +402,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mii0, miin, multiplyKeyWithValues), null) + assert(map_zip_with(mii0, mii1, multiplyKeyWithValues).dataType === + MapType(IntegerType, IntegerType, valueContainsNull = true)) val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), MapType(StringType, StringType, valueContainsNull = false)) @@ -437,6 +439,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mss0, mssn, concat), null) + assert(map_zip_with(mss0, mss1, concat).dataType === + MapType(StringType, StringType, valueContainsNull = true)) def b(data: Byte*): Array[Byte] = Array[Byte](data: _*)