Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ParameterizedBoundReference}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable

/**
Expand All @@ -40,23 +40,50 @@ class EquivalentExpressions {
override def hashCode: Int = e.semanticHash()
}

/**
* Wrapper around an Expression that provides structural semantic equality.
*/
case class StructuralExpr(e: Expression) {
def normalized(expr: Expression): Expression = {
expr.transformUp {
case b: BoundReference =>
b.copy(ordinal = -1)
}
}
override def equals(o: Any): Boolean = o match {
case other: StructuralExpr =>
normalized(e).semanticEquals(normalized(other.e))
case _ => false
}

override def hashCode: Int = normalized(e).semanticHash()
}

type EquivalenceMap = mutable.HashMap[Expr, mutable.ArrayBuffer[Expression]]

// For each expression, the set of equivalent expressions.
private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]

// For each expression, the set of structurally equivalent expressions.
// Among expressions with same structure, there are different sub-set of expressions
// which are semantically different to each others. Thus, under each key, the value is
// the map structure used to do semantically sub-expression elimination.
private val structEquivalenceMap = mutable.HashMap.empty[StructuralExpr, EquivalenceMap]

/**
* Adds each expression to this data structure, grouping them with existing equivalent
* expressions. Non-recursive.
* Returns true if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
def addExpr(expr: Expression, exprMap: EquivalenceMap = this.equivalenceMap): Boolean = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need = this.equivalenceMap? It seems that all of the callers pass two arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addExpr is also used at PhysicalAggregation:

expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}

if (expr.deterministic) {
val e: Expr = Expr(expr)
val f = equivalenceMap.get(e)
val f = exprMap.get(e)
if (f.isDefined) {
f.get += expr
true
} else {
equivalenceMap.put(e, mutable.ArrayBuffer(expr))
exprMap.put(e, mutable.ArrayBuffer(expr))
false
}
} else {
Expand All @@ -65,35 +92,97 @@ class EquivalentExpressions {
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
* Adds each expression to structural expression data structure, grouping them with existing
* structurally equivalent expressions. Non-recursive. Returns false if this doesn't add input
* expression actually.
*/
def addStructExpr(expr: Expression): Boolean = {
if (expr.deterministic) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot always share a function for non-deterministic cases? e.g.,

int subExpr1 = input[0] + random();
int subExpr2 = input[1] + random();
=>
int subExpr1 = subExpr(input[0]);
int subExpr2 = subExpr(input[1]);

int subExpr(int v) { return v + random(); }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-deterministic expressions can't do sub-expression elimination.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, this idea is limited to common subexprs? For example, the idea can cover a case like;

select sum(a + b), sum(b + c), sum(c + d), sum(d + e) from values (1, 1, 1, 1, 1) t(a, b, c, d, e)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably suitable. Only if we want these functions sum(a+b)...etc. to be called in split functions. Their inputs can be parameterized.

// For structural equivalent expressions, we need to pass in int type ordinals into
// split functions. If the number of ordinals is more than JVM function limit, we skip
// this expression.
// We calculate function parameter length by the number of ints plus `INPUT_ROW` plus
// a int type result array index.
val refs = expr.collect {
case _: BoundReference => Literal(0)
}
val parameterLength = CodeGenerator.calculateParamLength(refs) + 2
if (CodeGenerator.isValidParamLength(parameterLength)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the length goes over the limit, the current logic gives up eliminating common exprs? If so, can we fall back into the non-structural mode?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

val e: StructuralExpr = StructuralExpr(expr)
val f = structEquivalenceMap.get(e)
if (f.isDefined) {
addExpr(expr, f.get)
} else {
val exprMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
addExpr(expr, exprMap)
structEquivalenceMap.put(e, exprMap)
}
true
} else {
false
}
} else {
false
}
}

/**
* Checks if we skip add sub-expressions for given expression.
*/
def addExprTree(expr: Expression): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
private def skipExpr(expr: Expression): Boolean = {
expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
expr.find(_.isInstanceOf[LambdaVariable]).isDefined
}


// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
// 5. SortPrefix: skipt the direct child of SortPrefix which is an unevaluable SortOrder.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case s: SortPrefix => s.child.child :: Nil
case other => other.children
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
*/
def addExprTree(
expr: Expression,
exprMap: EquivalenceMap = this.equivalenceMap): Unit = {
val skip = skipExpr(expr)

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
def childrenToRecurse: Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
if (!skip && !addExpr(expr, exprMap)) {
childrenToRecurse(expr).foreach(addExprTree(_, exprMap))
}
}

/**
* Adds the expression to structural data structure recursively. Returns false if this doesn't add
* the input expression actually.
*/
def addStructuralExprTree(expr: Expression): Boolean = {
val skip = skipExpr(expr) || expr.isInstanceOf[CodegenFallback]

if (!skip && !addExpr(expr)) {
childrenToRecurse.foreach(addExprTree)
if (!skip && addStructExpr(expr)) {
childrenToRecurse(expr).foreach(addStructuralExprTree)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want to add (_)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addExprTree doesn't add (_) too. Just followed it. If no special reason, I will leave it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am neutral on this.
I was curious why this line added (_), but here does not add.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that line, if not add a (_) to call addExprTree, will see compilation error:

[error]  found   : (org.apache.spark.sql.catalyst.expressions.Expression, equivalentExpressions.EquivalenceMap) => Unit                                                               
[error]     (which expands to)  (org.apache.spark.sql.catalyst.expressions.Expression, scala.collection.mutable.HashMap[equivalentExpressions.Expr,scala.collection.mutable.ArrayBuff$r[org.apache.spark.sql.catalyst.expressions.Expression]]) => Unit                                                                                                                     
[error]  required: org.apache.spark.sql.catalyst.expressions.Expression => ?               
[error]     expressions.foreach(equivalentExpressions.addExprTree)     
[error]                                               ^                                                                                                                               

Because addExprTree actually needs to arguments, it doesn't match foreach's argument type.

true
} else {
false
}
}

Expand All @@ -112,6 +201,18 @@ class EquivalentExpressions {
equivalenceMap.values.map(_.toSeq).toSeq
}

def getStructurallyEquivalentExprs(e: Expression): Seq[Seq[Expression]] = {
val key = StructuralExpr(e)
structEquivalenceMap.get(key).map(_.values.map(_.toSeq).toSeq).getOrElse(Seq.empty)
}

/**
* Returns all the structurally equivalent sets of expressions.
*/
def getAllStructuralExpressions: Map[StructuralExpr, EquivalenceMap] = {
structEquivalenceMap.toMap
}

/**
* Returns the state of the data structure as a string. If `all` is false, skips sets of
* equivalent expressions with cardinality 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,28 @@ abstract class Expression extends TreeNode[Expression] {
""
}

// Appends necessary parameters from sub-expression elimination.
val arguments = (ctx.subExprEliminationParameters.map(_.ordinalParam) ++
Seq(ctx.INPUT_ROW)).mkString(", ")

val parameterString = (ctx.subExprEliminationParameters.map(p => s"int ${p.ordinalParam}") ++
Seq(s"InternalRow ${ctx.INPUT_ROW}")).mkString(", ")

val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")

val funcName = ctx.freshName(nodeName)
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
|private $javaType $funcName($parameterString) {
| ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)

eval.value = JavaCode.variable(newValue, dataType)
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
eval.code = code"$javaType $newValue = $funcFullName($arguments);"
}
}

Expand Down
Loading