Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,82 @@ class EquivalentExpressions {
}
}

private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = {
if (expr.deterministic) {
val e = Expr(expr)
if (set.contains(e)) {
true
} else {
set.add(e)
false
}
} else {
false
}
}

/**
* Adds only expressions which are common in each of given expressions, in a recursive way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
* the common expression `(c + 1)` will be added into `equivalenceMap`.
*/
private def addCommonExprs(
exprs: Seq[Expression],
addFunc: Expression => Boolean = addExpr): Unit = {
val exprSetForAll = mutable.Set[Expr]()
addExprTree(exprs.head, addExprToSet(_, exprSetForAll))

val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
val otherExprSet = mutable.Set[Expr]()
addExprTree(expr, addExprToSet(_, otherExprSet))
exprSet.intersect(otherExprSet)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle head and tail seperately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For expression head, we add underlying expressions into exprSetForAll set. But for expressions in tail, we keep intersect between exprSetForAll and exprSet.

We can merge two blocks, but in the block we need to check if current expression is head expression and do different logic based on the check.

I prefer current one since it looks simpler.


commonExprSet.foreach(expr => addFunc(expr.e))
}

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
}

// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
case i: If => Seq(Seq(i.trueValue, i.falseValue))
case c: CaseWhen =>
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
// a subexpression among values doesn't need to be in conditions because no matter which
// condition is true, it will be evaluated.
val conditions = c.branches.tail.map(_._1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a flaw here: we exclude the first condition, so a common subexpressions in the rest of the conditions doesn't mean it's always evaluated.

e.g. CaseWhen(cond1, ... cond2, ..., cond2, ...), cond2 is shared between the rest conditions but it's not always evaluated.

Copy link
Member Author

@viirya viirya Jun 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is related to #32977. This looks more a aggressive optimization. Consider if we respect short-circuit evaluation practice for CaseWhen, this might be an issue if users reply short-circuit evaluation to guard later conditions.

Safest approach is to only consider all conditions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT? Should we only consider all conditions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should. I hit an issue caused by it in my refactor and I'll open a PR for the refactor with multiple bugs fixed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, does #32980 conflict with your refactor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only some trivial conflicts, #32980 should be merged first as it has been reviewed and approved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I also addressed this issue in #32987 which assumed CaseWhen's (and Coalesce) should short circuit and guard later conditions. The main benefit/difference is if you have

CaseWhen(cond1, ..., cond1, ..., cond2, ...), cond1 gets pulled out as a subexpression when it wouldn't otherwise even with #33142 I think

val values = c.branches.map(_._2) ++ c.elseValue
Seq(conditions, values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to handle conditions and values separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A subexpression in conditions is definitely run because it is shared among conditions. It doesn't need to be shared in values. Similarly, a subexpression among values doesn't need to be in conditions because no matter which condition is true, it will be evaluated.

Copy link
Member

@maropu maropu Nov 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok. Thanks. Could you leave comments about it in the code side?

case c: Coalesce => Seq(c.children.tail)
case _ => Nil
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
*/
def addExprTree(expr: Expression): Unit = {
def addExprTree(
expr: Expression,
addFunc: Expression => Boolean = addExpr): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
Expand All @@ -78,26 +149,9 @@ class EquivalentExpressions {
// can cause error like NPE.
(expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
def childrenToRecurse: Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
}

if (!skip && !addExpr(expr)) {
childrenToRecurse.foreach(addExprTree)
if (!skip && !addFunc(expr)) {
childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ class CodegenContext extends Logging {
val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree)
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,111 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}

test("Children of conditional expressions") {
val condition = And(Literal(true), Literal(false))
test("Children of conditional expressions: If") {
val add = Add(Literal(1), Literal(2))
val ifExpr = If(condition, add, add)
val condition = GreaterThan(add, Literal(3))

val equivalence = new EquivalentExpressions
equivalence.addExprTree(ifExpr)
// the `add` inside `If` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
// only ifExpr and its predicate expression
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
val ifExpr1 = If(condition, add, add)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(ifExpr1)

// `add` is in both two branches of `If` and predicate.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
// one-time expressions: only ifExpr and its predicate expression
assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use contains method? HashMap can not guarantee the order

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will create a follow-up for making sure it will not possibly flaky. Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #30371.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @leoluan2009 and @viirya . The follow-up is merged to reduce the flakiness.

assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition))

// Repeated `add` is only in one branch, so we don't count it.
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(ifExpr2)

assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3)

val ifExpr3 = If(condition, ifExpr1, ifExpr1)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(ifExpr3)

// `add`: 2, `condition`: 2
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2)
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition))

// `ifExpr1`, `ifExpr3`
assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2)
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1))
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3))
}

test("Children of conditional expressions: CaseWhen") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val conditions1 = (GreaterThan(add2, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil

val caseWhenExpr1 = CaseWhen(conditions1, None)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(caseWhenExpr1)

// `add2` is repeatedly in all conditions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add1 is also repeated. Why it's not included?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We treat the first condition specially because it is definitely run. So it counts one for add2. Other conditions all contain add2 so it counts for one. That is where the count 2 comes from for add2.

For add1, although all values contain it, it is definitely run, so we count it one. If no other expression contains add1, we don't extract subexpression for add1 as it will run just once (we only run one value of CaseWhen).

assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))

val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil

val caseWhenExpr2 = CaseWhen(conditions2, None)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(caseWhenExpr2)

// `add1` is repeatedly in all branch values, and first predicate.
assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1))

// Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil

val caseWhenExpr3 = CaseWhen(conditions3, None)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(caseWhenExpr3)
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0)
}

test("Children of conditional expressions: Coalesce") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val conditions1 = GreaterThan(add2, Literal(3)) ::
GreaterThan(add2, Literal(4)) ::
GreaterThan(add2, Literal(5)) :: Nil

val coalesceExpr1 = Coalesce(conditions1)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(coalesceExpr1)

// `add2` is repeatedly in all conditions.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))

// Negative case. `add1` and `add2` both are not used in all branches.
val conditions2 = GreaterThan(add1, Literal(3)) ::
GreaterThan(add2, Literal(4)) ::
GreaterThan(add2, Literal(5)) :: Nil

val coalesceExpr2 = Coalesce(conditions2)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(coalesceExpr2)

assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0)
}
}

Expand Down