Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,15 @@ class EquivalentExpressions {
// There are some special expressions that we should not recurse into children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
// 3. Conditional expressions: sub-expression elimination for children of conditional
// expressions would possibly cause not excepted exception and performance regression.
val shouldRecurse = root match {
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false
case _: ReferenceToExpressions if skipReferenceToExpressions => false
case _: If => false
case _: CaseWhenBase => false
case _ => true
}
if (!skip && !addExpr(root) && shouldRecurse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}

test("Conditional expressions") {
val one = Literal(1)
val two = Literal(2)

val add = Add(one, two)
val abs = Abs(add)
val add2 = Add(add, add)

// `If` expression
val ifExpr = If(
GreaterThan(one, two),
add,
one)

var equivalence = new EquivalentExpressions
equivalence.addExprTree(ifExpr, true)
equivalence.addExprTree(add, true)
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)

// `CaseWhen` expression
val caseWhenExpr = CaseWhen(
Seq(
(GreaterThan(one, two), add),
(GreaterThan(two, one), add2)))

equivalence = new EquivalentExpressions
equivalence.addExprTree(caseWhenExpr, true)
equivalence.addExprTree(add, true)
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
}
}