-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-29013][SQL] Structurally equivalent subexpression elimination #25717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
13f5ca6
4dea062
f52cbde
cc0ee12
f447042
71e0239
4700b89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions | |||||||||||||||||
|
|
||||||||||||||||||
| import scala.collection.mutable | ||||||||||||||||||
|
|
||||||||||||||||||
| import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback | ||||||||||||||||||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ParameterizedBoundReference} | ||||||||||||||||||
| import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
|
|
@@ -40,23 +40,50 @@ class EquivalentExpressions { | |||||||||||||||||
| override def hashCode: Int = e.semanticHash() | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Wrapper around an Expression that provides structural semantic equality. | ||||||||||||||||||
| */ | ||||||||||||||||||
| case class StructuralExpr(e: Expression) { | ||||||||||||||||||
| def normalized(expr: Expression): Expression = { | ||||||||||||||||||
| expr.transformUp { | ||||||||||||||||||
| case b: BoundReference => | ||||||||||||||||||
| b.copy(ordinal = -1) | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| override def equals(o: Any): Boolean = o match { | ||||||||||||||||||
| case other: StructuralExpr => | ||||||||||||||||||
| normalized(e).semanticEquals(normalized(other.e)) | ||||||||||||||||||
| case _ => false | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| override def hashCode: Int = normalized(e).semanticHash() | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| type EquivalenceMap = mutable.HashMap[Expr, mutable.ArrayBuffer[Expression]] | ||||||||||||||||||
|
|
||||||||||||||||||
| // For each expression, the set of equivalent expressions. | ||||||||||||||||||
| private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] | ||||||||||||||||||
|
|
||||||||||||||||||
| // For each expression, the set of structurally equivalent expressions. | ||||||||||||||||||
| // Among expressions with same structure, there are different sub-set of expressions | ||||||||||||||||||
| // which are semantically different to each others. Thus, under each key, the value is | ||||||||||||||||||
| // the map structure used to do semantically sub-expression elimination. | ||||||||||||||||||
| private val structEquivalenceMap = mutable.HashMap.empty[StructuralExpr, EquivalenceMap] | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Adds each expression to this data structure, grouping them with existing equivalent | ||||||||||||||||||
| * expressions. Non-recursive. | ||||||||||||||||||
| * Returns true if there was already a matching expression. | ||||||||||||||||||
| */ | ||||||||||||||||||
| def addExpr(expr: Expression): Boolean = { | ||||||||||||||||||
| def addExpr(expr: Expression, exprMap: EquivalenceMap = this.equivalenceMap): Boolean = { | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Do we need
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addExpr is also used at PhysicalAggregation: spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala Lines 222 to 229 in 2f3997f
|
||||||||||||||||||
| if (expr.deterministic) { | ||||||||||||||||||
| val e: Expr = Expr(expr) | ||||||||||||||||||
| val f = equivalenceMap.get(e) | ||||||||||||||||||
| val f = exprMap.get(e) | ||||||||||||||||||
| if (f.isDefined) { | ||||||||||||||||||
| f.get += expr | ||||||||||||||||||
| true | ||||||||||||||||||
| } else { | ||||||||||||||||||
| equivalenceMap.put(e, mutable.ArrayBuffer(expr)) | ||||||||||||||||||
| exprMap.put(e, mutable.ArrayBuffer(expr)) | ||||||||||||||||||
| false | ||||||||||||||||||
| } | ||||||||||||||||||
| } else { | ||||||||||||||||||
|
|
@@ -65,35 +92,97 @@ class EquivalentExpressions { | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Adds the expression to this data structure recursively. Stops if a matching expression | ||||||||||||||||||
| * is found. That is, if `expr` has already been added, its children are not added. | ||||||||||||||||||
| * Adds each expression to structural expression data structure, grouping them with existing | ||||||||||||||||||
| * structurally equivalent expressions. Non-recursive. Returns false if this doesn't add input | ||||||||||||||||||
| * expression actually. | ||||||||||||||||||
| */ | ||||||||||||||||||
| def addStructExpr(expr: Expression): Boolean = { | ||||||||||||||||||
| if (expr.deterministic) { | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot always share a function for non-deterministic cases? e.g.,
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-deterministic expressions can't do sub-expression elimination.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I see.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, this idea is limited to common subexprs? For example, the idea can cover a case like; ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is probably suitable. Only if we want these functions sum(a+b)...etc. to be called in split functions. Their inputs can be parameterized. |
||||||||||||||||||
| // For structural equivalent expressions, we need to pass in int type ordinals into | ||||||||||||||||||
| // split functions. If the number of ordinals is more than JVM function limit, we skip | ||||||||||||||||||
| // this expression. | ||||||||||||||||||
| // We calculate function parameter length by the number of ints plus `INPUT_ROW` plus | ||||||||||||||||||
| // a int type result array index. | ||||||||||||||||||
| val refs = expr.collect { | ||||||||||||||||||
| case _: BoundReference => Literal(0) | ||||||||||||||||||
| } | ||||||||||||||||||
| val parameterLength = CodeGenerator.calculateParamLength(refs) + 2 | ||||||||||||||||||
| if (CodeGenerator.isValidParamLength(parameterLength)) { | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the length goes over the limit, the current logic gives up eliminating common exprs? If so, can we fall back into the non-structural mode?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. |
||||||||||||||||||
| val e: StructuralExpr = StructuralExpr(expr) | ||||||||||||||||||
| val f = structEquivalenceMap.get(e) | ||||||||||||||||||
| if (f.isDefined) { | ||||||||||||||||||
| addExpr(expr, f.get) | ||||||||||||||||||
| } else { | ||||||||||||||||||
| val exprMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] | ||||||||||||||||||
| addExpr(expr, exprMap) | ||||||||||||||||||
| structEquivalenceMap.put(e, exprMap) | ||||||||||||||||||
| } | ||||||||||||||||||
| true | ||||||||||||||||||
| } else { | ||||||||||||||||||
| false | ||||||||||||||||||
| } | ||||||||||||||||||
| } else { | ||||||||||||||||||
| false | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Checks if we skip add sub-expressions for given expression. | ||||||||||||||||||
| */ | ||||||||||||||||||
| def addExprTree(expr: Expression): Unit = { | ||||||||||||||||||
| val skip = expr.isInstanceOf[LeafExpression] || | ||||||||||||||||||
| private def skipExpr(expr: Expression): Boolean = { | ||||||||||||||||||
| expr.isInstanceOf[LeafExpression] || | ||||||||||||||||||
| // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the | ||||||||||||||||||
| // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. | ||||||||||||||||||
| expr.find(_.isInstanceOf[LambdaVariable]).isDefined | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| // There are some special expressions that we should not recurse into all of its children. | ||||||||||||||||||
| // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) | ||||||||||||||||||
| // 2. If: common subexpressions will always be evaluated at the beginning, but the true and | ||||||||||||||||||
| // false expressions in `If` may not get accessed, according to the predicate | ||||||||||||||||||
| // expression. We should only recurse into the predicate expression. | ||||||||||||||||||
| // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain | ||||||||||||||||||
| // condition. We should only recurse into the first condition expression as it | ||||||||||||||||||
| // will always get accessed. | ||||||||||||||||||
| // 4. Coalesce: it's also a conditional expression, we should only recurse into the first | ||||||||||||||||||
| // children, because others may not get accessed. | ||||||||||||||||||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
| // 5. SortPrefix: skipt the direct child of SortPrefix which is an unevaluable SortOrder. | ||||||||||||||||||
| private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { | ||||||||||||||||||
| case _: CodegenFallback => Nil | ||||||||||||||||||
| case i: If => i.predicate :: Nil | ||||||||||||||||||
| case c: CaseWhen => c.children.head :: Nil | ||||||||||||||||||
| case c: Coalesce => c.children.head :: Nil | ||||||||||||||||||
| case s: SortPrefix => s.child.child :: Nil | ||||||||||||||||||
| case other => other.children | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Adds the expression to this data structure recursively. Stops if a matching expression | ||||||||||||||||||
| * is found. That is, if `expr` has already been added, its children are not added. | ||||||||||||||||||
| */ | ||||||||||||||||||
| def addExprTree( | ||||||||||||||||||
| expr: Expression, | ||||||||||||||||||
| exprMap: EquivalenceMap = this.equivalenceMap): Unit = { | ||||||||||||||||||
| val skip = skipExpr(expr) | ||||||||||||||||||
|
|
||||||||||||||||||
| // There are some special expressions that we should not recurse into all of its children. | ||||||||||||||||||
| // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) | ||||||||||||||||||
| // 2. If: common subexpressions will always be evaluated at the beginning, but the true and | ||||||||||||||||||
| // false expressions in `If` may not get accessed, according to the predicate | ||||||||||||||||||
| // expression. We should only recurse into the predicate expression. | ||||||||||||||||||
| // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain | ||||||||||||||||||
| // condition. We should only recurse into the first condition expression as it | ||||||||||||||||||
| // will always get accessed. | ||||||||||||||||||
| // 4. Coalesce: it's also a conditional expression, we should only recurse into the first | ||||||||||||||||||
| // children, because others may not get accessed. | ||||||||||||||||||
| def childrenToRecurse: Seq[Expression] = expr match { | ||||||||||||||||||
| case _: CodegenFallback => Nil | ||||||||||||||||||
| case i: If => i.predicate :: Nil | ||||||||||||||||||
| case c: CaseWhen => c.children.head :: Nil | ||||||||||||||||||
| case c: Coalesce => c.children.head :: Nil | ||||||||||||||||||
| case other => other.children | ||||||||||||||||||
| if (!skip && !addExpr(expr, exprMap)) { | ||||||||||||||||||
| childrenToRecurse(expr).foreach(addExprTree(_, exprMap)) | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Adds the expression to structural data structure recursively. Returns false if this doesn't add | ||||||||||||||||||
| * the input expression actually. | ||||||||||||||||||
| */ | ||||||||||||||||||
| def addStructuralExprTree(expr: Expression): Boolean = { | ||||||||||||||||||
| val skip = skipExpr(expr) || expr.isInstanceOf[CodegenFallback] | ||||||||||||||||||
|
|
||||||||||||||||||
| if (!skip && !addExpr(expr)) { | ||||||||||||||||||
| childrenToRecurse.foreach(addExprTree) | ||||||||||||||||||
| if (!skip && addStructExpr(expr)) { | ||||||||||||||||||
| childrenToRecurse(expr).foreach(addStructuralExprTree) | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do we want to add
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addExprTree doesn't add (_) too. Just followed it. If no special reason, I will leave it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am neutral on this.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At that line, if not add a (_) to call addExprTree, will see compilation error: Because addExprTree actually needs to arguments, it doesn't match foreach's argument type. |
||||||||||||||||||
| true | ||||||||||||||||||
| } else { | ||||||||||||||||||
| false | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -112,6 +201,18 @@ class EquivalentExpressions { | |||||||||||||||||
| equivalenceMap.values.map(_.toSeq).toSeq | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| def getStructurallyEquivalentExprs(e: Expression): Seq[Seq[Expression]] = { | ||||||||||||||||||
| val key = StructuralExpr(e) | ||||||||||||||||||
| structEquivalenceMap.get(key).map(_.values.map(_.toSeq).toSeq).getOrElse(Seq.empty) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Returns all the structurally equivalent sets of expressions. | ||||||||||||||||||
| */ | ||||||||||||||||||
| def getAllStructuralExpressions: Map[StructuralExpr, EquivalenceMap] = { | ||||||||||||||||||
| structEquivalenceMap.toMap | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
| * Returns the state of the data structure as a string. If `all` is false, skips sets of | ||||||||||||||||||
| * equivalent expressions with cardinality 1. | ||||||||||||||||||
|
|
||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.