From ffb3e929605b608c3571c0e7ad4ac4465fd3454e Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 25 Nov 2025 21:04:07 +0000 Subject: [PATCH] Track conditionally evaluated expressions to resolve as subexpressions for cases they are already being evaluated --- .../expressions/EquivalentExpressions.scala | 157 +++++++++++++----- .../sql/catalyst/expressions/Expression.scala | 6 + .../expressions/conditionalExpressions.scala | 27 +-- .../expressions/nullExpressions.scala | 14 +- .../apache/spark/sql/internal/SQLConf.scala | 11 ++ .../SubexpressionEliminationSuite.scala | 124 +++++++++++--- 6 files changed, 243 insertions(+), 96 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 78f73f8778b8..8c1515f51de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -61,15 +61,22 @@ class EquivalentExpressions( private def updateExprInMap( expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats], - useCount: Int = 1): Boolean = { + useCount: Int = 1, + conditional: Boolean = false): Boolean = { if (expr.deterministic) { val wrapper = ExpressionEquals(expr) map.get(wrapper) match { case Some(stats) => - stats.useCount += useCount - if (stats.useCount > 0) { + val count = if (conditional) { + stats.conditionalUseCount += useCount + stats.conditionalUseCount + } else { + stats.useCount += useCount + stats.useCount + } + if (count > 0) { true - } else if (stats.useCount == 0) { + } else if (count == 0) { map -= wrapper false } else { @@ -79,7 +86,12 @@ class EquivalentExpressions( } case _ => if (useCount > 0) { - map.put(wrapper, ExpressionStats(expr)(useCount)) + val stats = if (conditional) { + ExpressionStats(expr)(useCount = 0, conditionalUseCount = useCount) + } else { + ExpressionStats(expr)(useCount) + } + map.put(wrapper, stats) } false } @@ -89,44 +101,47 @@ class EquivalentExpressions( } /** - * Adds or removes only expressions which are common in each of given expressions, in a recursive + * Returns a list of expressions which are common in each of given expressions, in a recursive * way. * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common - * expression `(c + 1)` will be added into `equivalenceMap`. + * expression `(c + 1)` will be returned. * * Note that as we don't know in advance if any child node of an expression will be common across * all given expressions, we compute local equivalence maps for all given expressions and filter * only the common nodes. - * Those common nodes are then removed from the local map and added to the final map of + * Those common nodes are then removed from the local map and added to the final list of * expressions. + * + * Conditional expressions are not considered because we are simply looking for expressions + * evaluated once in each parent expression. */ - private def updateCommonExprs( - exprs: Seq[Expression], - map: mutable.HashMap[ExpressionEquals, ExpressionStats], - useCount: Int): Unit = { + private def getCommonExprs(exprs: Seq[Expression]): Seq[ExpressionEquals] = { assert(exprs.length > 1) var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - updateExprTree(exprs.head, localEquivalenceMap) + updateExprTree(exprs.head, localEquivalenceMap, conditionalsEnabled = false) exprs.tail.foreach { expr => val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - updateExprTree(expr, otherLocalEquivalenceMap) + updateExprTree(expr, otherLocalEquivalenceMap, conditionalsEnabled = false) localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => otherLocalEquivalenceMap.contains(key) } } + val commonExpressions = mutable.ListBuffer.empty[ExpressionEquals] + // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. // The remaining highest expression in `localEquivalenceMap` is also common expression so loop // until `localEquivalenceMap` is not empty. var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) while (statsOption.nonEmpty) { val stats = statsOption.get - updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) - updateExprTree(stats.expr, map, useCount) + updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount, conditionalsEnabled = false) + commonExpressions += ExpressionEquals(stats.expr) statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) } + commonExpressions.toSeq } private def skipForShortcut(expr: Expression): Expression = { @@ -143,21 +158,17 @@ class EquivalentExpressions( } } - // There are some special expressions that we should not recurse into all of its children. - // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. ConditionalExpression: use its children that will always be evaluated. - private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { - case _: CodegenFallback => Nil - case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) - case other => skipForShortcut(other).children - } - - // For some special expressions we cannot just recurse into all of its children, but we can - // recursively add the common expressions shared between all of its children. - private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { - case _: CodegenFallback => Nil - case c: ConditionalExpression => c.branchGroups - case _ => Nil + /** + * There are some expressions that need special handling: + * 1. CodegenFallback: It's children will not be used to generate code (call eval() instead). + * 2. ConditionalExpression: use its children that will always be evaluated. + */ + private def childrenToRecurse(expr: Expression): RecurseChildren = expr match { + case _: CodegenFallback => RecurseChildren(Nil) + case c: ConditionalExpression => + RecurseChildren(c.alwaysEvaluatedInputs.map(skipForShortcut), c.branchGroups, + c.conditionallyEvaluatedInputs) + case other => RecurseChildren(skipForShortcut(other).children) } private def supportedExpression(e: Expression): Boolean = { @@ -184,13 +195,48 @@ class EquivalentExpressions( private def updateExprTree( expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, - useCount: Int = 1): Unit = { - val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] + useCount: Int = 1, + conditionalsEnabled: Boolean = SQLConf.get.subexpressionEliminationConditionalsEnabled, + conditional: Boolean = false, + skipExpressions: Set[ExpressionEquals] = Set.empty[ExpressionEquals] + ): Unit = { + val skip = useCount == 0 || + expr.isInstanceOf[LeafExpression] || + skipExpressions.contains(ExpressionEquals(expr)) - if (!skip && !updateExprInMap(expr, map, useCount)) { + if (!skip && !updateExprInMap(expr, map, useCount, conditional)) { val uc = useCount.sign - childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) - commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) + val recurseChildren = childrenToRecurse(expr) + recurseChildren.alwaysChildren.foreach { child => + updateExprTree(child, map, uc, conditionalsEnabled, conditional, skipExpressions) + } + + /** + * If the `commonExpressions` already appears in the equivalence map, calling + * `updateExprTree` will increase the `useCount` and mark it as a common subexpression. + * Otherwise, `updateExprTree` will recursively add `commonExpressions` and its descendant to + * the equivalence map, in case they also appear in other places. For example, + * `If(a + b > 1, a + b + c, a + b + c)`, `a + b` also appears in the condition and should + * be treated as common subexpression. + */ + val commonExpressions = recurseChildren.commonChildren.flatMap { exprs => + if (exprs.nonEmpty) { + getCommonExprs(exprs) + } else { + Nil + } + } + commonExpressions.foreach { ce => + updateExprTree(ce.e, map, uc, conditionalsEnabled, conditional, skipExpressions) + } + + if (conditionalsEnabled) { + // Add all conditional expressions, skipping those that were already counted as common + // expressions. + recurseChildren.conditionalChildren.foreach { cc => + updateExprTree(cc, map, uc, true, true, commonExpressions.toSet) + } + } } } @@ -208,7 +254,7 @@ class EquivalentExpressions( // Exposed for testing. private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { - equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2) + equivalenceMap.filter(_._2.getUseCount() > count).toSeq.sortBy(_._1.height).map(_._2) } /** @@ -225,8 +271,11 @@ class EquivalentExpressions( def debugString(all: Boolean = false): String = { val sb = new java.lang.StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats => - sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n') + equivalenceMap.values.filter(stats => all || stats.getUseCount() > 1).foreach { stats => + sb.append(" ") + .append(s"${stats.expr}: useCount = ${stats.useCount} ") + .append(s"conditionalUseCount = ${stats.conditionalUseCount}") + .append('\n') } sb.toString() } @@ -255,4 +304,32 @@ case class ExpressionEquals(e: Expression) { * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" * useCount in this wrapper in-place. */ -case class ExpressionStats(expr: Expression)(var useCount: Int) +case class ExpressionStats(expr: Expression)( + var useCount: Int = 1, + var conditionalUseCount: Int = 0) { + def getUseCount(): Int = if (useCount > 0) { + useCount + conditionalUseCount + } else { + 0 + } +} + +/** + * A wrapper for the different types of children of expressions. + * + * `alwaysChildren` are child expressions that will always be evaluated and should be considered + * for subexpressions. + * + * `commonChildren` are children such that each of the children might be evaluated, but at last once + * will definitely be evaluated. If there are any common expressions among them, those expressions + * will definitely be evaluated and should be considered for subexpressions. + * + * `conditionalChildren` are children that are conditionally evaluated, such as in If, CaseWhen, + * or Coalesce expressions, and should only be considered for subexpressions if they are evaluated + * non-conditionally elsewhere. + */ +case class RecurseChildren( + alwaysChildren: Seq[Expression], + commonChildren: Seq[Seq[Expression]] = Nil, + conditionalChildren: Seq[Expression] = Nil + ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b61f7ee0ee16..a53445a586e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -572,6 +572,12 @@ trait ConditionalExpression extends Expression { * so that we can eagerly evaluate the common expressions of a group. */ def branchGroups: Seq[Seq[Expression]] + + /** + * Returns children expressions which are conditionally evaluated. If the same expression + * will be always evaluated elsewhere, we can make it a subexpression. + */ + def conditionallyEvaluatedInputs: Seq[Expression] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 621f02ca18b8..5b23437a1161 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -65,6 +65,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue)) + override def conditionallyEvaluatedInputs: Seq[Expression] = Seq(trueValue, falseValue) + final override val nodePatterns : Seq[TreePattern] = Seq(IF) override def checkInputDataTypes(): TypeCheckResult = { @@ -241,29 +243,12 @@ case class CaseWhen( } override def branchGroups: Seq[Seq[Expression]] = { - // We look at subexpressions in conditions and values of `CaseWhen` separately. It is - // because a subexpression in conditions will be run no matter which condition is matched - // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, - // a subexpression among values doesn't need to be in conditions because no matter which - // condition is true, it will be evaluated. - val conditions = if (branches.length > 1) { - branches.map(_._1) - } else { - // If there is only one branch, the first condition is already covered by - // `alwaysEvaluatedInputs` and we should exclude it here. - Nil - } - // For an expression to be in all branch values of a CaseWhen statement, it must also be in - // the elseValue. - val values = if (elseValue.nonEmpty) { - branches.map(_._2) ++ elseValue - } else { - Nil - } - - Seq(conditions, values) + // If there's an else value then we will definitely evaluate at least one branch value + if (elseValue.isDefined) Seq(branches.map(_._2) ++ elseValue) else Nil } + override def conditionallyEvaluatedInputs: Seq[Expression] = children.tail + override def eval(input: InternalRow): Any = { var i = 0 val size = branches.size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 1aa1d0b25e44..56a9cc2b923e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -75,13 +75,9 @@ case class Coalesce(children: Seq[Expression]) withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1)) } - override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) { - // If there is only one child, the first child is already covered by - // `alwaysEvaluatedInputs` and we should exclude it here. - Seq(children) - } else { - Nil - } + override def branchGroups: Seq[Seq[Expression]] = Nil + + override def conditionallyEvaluatedInputs: Seq[Expression] = children.tail override def eval(input: InternalRow): Any = { var result: Any = null @@ -348,7 +344,9 @@ case class NaNvl(left: Expression, right: Expression) copy(left = alwaysEvaluatedInputs.head) } - override def branchGroups: Seq[Seq[Expression]] = Seq(children) + override def branchGroups: Seq[Seq[Expression]] = Nil + + override def conditionallyEvaluatedInputs: Seq[Expression] = right :: Nil override def eval(input: InternalRow): Any = { val value = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2f7706c859ba..8dd7807f78ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1251,6 +1251,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED = + buildConf("spark.sql.subexpressionElimination.conditionals.enabled") + .internal() + .doc("When true, common conditional subexpressions will be eliminated.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CASE_SENSITIVE = buildConf(SqlApiConfHelper.CASE_SENSITIVE_KEY) .internal() .doc("Whether the query analyzer should be case sensitive or not. " + @@ -7314,6 +7322,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def subexpressionEliminationSkipForShotcutExpr: Boolean = getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR) + def subexpressionEliminationConditionalsEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index e9faeba2411c..576b5ea144b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, LongType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -197,13 +197,15 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel (GreaterThan(add2, Literal(4)), add1) :: (GreaterThan(add2, Literal(5)), add1) :: Nil - val caseWhenExpr1 = CaseWhen(conditions1, None) - val equivalence1 = new EquivalentExpressions - equivalence1.addExprTree(caseWhenExpr1) + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") { + val caseWhenExpr1 = CaseWhen(conditions1, None) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) - // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + // `add2` is repeatedly in all conditions. + assert(equivalence1.getAllExprStates().count(_.getUseCount() == 3) == 1) + assert(equivalence1.getAllExprStates().filter(_.getUseCount() == 3).head.expr eq add2) + } val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: (GreaterThan(add2, Literal(4)), add1) :: @@ -235,13 +237,15 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel GreaterThan(add2, Literal(4)) :: GreaterThan(add2, Literal(5)) :: Nil - val coalesceExpr1 = Coalesce(conditions1) - val equivalence1 = new EquivalentExpressions - equivalence1.addExprTree(coalesceExpr1) + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") { + val coalesceExpr1 = Coalesce(conditions1) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(coalesceExpr1) - // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + // `add2` is repeatedly in all conditions. + assert(equivalence1.getAllExprStates().count(_.getUseCount() == 3) == 1) + assert(equivalence1.getAllExprStates().filter(_.getUseCount() == 3).head.expr eq add2) + } // Negative case. `add1` and `add2` both are not used in all branches. val conditions2 = GreaterThan(add1, Literal(3)) :: @@ -402,6 +406,71 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0) } + test("SPARK-35564: Subexpressions should be extracted from conditional values if that value " + + "will always be evaluated elsewhere") { + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + val conditions1 = (GreaterThan(add1, Literal(3)), add1) :: Nil + val caseWhenExpr1 = CaseWhen(conditions1, None) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + + // `add1` is evaluated once in the first condition, and optionally in the first value + assert(equivalence1.getCommonSubexpressions.size == 1) + + val ifExpr = If(GreaterThan(add1, Literal(3)), add1, add2) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(ifExpr) + + // `add1` is evaluated once in the condition, and optionally in the true value + assert(equivalence2.getCommonSubexpressions.size == 1) + } + } + + test("SPARK-35564: Common expressions don't infinite loop with conditional expressions") { + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + val inner = CaseWhen((GreaterThan(add2, Literal(2)), add1) :: Nil) + val outer = CaseWhen((GreaterThan(add1, Literal(2)), inner) :: Nil, add1) + + val equivalence = new EquivalentExpressions + equivalence.addExprTree(outer) + + // `add1` is evaluated in the outer condition, and optionally in the inner value + assert(equivalence.getCommonSubexpressions.size == 1) + + val when1 = CaseWhen((GreaterThan(Literal(1), Literal(1)), Cast(Literal(1), LongType)) :: Nil) + val when2 = CaseWhen((GreaterThan(when1, Literal(2)), when1) :: Nil, when1) + val when3 = CaseWhen((GreaterThan(when1, Literal(1)), when2) :: Nil) + + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(when3) + + // `when1` is evaluated in the outer condition, and optionally in the inner value multiple + // times including in a nested conditional + assert(equivalence2.getCommonSubexpressions.size == 1) + } + } + + test("SPARK-35564: Don't double count conditional expressions if present in all branches") { + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val add3 = Add(add2, Literal(4)) + + val caseWhenExpr1 = CaseWhen((GreaterThan(add1, Literal(3)), add3) :: Nil, add2) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + + // `add2` will only be evaluated once so don't create a subexpression + assert(equivalence1.getCommonSubexpressions.size == 0) + } + } + test("SPARK-35829: SubExprEliminationState keeps children sub exprs") { val add1 = Add(Literal(1), Literal(2)) val add2 = Add(add1, add1) @@ -439,17 +508,19 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel } test("SPARK-39040: Respect NaNvl in EquivalentExpressions for expression elimination") { - val add = Add(Literal(1), Literal(0)) - val n1 = NaNvl(Literal(1.0d), Add(add, add)) - val e1 = new EquivalentExpressions - e1.addExprTree(n1) - assert(e1.getCommonSubexpressions.isEmpty) - - val n2 = NaNvl(add, add) - val e2 = new EquivalentExpressions - e2.addExprTree(n2) - assert(e2.getCommonSubexpressions.size == 1) - assert(e2.getCommonSubexpressions.head == add) + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED.key -> "true") { + val add = Add(Literal(1), Literal(0)) + val n1 = NaNvl(Literal(1.0d), Add(add, add)) + val e1 = new EquivalentExpressions + e1.addExprTree(n1) + assert(e1.getCommonSubexpressions.isEmpty) + + val n2 = NaNvl(add, add) + val e2 = new EquivalentExpressions + e2.addExprTree(n2) + assert(e2.getCommonSubexpressions.size == 1) + assert(e2.getCommonSubexpressions.head == add) + } } test("SPARK-42851: Handle supportExpression consistently across add and get") { @@ -498,10 +569,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel test("Equivalent ternary expressions have different children") { val add1 = Add(Add(Literal(1), Literal(2)), Literal(3)) val add2 = Add(Add(Literal(3), Literal(1)), Literal(2)) - val conditions1 = (GreaterThan(add1, Literal(3)), Literal(1)) :: - (GreaterThan(add2, Literal(0)), Literal(2)) :: Nil + val conditions1 = (GreaterThan(add1, Literal(3)), add1) :: Nil - val caseWhenExpr1 = CaseWhen(conditions1, Literal(0)) + val caseWhenExpr1 = CaseWhen(conditions1, add2) val equivalence1 = new EquivalentExpressions equivalence1.addExprTree(caseWhenExpr1) assert(equivalence1.getCommonSubexpressions.size == 1)