diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 116227224fdd1..8fb597155fcb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -481,12 +481,12 @@ case class Add( // TODO: do not reorder consecutive `Add`s with different `evalMode` val reorderResult = orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode)) - if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) { - reorderResult + if (resolved && reorderResult.resolved && reorderResult.dataType != dataType) { + // SPARK-40903: Append cast for the canonicalization of decimal Add if the result data type is + // changed. Otherwise, it may cause data checking error within ComplexTypeMergingExpression. + Cast(reorderResult, dataType) } else { - // SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is - // changed, which may cause data checking error within ComplexTypeMergingExpression. - withCanonicalizedChildren + reorderResult } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 057fb98c23985..35ed48855080a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -206,4 +206,16 @@ class CanonicalizeSuite extends SparkFunSuite { assert(!Add(Add(literal4, literal5), literal1).semanticEquals( Add(Add(literal1, literal5), literal4))) } + + test("SPARK-40903: Append Casting if the canonicalization result type is changed") { + val d = Decimal(1.2) + val literal1 = Literal.create(d, DecimalType(2, 1)) + val literal2 = Literal.create(d, DecimalType(12, 5)) + val literal3 = Literal.create(d, DecimalType(12, 6)) + val add1 = Add(literal1, Add(literal2, literal3)).canonicalized + val add2 = Add(Add(literal3, literal2), literal1).canonicalized + assert(add1.isInstanceOf[Cast] && add1.dataType == DecimalType(15, 6)) + assert(add2.isInstanceOf[Cast] && add2.dataType == DecimalType(15, 6)) + assert(add1 == add2) + } }