From b440f6d309c0f94ca7e44f5e6a627aa97add1dc1 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Mon, 24 May 2021 07:40:57 -0400 Subject: [PATCH] CaseWhen needs to have an elseValue to be considered for subexpression elimination --- .../expressions/EquivalentExpressions.scala | 8 +++++++- .../SubexpressionEliminationSuite.scala | 18 +++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1dfff412d9a8..d03dd53f31e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -128,7 +128,13 @@ class EquivalentExpressions { // a subexpression among values doesn't need to be in conditions because no matter which // condition is true, it will be evaluated. val conditions = c.branches.tail.map(_._1) - val values = c.branches.map(_._2) ++ c.elseValue + // For an expression to be in all branch values of a CaseWhen statement, it must also be in + // the elseValue. + val values = if (c.elseValue.nonEmpty) { + c.branches.map(_._2) ++ c.elseValue + } else { + Nil + } Seq(conditions, values) case c: Coalesce => Seq(c.children.tail) case _ => Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 65671d253dc5..dd2162e27923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -209,7 +209,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel (GreaterThan(add2, Literal(4)), add1) :: (GreaterThan(add2, Literal(5)), add1) :: Nil - val caseWhenExpr2 = CaseWhen(conditions2, None) + val caseWhenExpr2 = CaseWhen(conditions2, add1) val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(caseWhenExpr2) @@ -309,6 +309,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel CodeGenerator.compile(code) } } + + test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an " + + "elseValue") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val conditions = (GreaterThan(add1, Literal(3)), add1) :: + (GreaterThan(add2, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + + val caseWhenExpr = CaseWhen(conditions, None) + val equivalence = new EquivalentExpressions + equivalence.addExprTree(caseWhenExpr) + + // `add1` is not in the elseValue, so we can't extract it from the branches + assert(equivalence.getAllEquivalentExprs.count(_.size == 2) == 0) + } } case class CodegenFallbackExpression(child: Expression)