From c78882a1dd06c90c5971b6ac86a0f12561b7ca4f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 23 Dec 2020 09:36:37 +0800 Subject: [PATCH 1/6] simplify CaseWhen when one clause is null and another is boolean --- .../sql/catalyst/optimizer/expressions.scala | 13 +++++ .../optimizer/SimplifyConditionalSuite.scala | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 47b968f6ebdd..187698275217 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -484,6 +484,19 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l) case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l) + case CaseWhen(Seq((cond, l @ Literal(null, _))), Some(FalseLiteral)) + if !cond.nullable => + And(cond, l) + case CaseWhen(Seq((cond, l @ Literal(null, _))), Some(TrueLiteral)) + if !cond.nullable => + Or(Not(cond), l) + case CaseWhen(Seq((cond, FalseLiteral)), elseOpt @ (Some(Literal(null, BooleanType)) | None)) + if !cond.nullable => + And(Not(cond), elseOpt.getOrElse(Literal(null, BooleanType))) + case CaseWhen(Seq((cond, TrueLiteral)), elseOpt @ (Some(Literal(null, BooleanType)) | None)) + if !cond.nullable => + Or(cond, elseOpt.getOrElse(Literal(null, BooleanType))) + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. // If there are no more branches left, just use the else value. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 328fc107e1c1..a5fe3a712d29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -215,4 +215,52 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) } + + test("SPARK-33884: simplify CaseWhen when one clause is null and another is boolean") { + val p = IsNull('a) + val nullLiteral = Literal(null, BooleanType) + assertEquivalent(CaseWhen(Seq((p, nullLiteral)), FalseLiteral), And(p, nullLiteral)) + assertEquivalent(CaseWhen(Seq((p, nullLiteral)), TrueLiteral), Or(IsNotNull('a), nullLiteral)) + assertEquivalent(CaseWhen(Seq((p, FalseLiteral)), nullLiteral), And(IsNotNull('a), nullLiteral)) + assertEquivalent(CaseWhen(Seq((p, FalseLiteral)), None), And(IsNotNull('a), nullLiteral)) + assertEquivalent(CaseWhen(Seq((p, TrueLiteral)), nullLiteral), Or(p, nullLiteral)) + assertEquivalent(CaseWhen(Seq((p, TrueLiteral)), None), Or(p, nullLiteral)) + + // the rule should not apply to nullable predicate + Seq(TrueLiteral, FalseLiteral).foreach { b => + assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), nullLiteral)), b), + CaseWhen(Seq((GreaterThan('a, 42), nullLiteral)), b)) + assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), b)), nullLiteral), + CaseWhen(Seq((GreaterThan('a, 42), b)), nullLiteral)) + assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), b)), None), + CaseWhen(Seq((GreaterThan('a, 42), b)), None)) + } + + // check evaluation also + Seq(TrueLiteral, FalseLiteral).foreach { b => + checkEvaluation(CaseWhen(Seq((b, nullLiteral)), FalseLiteral), + And(b, nullLiteral).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((b, nullLiteral)), TrueLiteral), + Or(Not(b), nullLiteral).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((b, FalseLiteral)), nullLiteral), + And(Not(b), nullLiteral).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((b, FalseLiteral)), None), + And(Not(b), nullLiteral).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((b, TrueLiteral)), nullLiteral), + Or(b, nullLiteral).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((b, TrueLiteral)), None), + Or(b, nullLiteral).eval(EmptyRow)) + } + + // should have no effect on expressions with nullable if condition + assert((Factorial(5) > 100L).nullable) + Seq(TrueLiteral, FalseLiteral).foreach { b => + checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, nullLiteral)), b), + CaseWhen(Seq((Factorial(5) > 100L, nullLiteral)), b).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, b)), nullLiteral), + CaseWhen(Seq((Factorial(5) > 100L, b)), nullLiteral).eval(EmptyRow)) + checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, b)), None), + CaseWhen(Seq((Factorial(5) > 100L, b)), None).eval(EmptyRow)) + } + } } From 44733f8ad27f4fb53256accf7652b5ef54f72cab Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 23 Dec 2020 09:49:53 +0800 Subject: [PATCH 2/6] Add test to PushFoldableIntoBranchesSuite --- .../optimizer/PushFoldableIntoBranchesSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 02307a52ebb8..e22b4c7b50e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -258,4 +258,17 @@ class PushFoldableIntoBranchesSuite EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")), CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) } + + test("SPARK-33884: simplify CaseWhen when one clause is null and another is boolean") { + val p = IsNull('a) + val nullLiteral = Literal(null, BooleanType) + assertEquivalent(EqualTo( + CaseWhen(Seq((p, Literal.create(null, IntegerType))), Literal(1)), + Literal(2)), + And(p, nullLiteral)) + assertEquivalent(EqualTo( + CaseWhen(Seq((p, Literal("str"))), Literal("1")).cast(IntegerType), + Literal(2)), + And(p, nullLiteral)) + } } From 91ec8b254314b904178b44bf1ba5ecabfdf7b505 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 25 Dec 2020 12:50:58 +0800 Subject: [PATCH 3/6] fix --- .../sql/catalyst/optimizer/expressions.scala | 20 +++--- .../PushFoldableIntoBranchesSuite.scala | 13 +--- .../optimizer/SimplifyConditionalSuite.scala | 64 ++++++------------- 3 files changed, 31 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index b40d2d56372a..ff0dc33f11ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(_, TrueLiteral, TrueLiteral) => TrueLiteral + case If(_, FalseLiteral, FalseLiteral) => FalseLiteral case If(cond, TrueLiteral, FalseLiteral) => cond case If(cond, FalseLiteral, TrueLiteral) => Not(cond) case If(cond, trueValue, falseValue) @@ -484,18 +486,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l) case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l) - case CaseWhen(Seq((cond, l @ Literal(null, _))), Some(FalseLiteral)) - if !cond.nullable => - And(cond, l) - case CaseWhen(Seq((cond, l @ Literal(null, _))), Some(TrueLiteral)) - if !cond.nullable => - Or(Not(cond), l) - case CaseWhen(Seq((cond, FalseLiteral)), elseOpt @ (Some(Literal(null, BooleanType)) | None)) - if !cond.nullable => - And(Not(cond), elseOpt.getOrElse(Literal(null, BooleanType))) - case CaseWhen(Seq((cond, TrueLiteral)), elseOpt @ (Some(Literal(null, BooleanType)) | None)) - if !cond.nullable => - Or(cond, elseOpt.getOrElse(Literal(null, BooleanType))) + case CaseWhen(branches, Some(TrueLiteral)) + if branches.forall(_._2 == TrueLiteral) => TrueLiteral + case CaseWhen(branches, Some(FalseLiteral)) + if branches.forall(_._2 == FalseLiteral) => FalseLiteral + case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) => cond + case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) => Not(cond) case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 5ea2577416b3..9ac9129e0cc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -270,16 +270,9 @@ class PushFoldableIntoBranchesSuite } } - test("SPARK-33884: simplify CaseWhen when one clause is null and another is boolean") { - val p = IsNull('a) - val nullLiteral = Literal(null, BooleanType) + test("SPARK-33884: simplify conditional if all branches are foldable boolean type") { assertEquivalent(EqualTo( - CaseWhen(Seq((p, Literal.create(null, IntegerType))), Literal(1)), - Literal(2)), - And(p, nullLiteral)) - assertEquivalent(EqualTo( - CaseWhen(Seq((p, Literal("str"))), Literal("1")).cast(IntegerType), - Literal(2)), - And(p, nullLiteral)) + CaseWhen(Seq((IsNull('a), Literal(0)), (IsNull('b), Literal(1))), Literal(2)), Literal(3)), + FalseLiteral) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 5e0b05e12258..3b2aeac1a0ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -207,6 +207,10 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assertEquivalent( If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), IsNull(UnresolvedAttribute("a"))) + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, TrueLiteral), TrueLiteral) + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, FalseLiteral), FalseLiteral) assertEquivalent( If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), @@ -224,51 +228,23 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } } - test("SPARK-33884: simplify CaseWhen when one clause is null and another is boolean") { - val p = IsNull('a) - val nullLiteral = Literal(null, BooleanType) - assertEquivalent(CaseWhen(Seq((p, nullLiteral)), FalseLiteral), And(p, nullLiteral)) - assertEquivalent(CaseWhen(Seq((p, nullLiteral)), TrueLiteral), Or(IsNotNull('a), nullLiteral)) - assertEquivalent(CaseWhen(Seq((p, FalseLiteral)), nullLiteral), And(IsNotNull('a), nullLiteral)) - assertEquivalent(CaseWhen(Seq((p, FalseLiteral)), None), And(IsNotNull('a), nullLiteral)) - assertEquivalent(CaseWhen(Seq((p, TrueLiteral)), nullLiteral), Or(p, nullLiteral)) - assertEquivalent(CaseWhen(Seq((p, TrueLiteral)), None), Or(p, nullLiteral)) - - // the rule should not apply to nullable predicate - Seq(TrueLiteral, FalseLiteral).foreach { b => - assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), nullLiteral)), b), - CaseWhen(Seq((GreaterThan('a, 42), nullLiteral)), b)) - assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), b)), nullLiteral), - CaseWhen(Seq((GreaterThan('a, 42), b)), nullLiteral)) - assertEquivalent(CaseWhen(Seq((GreaterThan('a, 42), b)), None), - CaseWhen(Seq((GreaterThan('a, 42), b)), None)) - } - - // check evaluation also - Seq(TrueLiteral, FalseLiteral).foreach { b => - checkEvaluation(CaseWhen(Seq((b, nullLiteral)), FalseLiteral), - And(b, nullLiteral).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((b, nullLiteral)), TrueLiteral), - Or(Not(b), nullLiteral).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((b, FalseLiteral)), nullLiteral), - And(Not(b), nullLiteral).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((b, FalseLiteral)), None), - And(Not(b), nullLiteral).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((b, TrueLiteral)), nullLiteral), - Or(b, nullLiteral).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((b, TrueLiteral)), None), - Or(b, nullLiteral).eval(EmptyRow)) + test("SPARK-33884: simplify conditional if all branches are foldable boolean type") { + Seq(IsNull('a), GreaterThan(Rand(0), 1)).foreach { condition => + assertEquivalent(CaseWhen(Seq((condition, FalseLiteral)), FalseLiteral), FalseLiteral) + assertEquivalent(CaseWhen(Seq((condition, TrueLiteral)), TrueLiteral), TrueLiteral) + assertEquivalent( + CaseWhen(Seq((condition, FalseLiteral), (IsNull('b), FalseLiteral)), FalseLiteral), + FalseLiteral) + assertEquivalent( + CaseWhen(Seq((condition, TrueLiteral), (IsNull('b), TrueLiteral)), TrueLiteral), + TrueLiteral) } - // should have no effect on expressions with nullable if condition - assert((Factorial(5) > 100L).nullable) - Seq(TrueLiteral, FalseLiteral).foreach { b => - checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, nullLiteral)), b), - CaseWhen(Seq((Factorial(5) > 100L, nullLiteral)), b).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, b)), nullLiteral), - CaseWhen(Seq((Factorial(5) > 100L, b)), nullLiteral).eval(EmptyRow)) - checkEvaluation(CaseWhen(Seq((Factorial(5) > 100L, b)), None), - CaseWhen(Seq((Factorial(5) > 100L, b)), None).eval(EmptyRow)) - } + assertEquivalent(CaseWhen(Seq((IsNull('a), TrueLiteral)), FalseLiteral), IsNull('a)) + assertEquivalent(CaseWhen(Seq((GreaterThan(Rand(0), 1.0), TrueLiteral)), FalseLiteral), + GreaterThan(Rand(0), 1.0)) + assertEquivalent(CaseWhen(Seq((IsNull('a), FalseLiteral)), TrueLiteral), IsNotNull('a)) + assertEquivalent(CaseWhen(Seq((GreaterThan(Rand(0), 1.0), FalseLiteral)), TrueLiteral), + LessThanOrEqual(Rand(0), 1.0)) } } From 871d29fa5bf5fb557483ac07c1f0a3cb660c40dc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 25 Dec 2020 13:30:00 +0800 Subject: [PATCH 4/6] fix --- .../optimizer/PushFoldableIntoBranchesSuite.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 9ac9129e0cc6..447a6f4b2dc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -68,8 +68,7 @@ class PushFoldableIntoBranchesSuite assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), GreaterThanOrEqual(Rand(1), Literal(0.5))) - assertEquivalent(EqualTo(nonDeterministic, Literal(3)), - If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral)) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), FalseLiteral) // Handle Null values. assertEquivalent( @@ -141,9 +140,8 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2))) assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), - CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral))) - assertEquivalent(EqualTo(nonDeterministic, Literal(3)), - CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral))) + GreaterThanOrEqual(Rand(1), Literal(0.5))) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), FalseLiteral) // Handle Null values. assertEquivalent( From ae3b284c5dbfc2bf62f2bf6dc6d830db13535755 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 28 Dec 2020 14:03:44 +0800 Subject: [PATCH 5/6] fix --- .../sql/catalyst/optimizer/expressions.scala | 6 ------ .../PushFoldableIntoBranchesSuite.scala | 15 +++++++------ .../optimizer/SimplifyConditionalSuite.scala | 21 +++---------------- 3 files changed, 12 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index ff0dc33f11ba..b33aa529a4be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,8 +475,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue - case If(_, TrueLiteral, TrueLiteral) => TrueLiteral - case If(_, FalseLiteral, FalseLiteral) => FalseLiteral case If(cond, TrueLiteral, FalseLiteral) => cond case If(cond, FalseLiteral, TrueLiteral) => Not(cond) case If(cond, trueValue, falseValue) @@ -486,10 +484,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l) case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l) - case CaseWhen(branches, Some(TrueLiteral)) - if branches.forall(_._2 == TrueLiteral) => TrueLiteral - case CaseWhen(branches, Some(FalseLiteral)) - if branches.forall(_._2 == FalseLiteral) => FalseLiteral case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) => cond case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) => Not(cond) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 447a6f4b2dc8..6f40009d90fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -68,7 +68,8 @@ class PushFoldableIntoBranchesSuite assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), GreaterThanOrEqual(Rand(1), Literal(0.5))) - assertEquivalent(EqualTo(nonDeterministic, Literal(3)), FalseLiteral) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), + If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral)) // Handle Null values. assertEquivalent( @@ -141,7 +142,8 @@ class PushFoldableIntoBranchesSuite assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), GreaterThanOrEqual(Rand(1), Literal(0.5))) - assertEquivalent(EqualTo(nonDeterministic, Literal(3)), FalseLiteral) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral))) // Handle Null values. assertEquivalent( @@ -268,9 +270,10 @@ class PushFoldableIntoBranchesSuite } } - test("SPARK-33884: simplify conditional if all branches are foldable boolean type") { - assertEquivalent(EqualTo( - CaseWhen(Seq((IsNull('a), Literal(0)), (IsNull('b), Literal(1))), Literal(2)), Literal(3)), - FalseLiteral) + test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { + assertEquivalent( + EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)), 'a > 10) + assertEquivalent( + EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)), 'a <= 10) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 3b2aeac1a0ab..7127dd1f757a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -207,10 +207,6 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assertEquivalent( If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), IsNull(UnresolvedAttribute("a"))) - assertEquivalent( - If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, TrueLiteral), TrueLiteral) - assertEquivalent( - If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, FalseLiteral), FalseLiteral) assertEquivalent( If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), @@ -228,21 +224,10 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } } - test("SPARK-33884: simplify conditional if all branches are foldable boolean type") { - Seq(IsNull('a), GreaterThan(Rand(0), 1)).foreach { condition => - assertEquivalent(CaseWhen(Seq((condition, FalseLiteral)), FalseLiteral), FalseLiteral) - assertEquivalent(CaseWhen(Seq((condition, TrueLiteral)), TrueLiteral), TrueLiteral) - assertEquivalent( - CaseWhen(Seq((condition, FalseLiteral), (IsNull('b), FalseLiteral)), FalseLiteral), - FalseLiteral) - assertEquivalent( - CaseWhen(Seq((condition, TrueLiteral), (IsNull('b), TrueLiteral)), TrueLiteral), - TrueLiteral) + test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { + Seq(IsNull('a), GreaterThan(Rand(0), 1.0)).foreach { cond => + assertEquivalent(CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral), cond) } - - assertEquivalent(CaseWhen(Seq((IsNull('a), TrueLiteral)), FalseLiteral), IsNull('a)) - assertEquivalent(CaseWhen(Seq((GreaterThan(Rand(0), 1.0), TrueLiteral)), FalseLiteral), - GreaterThan(Rand(0), 1.0)) assertEquivalent(CaseWhen(Seq((IsNull('a), FalseLiteral)), TrueLiteral), IsNotNull('a)) assertEquivalent(CaseWhen(Seq((GreaterThan(Rand(0), 1.0), FalseLiteral)), TrueLiteral), LessThanOrEqual(Rand(0), 1.0)) From d3b072e2d1db3aef0ea4ab80767ab739502f7e81 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 29 Dec 2020 09:45:33 +0800 Subject: [PATCH 6/6] fix --- .../sql/catalyst/optimizer/expressions.scala | 6 ++- .../PushFoldableIntoBranchesSuite.scala | 6 ++- .../optimizer/SimplifyConditionalSuite.scala | 37 ++++++++++++++++--- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 5b4041745dd5..6c5dec133d2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -486,8 +486,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l) case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l) - case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) => cond - case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) => Not(cond) + case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) => + if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond + case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) => + if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond) case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index dc82656db373..0d5218ac629e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -272,8 +272,10 @@ class PushFoldableIntoBranchesSuite test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { assertEquivalent( - EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)), 'a > 10) + EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)), + 'a > 10 <=> TrueLiteral) assertEquivalent( - EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)), 'a <= 10) + EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)), + Not('a > 10 <=> TrueLiteral)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 85d0ec79671a..f3edd70bcfb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -245,11 +245,38 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { - Seq(IsNull('a), GreaterThan(Rand(0), 1.0)).foreach { cond => - assertEquivalent(CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral), cond) + // verify the boolean equivalence of all transformations involved + val fields = Seq( + 'cond.boolean.notNull, + 'cond_nullable.boolean, + 'a.boolean, + 'b.boolean + ) + val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) } + + val exprs = Seq( + // actual expressions of the transformations: original -> transformed + CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral) -> cond, + CaseWhen(Seq((cond, FalseLiteral)), TrueLiteral) -> !cond, + CaseWhen(Seq((cond_nullable, TrueLiteral)), FalseLiteral) -> (cond_nullable <=> true), + CaseWhen(Seq((cond_nullable, FalseLiteral)), TrueLiteral) -> (!(cond_nullable <=> true))) + + // check plans + for ((originalExpr, expectedExpr) <- exprs) { + assertEquivalent(originalExpr, expectedExpr) + } + + // check evaluation + val binaryBooleanValues = Seq(true, false) + val ternaryBooleanValues = Seq(true, false, null) + for (condVal <- binaryBooleanValues; + condNullableVal <- ternaryBooleanValues; + aVal <- ternaryBooleanValues; + bVal <- ternaryBooleanValues; + (originalExpr, expectedExpr) <- exprs) { + val inputRow = create_row(condVal, condNullableVal, aVal, bVal) + val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow) + checkEvaluation(originalExpr, optimizedVal, inputRow) } - assertEquivalent(CaseWhen(Seq((IsNull('a), FalseLiteral)), TrueLiteral), IsNotNull('a)) - assertEquivalent(CaseWhen(Seq((GreaterThan(Rand(0), 1.0), FalseLiteral)), TrueLiteral), - LessThanOrEqual(Rand(0), 1.0)) } }