-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35564][SQL] Support subexpression elimination for conditionally evaluated expressions #32987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,15 +61,22 @@ class EquivalentExpressions( | |
| private def updateExprInMap( | ||
| expr: Expression, | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats], | ||
| useCount: Int = 1): Boolean = { | ||
| useCount: Int = 1, | ||
| conditional: Boolean = false): Boolean = { | ||
| if (expr.deterministic) { | ||
| val wrapper = ExpressionEquals(expr) | ||
| map.get(wrapper) match { | ||
| case Some(stats) => | ||
| stats.useCount += useCount | ||
| if (stats.useCount > 0) { | ||
| val count = if (conditional) { | ||
| stats.conditionalUseCount += useCount | ||
| stats.conditionalUseCount | ||
| } else { | ||
| stats.useCount += useCount | ||
| stats.useCount | ||
| } | ||
| if (count > 0) { | ||
| true | ||
| } else if (stats.useCount == 0) { | ||
| } else if (count == 0) { | ||
| map -= wrapper | ||
| false | ||
| } else { | ||
|
|
@@ -79,7 +86,12 @@ class EquivalentExpressions( | |
| } | ||
| case _ => | ||
| if (useCount > 0) { | ||
| map.put(wrapper, ExpressionStats(expr)(useCount)) | ||
| val stats = if (conditional) { | ||
| ExpressionStats(expr)(useCount = 0, conditionalUseCount = useCount) | ||
| } else { | ||
| ExpressionStats(expr)(useCount) | ||
| } | ||
| map.put(wrapper, stats) | ||
| } | ||
| false | ||
| } | ||
|
|
@@ -89,44 +101,47 @@ class EquivalentExpressions( | |
| } | ||
|
|
||
| /** | ||
| * Adds or removes only expressions which are common in each of given expressions, in a recursive | ||
| * Returns a list of expressions which are common in each of given expressions, in a recursive | ||
| * way. | ||
| * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common | ||
| * expression `(c + 1)` will be added into `equivalenceMap`. | ||
| * expression `(c + 1)` will be returned. | ||
| * | ||
| * Note that as we don't know in advance if any child node of an expression will be common across | ||
| * all given expressions, we compute local equivalence maps for all given expressions and filter | ||
| * only the common nodes. | ||
| * Those common nodes are then removed from the local map and added to the final map of | ||
| * Those common nodes are then removed from the local map and added to the final list of | ||
| * expressions. | ||
| * | ||
| * Conditional expressions are not considered because we are simply looking for expressions | ||
| * evaluated once in each parent expression. | ||
| */ | ||
| private def updateCommonExprs( | ||
| exprs: Seq[Expression], | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats], | ||
| useCount: Int): Unit = { | ||
| private def getCommonExprs(exprs: Seq[Expression]): Seq[ExpressionEquals] = { | ||
| assert(exprs.length > 1) | ||
| var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
| updateExprTree(exprs.head, localEquivalenceMap) | ||
| updateExprTree(exprs.head, localEquivalenceMap, conditionalsEnabled = false) | ||
|
|
||
| exprs.tail.foreach { expr => | ||
| val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
| updateExprTree(expr, otherLocalEquivalenceMap) | ||
| updateExprTree(expr, otherLocalEquivalenceMap, conditionalsEnabled = false) | ||
| localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => | ||
| otherLocalEquivalenceMap.contains(key) | ||
| } | ||
| } | ||
|
|
||
| val commonExpressions = mutable.ListBuffer.empty[ExpressionEquals] | ||
|
|
||
| // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. | ||
| // The remaining highest expression in `localEquivalenceMap` is also common expression so loop | ||
| // until `localEquivalenceMap` is not empty. | ||
| var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) | ||
| while (statsOption.nonEmpty) { | ||
| val stats = statsOption.get | ||
| updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) | ||
| updateExprTree(stats.expr, map, useCount) | ||
| updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount, conditionalsEnabled = false) | ||
| commonExpressions += ExpressionEquals(stats.expr) | ||
|
|
||
| statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) | ||
| } | ||
| commonExpressions.toSeq | ||
| } | ||
|
|
||
| private def skipForShortcut(expr: Expression): Expression = { | ||
|
|
@@ -143,21 +158,17 @@ class EquivalentExpressions( | |
| } | ||
| } | ||
|
|
||
| // There are some special expressions that we should not recurse into all of its children. | ||
| // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) | ||
| // 2. ConditionalExpression: use its children that will always be evaluated. | ||
| private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { | ||
| case _: CodegenFallback => Nil | ||
| case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) | ||
| case other => skipForShortcut(other).children | ||
| } | ||
|
|
||
| // For some special expressions we cannot just recurse into all of its children, but we can | ||
| // recursively add the common expressions shared between all of its children. | ||
| private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { | ||
| case _: CodegenFallback => Nil | ||
| case c: ConditionalExpression => c.branchGroups | ||
| case _ => Nil | ||
| /** | ||
| * There are some expressions that need special handling: | ||
| * 1. CodegenFallback: It's children will not be used to generate code (call eval() instead). | ||
| * 2. ConditionalExpression: use its children that will always be evaluated. | ||
| */ | ||
| private def childrenToRecurse(expr: Expression): RecurseChildren = expr match { | ||
| case _: CodegenFallback => RecurseChildren(Nil) | ||
| case c: ConditionalExpression => | ||
| RecurseChildren(c.alwaysEvaluatedInputs.map(skipForShortcut), c.branchGroups, | ||
| c.conditionallyEvaluatedInputs) | ||
| case other => RecurseChildren(skipForShortcut(other).children) | ||
| } | ||
|
|
||
| private def supportedExpression(e: Expression): Boolean = { | ||
|
|
@@ -184,13 +195,48 @@ class EquivalentExpressions( | |
| private def updateExprTree( | ||
| expr: Expression, | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, | ||
| useCount: Int = 1): Unit = { | ||
| val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] | ||
| useCount: Int = 1, | ||
| conditionalsEnabled: Boolean = SQLConf.get.subexpressionEliminationConditionalsEnabled, | ||
| conditional: Boolean = false, | ||
| skipExpressions: Set[ExpressionEquals] = Set.empty[ExpressionEquals] | ||
| ): Unit = { | ||
| val skip = useCount == 0 || | ||
| expr.isInstanceOf[LeafExpression] || | ||
| skipExpressions.contains(ExpressionEquals(expr)) | ||
|
|
||
| if (!skip && !updateExprInMap(expr, map, useCount)) { | ||
| if (!skip && !updateExprInMap(expr, map, useCount, conditional)) { | ||
| val uc = useCount.sign | ||
| childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) | ||
| commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) | ||
| val recurseChildren = childrenToRecurse(expr) | ||
| recurseChildren.alwaysChildren.foreach { child => | ||
| updateExprTree(child, map, uc, conditionalsEnabled, conditional, skipExpressions) | ||
| } | ||
|
|
||
| /** | ||
| * If the `commonExpressions` already appears in the equivalence map, calling | ||
| * `updateExprTree` will increase the `useCount` and mark it as a common subexpression. | ||
| * Otherwise, `updateExprTree` will recursively add `commonExpressions` and its descendant to | ||
| * the equivalence map, in case they also appear in other places. For example, | ||
| * `If(a + b > 1, a + b + c, a + b + c)`, `a + b` also appears in the condition and should | ||
| * be treated as common subexpression. | ||
| */ | ||
| val commonExpressions = recurseChildren.commonChildren.flatMap { exprs => | ||
| if (exprs.nonEmpty) { | ||
| getCommonExprs(exprs) | ||
| } else { | ||
| Nil | ||
| } | ||
| } | ||
| commonExpressions.foreach { ce => | ||
| updateExprTree(ce.e, map, uc, conditionalsEnabled, conditional, skipExpressions) | ||
| } | ||
|
|
||
| if (conditionalsEnabled) { | ||
| // Add all conditional expressions, skipping those that were already counted as common | ||
| // expressions. | ||
| recurseChildren.conditionalChildren.foreach { cc => | ||
| updateExprTree(cc, map, uc, true, true, commonExpressions.toSet) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -208,7 +254,7 @@ class EquivalentExpressions( | |
|
|
||
| // Exposed for testing. | ||
| private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { | ||
| equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2) | ||
| equivalenceMap.filter(_._2.getUseCount() > count).toSeq.sortBy(_._1.height).map(_._2) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -225,8 +271,11 @@ class EquivalentExpressions( | |
| def debugString(all: Boolean = false): String = { | ||
| val sb = new java.lang.StringBuilder() | ||
| sb.append("Equivalent expressions:\n") | ||
| equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats => | ||
| sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n') | ||
| equivalenceMap.values.filter(stats => all || stats.getUseCount() > 1).foreach { stats => | ||
| sb.append(" ") | ||
| .append(s"${stats.expr}: useCount = ${stats.useCount} ") | ||
| .append(s"conditionalUseCount = ${stats.conditionalUseCount}") | ||
| .append('\n') | ||
| } | ||
| sb.toString() | ||
| } | ||
|
|
@@ -255,4 +304,32 @@ case class ExpressionEquals(e: Expression) { | |
| * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" | ||
| * useCount in this wrapper in-place. | ||
| */ | ||
| case class ExpressionStats(expr: Expression)(var useCount: Int) | ||
| case class ExpressionStats(expr: Expression)( | ||
| var useCount: Int = 1, | ||
| var conditionalUseCount: Int = 0) { | ||
| def getUseCount(): Int = if (useCount > 0) { | ||
| useCount + conditionalUseCount | ||
| } else { | ||
| 0 | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A wrapper for the different types of children of expressions. | ||
| * | ||
| * `alwaysChildren` are child expressions that will always be evaluated and should be considered | ||
| * for subexpressions. | ||
| * | ||
| * `commonChildren` are children such that each of the children might be evaluated, but at last once | ||
| * will definitely be evaluated. If there are any common expressions among them, those expressions | ||
| * will definitely be evaluated and should be considered for subexpressions. | ||
| * | ||
| * `conditionalChildren` are children that are conditionally evaluated, such as in If, CaseWhen, | ||
| * or Coalesce expressions, and should only be considered for subexpressions if they are evaluated | ||
| * non-conditionally elsewhere. | ||
| */ | ||
| case class RecurseChildren( | ||
| alwaysChildren: Seq[Expression], | ||
| commonChildren: Seq[Seq[Expression]] = Nil, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have an example this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the docs a little bit to clarify. Currently it's only |
||
| conditionalChildren: Seq[Expression] = Nil | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we update the doc of this method? no equivalenceMap in this method now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep good call, haven't kept up with with some of the docs