diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 27839d72c630..10d9ee52faca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -153,19 +153,26 @@ object TypeCoercion { t2: DataType, findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + findTypeFunc(et1, et2).map { et => + ArrayType(et, containsNull1 || containsNull2 || + Cast.forceNullable(et1, et) || Cast.forceNullable(et2, et)) + } case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findTypeFunc(kt1, kt2).flatMap { kt => - findTypeFunc(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } + findTypeFunc(kt1, kt2) + .filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) } + .flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2 || + Cast.forceNullable(vt1, vt) || Cast.forceNullable(vt2, vt)) + } } case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => val resolver = SQLConf.get.resolver fields1.zip(fields2).foldLeft(Option(new StructType())) { case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => - findTypeFunc(field1.dataType, field2.dataType).map { - dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + findTypeFunc(field1.dataType, field2.dataType).map { dt => + struct.add(field1.name, dt, field1.nullable || field2.nullable || + Cast.forceNullable(field1.dataType, dt) || Cast.forceNullable(field2.dataType, dt)) } case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d71bbb322713..2c6cb3ae1274 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -499,6 +499,10 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(new StructType().add("num", ShortType), containsNull = false), ArrayType(new StructType().add("num", LongType), containsNull = false), Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(IntegerType, containsNull = false), + ArrayType(DecimalType.IntDecimal, containsNull = false), + Some(ArrayType(DecimalType.IntDecimal, containsNull = true))) // MapType widenTestWithStringPromotion( @@ -517,6 +521,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType, valueContainsNull = false), + MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false), + Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType, valueContainsNull = false), + MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false), + None) // StructType widenTestWithStringPromotion( @@ -540,6 +552,10 @@ class TypeCoercionSuite extends AnalysisTest { .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), Some(new StructType() .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", IntegerType, nullable = false), + new StructType().add("num", DecimalType.IntDecimal, nullable = false), + Some(new StructType().add("num", DecimalType.IntDecimal, nullable = true))) widenTestWithStringPromotion( new StructType().add("num", IntegerType),