From 548e45f63f64a3f9e9709af39610ce50390b0fa1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Nov 2016 04:25:21 +0000 Subject: [PATCH] Skip subexpression elimination for conditional expressions. --- .../expressions/EquivalentExpressions.scala | 4 +++ .../SubexpressionEliminationSuite.scala | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index b8e2b67b2fe9..7f1a583475a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -76,11 +76,15 @@ class EquivalentExpressions { // There are some special expressions that we should not recurse into children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. + // 3. Conditional expressions: sub-expression elimination for children of conditional + // expressions would possibly cause not excepted exception and performance regression. val shouldRecurse = root match { // TODO: some expressions implements `CodegenFallback` but can still do codegen, // e.g. `CaseWhen`, we should support them. case _: CodegenFallback => false case _: ReferenceToExpressions if skipReferenceToExpressions => false + case _: If => false + case _: CaseWhenBase => false case _ => true } if (!skip && !addExpr(root) && shouldRecurse) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 1e39b24fe877..62b7b2f7eae4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -171,4 +171,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } + + test("Conditional expressions") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + // `If` expression + val ifExpr = If( + GreaterThan(one, two), + add, + one) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(ifExpr, true) + equivalence.addExprTree(add, true) + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) + + // `CaseWhen` expression + val caseWhenExpr = CaseWhen( + Seq( + (GreaterThan(one, two), add), + (GreaterThan(two, one), add2))) + + equivalence = new EquivalentExpressions + equivalence.addExprTree(caseWhenExpr, true) + equivalence.addExprTree(add, true) + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) + } }