diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index b2625bddeecf4..6c5dec133d2a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -486,6 +486,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l) case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l) + case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) => + if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond + case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) => + if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond) + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. // If there are no more branches left, just use the else value. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 7c9a67d7554e2..0d5218ac629e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -141,7 +141,7 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2))) assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), - CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral))) + GreaterThanOrEqual(Rand(1), Literal(0.5))) assertEquivalent(EqualTo(nonDeterministic, Literal(3)), CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral))) @@ -269,4 +269,13 @@ class PushFoldableIntoBranchesSuite Literal.create(null, BooleanType)) } } + + test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { + assertEquivalent( + EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)), + 'a > 10 <=> TrueLiteral) + assertEquivalent( + EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)), + Not('a > 10 <=> TrueLiteral)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 317984eba2261..f3edd70bcfb12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -243,4 +243,40 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P Literal.create(null, IntegerType)) } } + + test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { + // verify the boolean equivalence of all transformations involved + val fields = Seq( + 'cond.boolean.notNull, + 'cond_nullable.boolean, + 'a.boolean, + 'b.boolean + ) + val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) } + + val exprs = Seq( + // actual expressions of the transformations: original -> transformed + CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral) -> cond, + CaseWhen(Seq((cond, FalseLiteral)), TrueLiteral) -> !cond, + CaseWhen(Seq((cond_nullable, TrueLiteral)), FalseLiteral) -> (cond_nullable <=> true), + CaseWhen(Seq((cond_nullable, FalseLiteral)), TrueLiteral) -> (!(cond_nullable <=> true))) + + // check plans + for ((originalExpr, expectedExpr) <- exprs) { + assertEquivalent(originalExpr, expectedExpr) + } + + // check evaluation + val binaryBooleanValues = Seq(true, false) + val ternaryBooleanValues = Seq(true, false, null) + for (condVal <- binaryBooleanValues; + condNullableVal <- ternaryBooleanValues; + aVal <- ternaryBooleanValues; + bVal <- ternaryBooleanValues; + (originalExpr, expectedExpr) <- exprs) { + val inputRow = create_row(condVal, condNullableVal, aVal, bVal) + val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow) + checkEvaluation(originalExpr, optimizedVal, inputRow) + } + } }