-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-33337][SQL] Support subexpression elimination in branches of conditional expressions #30245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
db0cfcc
cd3776c
9182e3d
cc0648a
16314a9
33f3bd3
b415728
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,11 +65,82 @@ class EquivalentExpressions { | |
| } | ||
| } | ||
|
|
||
| private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = { | ||
| if (expr.deterministic) { | ||
| val e = Expr(expr) | ||
| if (set.contains(e)) { | ||
| true | ||
| } else { | ||
| set.add(e) | ||
| false | ||
| } | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Adds only expressions which are common in each of given expressions, in a recursive way. | ||
| * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, | ||
| * the common expression `(c + 1)` will be added into `equivalenceMap`. | ||
| */ | ||
| private def addCommonExprs( | ||
| exprs: Seq[Expression], | ||
| addFunc: Expression => Boolean = addExpr): Unit = { | ||
| val exprSetForAll = mutable.Set[Expr]() | ||
| addExprTree(exprs.head, addExprToSet(_, exprSetForAll)) | ||
|
|
||
| val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => | ||
| val otherExprSet = mutable.Set[Expr]() | ||
| addExprTree(expr, addExprToSet(_, otherExprSet)) | ||
| exprSet.intersect(otherExprSet) | ||
| } | ||
|
|
||
| commonExprSet.foreach(expr => addFunc(expr.e)) | ||
| } | ||
|
|
||
| // There are some special expressions that we should not recurse into all of its children. | ||
| // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) | ||
| // 2. If: common subexpressions will always be evaluated at the beginning, but the true and | ||
| // false expressions in `If` may not get accessed, according to the predicate | ||
| // expression. We should only recurse into the predicate expression. | ||
| // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain | ||
| // condition. We should only recurse into the first condition expression as it | ||
| // will always get accessed. | ||
| // 4. Coalesce: it's also a conditional expression, we should only recurse into the first | ||
| // children, because others may not get accessed. | ||
| private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { | ||
| case _: CodegenFallback => Nil | ||
| case i: If => i.predicate :: Nil | ||
| case c: CaseWhen => c.children.head :: Nil | ||
| case c: Coalesce => c.children.head :: Nil | ||
| case other => other.children | ||
| } | ||
|
|
||
| // For some special expressions we cannot just recurse into all of its children, but we can | ||
| // recursively add the common expressions shared between all of its children. | ||
| private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { | ||
| case i: If => Seq(Seq(i.trueValue, i.falseValue)) | ||
| case c: CaseWhen => | ||
| // We look at subexpressions in conditions and values of `CaseWhen` separately. It is | ||
| // because a subexpression in conditions will be run no matter which condition is matched | ||
| // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, | ||
| // a subexpression among values doesn't need to be in conditions because no matter which | ||
| // condition is true, it will be evaluated. | ||
| val conditions = c.branches.tail.map(_._1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a flaw here: we exclude the first condition, so a common subexpressions in the rest of the conditions doesn't mean it's always evaluated. e.g.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this is related to #32977. This looks more a aggressive optimization. Consider if we respect short-circuit evaluation practice for CaseWhen, this might be an issue if users reply short-circuit evaluation to guard later conditions. Safest approach is to only consider all conditions.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT? Should we only consider all conditions?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should. I hit an issue caused by it in my refactor and I'll open a PR for the refactor with multiple bugs fixed.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, thanks!
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, does #32980 conflict with your refactor?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only some trivial conflicts, #32980 should be merged first as it has been reviewed and approved.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW, I also addressed this issue in #32987 which assumed CaseWhen's (and Coalesce) should short circuit and guard later conditions. The main benefit/difference is if you have
|
||
| val values = c.branches.map(_._2) ++ c.elseValue | ||
| Seq(conditions, values) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to handle
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A subexpression in conditions is definitely run because it is shared among
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, ok. Thanks. Could you leave comments about it in the code side? |
||
| case c: Coalesce => Seq(c.children.tail) | ||
| case _ => Nil | ||
| } | ||
|
|
||
| /** | ||
| * Adds the expression to this data structure recursively. Stops if a matching expression | ||
| * is found. That is, if `expr` has already been added, its children are not added. | ||
| */ | ||
| def addExprTree(expr: Expression): Unit = { | ||
| def addExprTree( | ||
| expr: Expression, | ||
| addFunc: Expression => Boolean = addExpr): Unit = { | ||
| val skip = expr.isInstanceOf[LeafExpression] || | ||
| // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the | ||
| // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. | ||
|
|
@@ -78,26 +149,9 @@ class EquivalentExpressions { | |
| // can cause error like NPE. | ||
| (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) | ||
|
|
||
| // There are some special expressions that we should not recurse into all of its children. | ||
| // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) | ||
| // 2. If: common subexpressions will always be evaluated at the beginning, but the true and | ||
| // false expressions in `If` may not get accessed, according to the predicate | ||
| // expression. We should only recurse into the predicate expression. | ||
| // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain | ||
| // condition. We should only recurse into the first condition expression as it | ||
| // will always get accessed. | ||
| // 4. Coalesce: it's also a conditional expression, we should only recurse into the first | ||
| // children, because others may not get accessed. | ||
| def childrenToRecurse: Seq[Expression] = expr match { | ||
| case _: CodegenFallback => Nil | ||
| case i: If => i.predicate :: Nil | ||
| case c: CaseWhen => c.children.head :: Nil | ||
| case c: Coalesce => c.children.head :: Nil | ||
| case other => other.children | ||
| } | ||
|
|
||
| if (!skip && !addExpr(expr)) { | ||
| childrenToRecurse.foreach(addExprTree) | ||
| if (!skip && !addFunc(expr)) { | ||
| childrenToRecurse(expr).foreach(addExprTree(_, addFunc)) | ||
| commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc)) | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,20 +146,111 @@ class SubexpressionEliminationSuite extends SparkFunSuite { | |
| equivalence.addExprTree(add) | ||
| // the `two` inside `fallback` should not be added | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode | ||
| } | ||
|
|
||
| test("Children of conditional expressions") { | ||
| val condition = And(Literal(true), Literal(false)) | ||
| test("Children of conditional expressions: If") { | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| val add = Add(Literal(1), Literal(2)) | ||
| val ifExpr = If(condition, add, add) | ||
| val condition = GreaterThan(add, Literal(3)) | ||
|
|
||
| val equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(ifExpr) | ||
| // the `add` inside `If` should not be added | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| // only ifExpr and its predicate expression | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| val ifExpr1 = If(condition, add, add) | ||
| val equivalence1 = new EquivalentExpressions | ||
| equivalence1.addExprTree(ifExpr1) | ||
|
|
||
| // `add` is in both two branches of `If` and predicate. | ||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) | ||
| // one-time expressions: only ifExpr and its predicate expression | ||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use contains method? HashMap can not guarantee the order
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will create a follow-up for making sure it will not possibly flaky. Thanks.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created #30371.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, @leoluan2009 and @viirya . The follow-up is merged to reduce the flakiness. |
||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition)) | ||
|
|
||
| // Repeated `add` is only in one branch, so we don't count it. | ||
| val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) | ||
| val equivalence2 = new EquivalentExpressions | ||
| equivalence2.addExprTree(ifExpr2) | ||
|
|
||
| assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3) | ||
|
|
||
| val ifExpr3 = If(condition, ifExpr1, ifExpr1) | ||
| val equivalence3 = new EquivalentExpressions | ||
| equivalence3.addExprTree(ifExpr3) | ||
|
|
||
| // `add`: 2, `condition`: 2 | ||
| assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition)) | ||
|
|
||
| // `ifExpr1`, `ifExpr3` | ||
| assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3)) | ||
| } | ||
|
|
||
| test("Children of conditional expressions: CaseWhen") { | ||
| val add1 = Add(Literal(1), Literal(2)) | ||
| val add2 = Add(Literal(2), Literal(3)) | ||
| val conditions1 = (GreaterThan(add2, Literal(3)), add1) :: | ||
| (GreaterThan(add2, Literal(4)), add1) :: | ||
| (GreaterThan(add2, Literal(5)), add1) :: Nil | ||
|
|
||
| val caseWhenExpr1 = CaseWhen(conditions1, None) | ||
| val equivalence1 = new EquivalentExpressions | ||
| equivalence1.addExprTree(caseWhenExpr1) | ||
|
|
||
| // `add2` is repeatedly in all conditions. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We treat the first condition specially because it is definitely run. So it counts one for For |
||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) | ||
|
|
||
| val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: | ||
| (GreaterThan(add2, Literal(4)), add1) :: | ||
| (GreaterThan(add2, Literal(5)), add1) :: Nil | ||
|
|
||
| val caseWhenExpr2 = CaseWhen(conditions2, None) | ||
| val equivalence2 = new EquivalentExpressions | ||
| equivalence2.addExprTree(caseWhenExpr2) | ||
|
|
||
| // `add1` is repeatedly in all branch values, and first predicate. | ||
| assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1)) | ||
|
|
||
| // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. | ||
| val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: | ||
| (GreaterThan(add2, Literal(4)), add1) :: | ||
| (GreaterThan(add2, Literal(5)), add1) :: Nil | ||
|
|
||
| val caseWhenExpr3 = CaseWhen(conditions3, None) | ||
| val equivalence3 = new EquivalentExpressions | ||
| equivalence3.addExprTree(caseWhenExpr3) | ||
| assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0) | ||
| } | ||
|
|
||
| test("Children of conditional expressions: Coalesce") { | ||
| val add1 = Add(Literal(1), Literal(2)) | ||
| val add2 = Add(Literal(2), Literal(3)) | ||
| val conditions1 = GreaterThan(add2, Literal(3)) :: | ||
| GreaterThan(add2, Literal(4)) :: | ||
| GreaterThan(add2, Literal(5)) :: Nil | ||
|
|
||
| val coalesceExpr1 = Coalesce(conditions1) | ||
| val equivalence1 = new EquivalentExpressions | ||
| equivalence1.addExprTree(coalesceExpr1) | ||
|
|
||
| // `add2` is repeatedly in all conditions. | ||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) | ||
|
|
||
| // Negative case. `add1` and `add2` both are not used in all branches. | ||
| val conditions2 = GreaterThan(add1, Literal(3)) :: | ||
| GreaterThan(add2, Literal(4)) :: | ||
| GreaterThan(add2, Literal(5)) :: Nil | ||
|
|
||
| val coalesceExpr2 = Coalesce(conditions2) | ||
| val equivalence2 = new EquivalentExpressions | ||
| equivalence2.addExprTree(coalesceExpr2) | ||
|
|
||
| assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0) | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to handle
headandtailseperately?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For expression
head, we add underlying expressions intoexprSetForAllset. But for expressions intail, we keep intersect betweenexprSetForAllandexprSet.We can merge two blocks, but in the block we need to check if current expression is head expression and do different logic based on the check.
I prefer current one since it looks simpler.