diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6df03aa8e84e9..6d8c2e83ef797 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -242,7 +242,13 @@ abstract class Expression extends TreeNode[Expression] { * This means that the lazy `cannonicalized` is called and computed only on the root of the * adjacent expressions. */ - lazy val canonicalized: Expression = { + lazy val canonicalized: Expression = withCanonicalizedChildren + + /** + * The default process of canonicalization. It is a one pass, bottum-up expression tree + * computation based oncanonicalizing children before canonicalizing the current node. + */ + final protected def withCanonicalizedChildren: Expression = { val canonicalizedChildren = children.map(_.canonicalized) withNewChildren(canonicalizedChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d82108aa3c9f8..a72b84978c002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -479,7 +479,15 @@ case class Add( override lazy val canonicalized: Expression = { // TODO: do not reorder consecutive `Add`s with different `evalMode` - orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode)) + val reorderResult = + orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode)) + if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) { + reorderResult + } else { + // SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is + // changed, which may cause data checking error within ComplexTypeMergingExpression. + withCanonicalizedChildren + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 43b7f35f7bb24..057fb98c23985 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -187,7 +187,23 @@ class CanonicalizeSuite extends SparkFunSuite { test("SPARK-40362: Commutative operator under BinaryComparison") { Seq(EqualTo, EqualNullSafe, GreaterThan, LessThan, GreaterThanOrEqual, LessThanOrEqual) .foreach { bc => - assert(bc(Add($"a", $"b"), Literal(10)).semanticEquals(bc(Add($"b", $"a"), Literal(10)))) + assert(bc(Multiply($"a", $"b"), Literal(10)).semanticEquals( + bc(Multiply($"b", $"a"), Literal(10)))) } } + + test("SPARK-40903: Only reorder decimal Add when the result data type is not changed") { + val d = Decimal(1.2) + val literal1 = Literal.create(d, DecimalType(2, 1)) + val literal2 = Literal.create(d, DecimalType(2, 1)) + val literal3 = Literal.create(d, DecimalType(3, 2)) + assert(Add(literal1, literal2).semanticEquals(Add(literal2, literal1))) + assert(Add(Add(literal1, literal2), literal3).semanticEquals( + Add(Add(literal3, literal2), literal1))) + + val literal4 = Literal.create(d, DecimalType(12, 5)) + val literal5 = Literal.create(d, DecimalType(12, 6)) + assert(!Add(Add(literal4, literal5), literal1).semanticEquals( + Add(Add(literal1, literal5), literal4))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 030e68d227aae..dd3ad0f4d6bd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4518,6 +4518,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-40903: Don't reorder Add for canonicalize if it is decimal type") { + val tableName = "decimalTable" + withTable(tableName) { + sql(s"create table $tableName(a decimal(12, 5), b decimal(12, 6)) using orc") + checkAnswer(sql(s"select sum(coalesce(a + b + 1.75, a)) from $tableName"), Row(null)) + } + } } case class Foo(bar: Option[String])