diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 32eb650aa6d2..67298678e8f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -143,7 +143,13 @@ class EquivalentExpressions { // a subexpression among values doesn't need to be in conditions because no matter which // condition is true, it will be evaluated. val conditions = c.branches.tail.map(_._1) - val values = c.branches.map(_._2) ++ c.elseValue + // For an expression to be in all branch values of a CaseWhen statement, it must also be in + // the elseValue. + val values = if (c.elseValue.nonEmpty) { + c.branches.map(_._2) ++ c.elseValue + } else { + Nil + } Seq(conditions, values) case c: Coalesce => Seq(c.children.tail) case _ => Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 5e6b07422c76..7a17a05439eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -209,7 +209,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel (GreaterThan(add2, Literal(4)), add1) :: (GreaterThan(add2, Literal(5)), add1) :: Nil - val caseWhenExpr2 = CaseWhen(conditions2, None) + val caseWhenExpr2 = CaseWhen(conditions2, add1) val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(caseWhenExpr2) @@ -317,7 +317,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val add3 = Add(add1, add2) val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil - val caseWhenExpr = CaseWhen(condition, None) + val caseWhenExpr = CaseWhen(condition, Add(add3, Literal(1))) val equivalence = new EquivalentExpressions equivalence.addExprTree(caseWhenExpr) @@ -354,6 +354,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add)))) } + + test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an " + + "elseValue") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val conditions = (GreaterThan(add1, Literal(3)), add1) :: + (GreaterThan(add2, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + + val caseWhenExpr = CaseWhen(conditions, None) + val equivalence = new EquivalentExpressions + equivalence.addExprTree(caseWhenExpr) + + // `add1` is not in the elseValue, so we can't extract it from the branches + assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 19c905326042..6cdcc88e17e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2870,13 +2870,15 @@ class DataFrameSuite extends QueryTest s }) val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0, - functions.length(simpleUDF($"id")))) + functions.length(simpleUDF($"id"))).otherwise( + functions.length(simpleUDF($"id")) + 1)) df1.collect() assert(accum.value == 5) val nondeterministicUDF = simpleUDF.asNondeterministic() val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0, - functions.length(nondeterministicUDF($"id")))) + functions.length(nondeterministicUDF($"id"))).otherwise( + functions.length(nondeterministicUDF($"id")) + 1)) df2.collect() assert(accum.value == 15) }