diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 364f546ffcbe..32eb650aa6d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -83,6 +83,13 @@ class EquivalentExpressions { * Adds only expressions which are common in each of given expressions, in a recursive way. * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, * the common expression `(c + 1)` will be added into `equivalenceMap`. + * + * Note that as we don't know in advance if any child node of an expression will be common + * across all given expressions, we count all child nodes when looking through the given + * expressions. But when we call `addExprTree` to add common expressions into the map, we + * will add recursively the child nodes. So we need to filter the child expressions first. + * For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add + * `((a + b) + c)`. */ private def addCommonExprs( exprs: Seq[Expression], @@ -90,13 +97,21 @@ class EquivalentExpressions { val exprSetForAll = mutable.Set[Expr]() addExprTree(exprs.head, addExprToSet(_, exprSetForAll)) - val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => + val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => val otherExprSet = mutable.Set[Expr]() addExprTree(expr, addExprToSet(_, otherExprSet)) exprSet.intersect(otherExprSet) } - commonExprSet.foreach(expr => addFunc(expr.e)) + // Not all expressions in the set should be added. We should filter out the related + // children nodes. + val commonExprSet = candidateExprs.filter { candidateExpr => + candidateExprs.forall { expr => + expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty + } + } + + commonExprSet.foreach(expr => addExprTree(expr.e, addFunc)) } // There are some special expressions that we should not recurse into all of its children. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index bdb08de6d727..5e6b07422c76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -310,6 +310,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel } } + test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + + "for conditional expressions") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val add3 = Add(add1, add2) + val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil + + val caseWhenExpr = CaseWhen(condition, None) + val equivalence = new EquivalentExpressions + equivalence.addExprTree(caseWhenExpr) + + val commonExprs = equivalence.getAllEquivalentExprs(1) + assert(commonExprs.size == 1) + assert(commonExprs.head === Seq(add3, add3)) + } + test("SPARK-35439: Children subexpr should come first than parent subexpr") { val add = Add(Literal(1), Literal(2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d7d85d43544e..19c905326042 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2861,6 +2861,25 @@ class DataFrameSuite extends QueryTest checkAnswer(result, Row(0, 0, 0, 0, 100)) } } + + test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + + "for conditional expressions") { + val accum = sparkContext.longAccumulator("call") + val simpleUDF = udf((s: String) => { + accum.add(1) + s + }) + val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0, + functions.length(simpleUDF($"id")))) + df1.collect() + assert(accum.value == 5) + + val nondeterministicUDF = simpleUDF.asNondeterministic() + val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0, + functions.length(nondeterministicUDF($"id")))) + df2.collect() + assert(accum.value == 15) + } } case class GroupByKey(a: Int, b: Int)