diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7666c4a53e5d..d1945478e205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -523,6 +523,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + + case e @ EqualTo(c @ CaseWhen(branches, elseValue), right) + if c.deterministic && + right.isInstanceOf[Literal] && branches.forall(_._2.isInstanceOf[Literal]) && + elseValue.forall(_.isInstanceOf[Literal]) => + if ((branches.map(_._2) ++ elseValue).forall(!_.equals(right))) { + FalseLiteral + } else { + e + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index bac962ced461..bdba456c3135 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -199,4 +199,38 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow)) } } + + test("SPARK-33315: simplify CaseWhen with EqualTo") { + val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) + val b = UnresolvedAttribute("b") + val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) + val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) + + assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(3))) + assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal("3")), EqualTo(caseWhen, Literal(3))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")), + FalseLiteral) + + assertEquivalent( + And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), + FalseLiteral) + + assertEquivalent( + EqualTo(CaseWhen(Seq(normalBranch, (a, Literal(1)), (c, Literal(1))), None), Literal(-1)), + FalseLiteral) + + // Do not simplify if it contains non foldable expressions. + assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)), + EqualTo(caseWhen, NonFoldableLiteral(true))) + val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None) + assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) + + // Do not simplify if it contains non-deterministic expressions. + val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b)) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + } }