diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 458c48df6d0c..1dfff412d9a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -65,11 +65,82 @@ class EquivalentExpressions { } } + private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = { + if (expr.deterministic) { + val e = Expr(expr) + if (set.contains(e)) { + true + } else { + set.add(e) + false + } + } else { + false + } + } + + /** + * Adds only expressions which are common in each of given expressions, in a recursive way. + * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, + * the common expression `(c + 1)` will be added into `equivalenceMap`. + */ + private def addCommonExprs( + exprs: Seq[Expression], + addFunc: Expression => Boolean = addExpr): Unit = { + val exprSetForAll = mutable.Set[Expr]() + addExprTree(exprs.head, addExprToSet(_, exprSetForAll)) + + val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => + val otherExprSet = mutable.Set[Expr]() + addExprTree(expr, addExprToSet(_, otherExprSet)) + exprSet.intersect(otherExprSet) + } + + commonExprSet.foreach(expr => addFunc(expr.e)) + } + + // There are some special expressions that we should not recurse into all of its children. + // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) + // 2. If: common subexpressions will always be evaluated at the beginning, but the true and + // false expressions in `If` may not get accessed, according to the predicate + // expression. We should only recurse into the predicate expression. + // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain + // condition. We should only recurse into the first condition expression as it + // will always get accessed. + // 4. Coalesce: it's also a conditional expression, we should only recurse into the first + // children, because others may not get accessed. + private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { + case _: CodegenFallback => Nil + case i: If => i.predicate :: Nil + case c: CaseWhen => c.children.head :: Nil + case c: Coalesce => c.children.head :: Nil + case other => other.children + } + + // For some special expressions we cannot just recurse into all of its children, but we can + // recursively add the common expressions shared between all of its children. + private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { + case i: If => Seq(Seq(i.trueValue, i.falseValue)) + case c: CaseWhen => + // We look at subexpressions in conditions and values of `CaseWhen` separately. It is + // because a subexpression in conditions will be run no matter which condition is matched + // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, + // a subexpression among values doesn't need to be in conditions because no matter which + // condition is true, it will be evaluated. + val conditions = c.branches.tail.map(_._1) + val values = c.branches.map(_._2) ++ c.elseValue + Seq(conditions, values) + case c: Coalesce => Seq(c.children.tail) + case _ => Nil + } + /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. */ - def addExprTree(expr: Expression): Unit = { + def addExprTree( + expr: Expression, + addFunc: Expression => Boolean = addExpr): Unit = { val skip = expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. @@ -78,26 +149,9 @@ class EquivalentExpressions { // can cause error like NPE. (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) - // There are some special expressions that we should not recurse into all of its children. - // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. If: common subexpressions will always be evaluated at the beginning, but the true and - // false expressions in `If` may not get accessed, according to the predicate - // expression. We should only recurse into the predicate expression. - // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain - // condition. We should only recurse into the first condition expression as it - // will always get accessed. - // 4. Coalesce: it's also a conditional expression, we should only recurse into the first - // children, because others may not get accessed. - def childrenToRecurse: Seq[Expression] = expr match { - case _: CodegenFallback => Nil - case i: If => i.predicate :: Nil - case c: CaseWhen => c.children.head :: Nil - case c: Coalesce => c.children.head :: Nil - case other => other.children - } - - if (!skip && !addExpr(expr)) { - childrenToRecurse.foreach(addExprTree) + if (!skip && !addFunc(expr)) { + childrenToRecurse(expr).foreach(addExprTree(_, addFunc)) + commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9a26c388f59a..9aa827a58d87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1044,7 +1044,7 @@ class CodegenContext extends Logging { val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree) + expressions.foreach(equivalentExpressions.addExprTree(_)) // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 1fa185cc77eb..4725a40781c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -146,20 +146,111 @@ class SubexpressionEliminationSuite extends SparkFunSuite { equivalence.addExprTree(add) // the `two` inside `fallback` should not be added assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) - assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode + assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } - test("Children of conditional expressions") { - val condition = And(Literal(true), Literal(false)) + test("Children of conditional expressions: If") { val add = Add(Literal(1), Literal(2)) - val ifExpr = If(condition, add, add) + val condition = GreaterThan(add, Literal(3)) - val equivalence = new EquivalentExpressions - equivalence.addExprTree(ifExpr) - // the `add` inside `If` should not be added - assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) - // only ifExpr and its predicate expression - assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2) + val ifExpr1 = If(condition, add, add) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(ifExpr1) + + // `add` is in both two branches of `If` and predicate. + assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) + // one-time expressions: only ifExpr and its predicate expression + assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition)) + + // Repeated `add` is only in one branch, so we don't count it. + val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(ifExpr2) + + assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0) + assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3) + + val ifExpr3 = If(condition, ifExpr1, ifExpr1) + val equivalence3 = new EquivalentExpressions + equivalence3.addExprTree(ifExpr3) + + // `add`: 2, `condition`: 2 + assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition)) + + // `ifExpr1`, `ifExpr3` + assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3)) + } + + test("Children of conditional expressions: CaseWhen") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val conditions1 = (GreaterThan(add2, Literal(3)), add1) :: + (GreaterThan(add2, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + + val caseWhenExpr1 = CaseWhen(conditions1, None) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + + // `add2` is repeatedly in all conditions. + assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) + + val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: + (GreaterThan(add2, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + + val caseWhenExpr2 = CaseWhen(conditions2, None) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(caseWhenExpr2) + + // `add1` is repeatedly in all branch values, and first predicate. + assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1) + assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1)) + + // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. + val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: + (GreaterThan(add2, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + + val caseWhenExpr3 = CaseWhen(conditions3, None) + val equivalence3 = new EquivalentExpressions + equivalence3.addExprTree(caseWhenExpr3) + assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0) + } + + test("Children of conditional expressions: Coalesce") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val conditions1 = GreaterThan(add2, Literal(3)) :: + GreaterThan(add2, Literal(4)) :: + GreaterThan(add2, Literal(5)) :: Nil + + val coalesceExpr1 = Coalesce(conditions1) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(coalesceExpr1) + + // `add2` is repeatedly in all conditions. + assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) + + // Negative case. `add1` and `add2` both are not used in all branches. + val conditions2 = GreaterThan(add1, Literal(3)) :: + GreaterThan(add2, Literal(4)) :: + GreaterThan(add2, Literal(5)) :: Nil + + val coalesceExpr2 = Coalesce(conditions2) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(coalesceExpr2) + + assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0) } }