From 09edff51ad5bce7b6f2b4b982d6df550a37d502e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 19 Dec 2020 16:42:15 +0800 Subject: [PATCH 1/7] Replace None of elseValue inside CaseWhen if all branches are FalseLiteral --- .../ReplaceNullWithFalseInPredicate.scala | 1 + ...ReplaceNullWithFalseInPredicateSuite.scala | 83 ++++++++++++++++--- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 4a71dba663b3..070c86fe3cda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -94,6 +94,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { replaceNullWithFalse(cond) -> replaceNullWithFalse(value) } val newElseValue = cw.elseValue.map(replaceNullWithFalse) + .orElse(if (newBranches.forall(_._2 == FalseLiteral)) Some(FalseLiteral) else None) CaseWhen(newBranches, newElseValue) case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 00433a549057..e9d3d01b56c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, EqualTo, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} @@ -38,6 +38,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { ConstantFolding, BooleanSimplification, SimplifyConditionals, + PushFoldableIntoBranches, ReplaceNullWithFalseInPredicate) :: Nil } @@ -222,10 +223,17 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(null, IntegerType), Literal(3)), FalseLiteral) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + val expectedCond = If( + UnresolvedAttribute("i") > Literal(10), + If( + UnresolvedAttribute("i") === Literal(15), + FalseLiteral, + TrueLiteral), + FalseLiteral) + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("inability to replace null in non-boolean values of CaseWhen") { @@ -238,10 +246,18 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) val condition = CaseWhen(branches) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + + val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> If( + CaseWhen( + Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), + FalseLiteral), + TrueLiteral, + FalseLiteral))) + + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -252,10 +268,17 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(3)), TrueLiteral, FalseLiteral) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + val expectedCond = If( + If( + UnresolvedAttribute("i") === Literal(15), + FalseLiteral, + TrueLiteral), + TrueLiteral, + FalseLiteral) + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("replace null in If used as a join condition") { @@ -375,6 +398,40 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + test("replace None of elseValue inside CaseWhen if all branches are FalseLiteral") { + val allFalseBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val allFalseCond = CaseWhen(allFalseBranches) + + val nonAllFalseBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val nonAllFalseCond = CaseWhen(nonAllFalseBranches) + + testFilter(allFalseCond, FalseLiteral) + testJoin(allFalseCond, FalseLiteral) + testDelete(allFalseCond, FalseLiteral) + testUpdate(allFalseCond, FalseLiteral) + + testFilter(nonAllFalseCond, nonAllFalseCond) + testJoin(nonAllFalseCond, nonAllFalseCond) + testDelete(nonAllFalseCond, nonAllFalseCond) + testUpdate(nonAllFalseCond, nonAllFalseCond) + } + + test("replace None of elseValue inside CaseWhen with PushFoldableIntoBranches") { + val allFalseBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> Literal("a"), + (UnresolvedAttribute("i") > Literal(40)) -> Literal("b")) + val allFalseCond = EqualTo(CaseWhen(allFalseBranches), "c") + + testFilter(allFalseCond, FalseLiteral) + testJoin(allFalseCond, FalseLiteral) + testDelete(allFalseCond, FalseLiteral) + testUpdate(allFalseCond, FalseLiteral) + } + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { test((rel, exp) => rel.where(exp), originalCond, expectedCond) } From 3c5f3dab42203e309f1cf500682d2622e32eebee Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 21 Dec 2020 21:38:32 +0800 Subject: [PATCH 2/7] Fix --- ...ReplaceNullWithFalseInPredicateSuite.scala | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 531abd249316..0d226cab2cb5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, EqualTo, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, EqualTo, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Not, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} @@ -225,10 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) val expectedCond = If( UnresolvedAttribute("i") > Literal(10), - If( - UnresolvedAttribute("i") === Literal(15), - FalseLiteral, - TrueLiteral), + Not(UnresolvedAttribute("i") === Literal(15)), FalseLiteral) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) @@ -246,14 +243,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) val condition = CaseWhen(branches) - - val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> If( + val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> CaseWhen( Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), - FalseLiteral), - TrueLiteral, - FalseLiteral))) - + FalseLiteral))) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -268,13 +261,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(3)), TrueLiteral, FalseLiteral) - val expectedCond = If( - If( - UnresolvedAttribute("i") === Literal(15), - FalseLiteral, - TrueLiteral), - TrueLiteral, - FalseLiteral) + val expectedCond = Not(UnresolvedAttribute("i") === Literal(15)) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) From b837e37b62355aeb9fdd9b67417ca2b43beba0dd Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 21 Dec 2020 23:10:13 +0800 Subject: [PATCH 3/7] fix --- .../spark/sql/catalyst/optimizer/expressions.scala | 3 +++ .../optimizer/PushFoldableIntoBranchesSuite.scala | 9 +++++---- .../optimizer/ReplaceNullWithFalseInPredicateSuite.scala | 7 ++++--- .../catalyst/optimizer/SimplifyConditionalSuite.scala | 8 +++++++- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index ac2caaeb1535..c92a9809350a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -525,6 +525,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + + case e @ CaseWhen(_, elseValue) if elseValue.isEmpty => + e.copy(elseValue = Some(Literal.create(null, e.dataType))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index de4f4be8ec33..521983f7f79b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -122,7 +122,7 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Literal.create(null, BooleanType))) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), @@ -130,11 +130,12 @@ class PushFoldableIntoBranchesSuite // Push down at most one branch is not foldable expressions. assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), - CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), + Literal.create(null, BooleanType))) assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), - EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) + EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), Literal.create(null, IntegerType)), Literal(1))) assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), - EqualTo(CaseWhen(Seq((a, b)), None), Literal(1))) + CaseWhen(Seq((a, b === Literal(1))), Literal.create(null, BooleanType))) // Push down non-deterministic expressions. val nonDeterministic = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 0d226cab2cb5..7fbe73153771 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -115,7 +115,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) + val expectedCond = CaseWhen(expectedBranches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) @@ -246,7 +246,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> CaseWhen( Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), - FalseLiteral))) + FalseLiteral)), + FalseLiteral) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -394,7 +395,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val nonAllFalseBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val nonAllFalseCond = CaseWhen(nonAllFalseBranches) + val nonAllFalseCond = CaseWhen(nonAllFalseBranches, FalseLiteral) testFilter(allFalseCond, FalseLiteral) testJoin(allFalseCond, FalseLiteral) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 328fc107e1c1..1a3201a299e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -83,7 +83,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P // i.e. removing branches whose conditions are always false assertEquivalent( CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), - CaseWhen(normalBranch :: Nil, None)) + CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType))) } test("remove entire CaseWhen if only the else branch is reachable") { @@ -215,4 +215,10 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) } + + test("SPARK-33847: Replace None of elseValue inside CaseWhen to null literal") { + assertEquivalent( + CaseWhen(normalBranch :: Nil, None), + CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType))) + } } From 81c38f8552b2b52aaea1ee88cb0ac14e606e3da7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 22 Dec 2020 23:21:38 +0800 Subject: [PATCH 4/7] fix --- .../sql/catalyst/optimizer/expressions.scala | 5 +++-- .../PushFoldableIntoBranchesSuite.scala | 9 ++++---- ...ReplaceNullWithFalseInPredicateSuite.scala | 22 +++++++++++++------ .../optimizer/SimplifyConditionalSuite.scala | 22 +++++++++++++++---- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6d7d70e447c5..c0c8ddbe206f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -526,8 +526,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } - case e @ CaseWhen(_, elseValue) if elseValue.isEmpty => - e.copy(elseValue = Some(Literal.create(null, e.dataType))) + case e @ CaseWhen(branches, elseOpt) + if elseOpt.isEmpty && branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) => + Literal(null, e.dataType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 9da9a8a40444..02307a52ebb8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -122,7 +122,7 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Literal.create(null, BooleanType))) + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), @@ -130,12 +130,11 @@ class PushFoldableIntoBranchesSuite // Push down at most one branch is not foldable expressions. assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), - CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), - Literal.create(null, BooleanType))) + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), - EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), Literal.create(null, IntegerType)), Literal(1))) + EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), - CaseWhen(Seq((a, b === Literal(1))), Literal.create(null, BooleanType))) + EqualTo(CaseWhen(Seq((a, b)), None), Literal(1))) // Push down non-deterministic expressions. val nonDeterministic = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 7fbe73153771..5f51f0ddf532 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -115,7 +115,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches, FalseLiteral) + val expectedCond = CaseWhen(expectedBranches) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) @@ -241,13 +241,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(2) === nestedCaseWhen, TrueLiteral, FalseLiteral) - val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) - val condition = CaseWhen(branches) + val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> - CaseWhen( - Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), - FalseLiteral)), - FalseLiteral) + CaseWhen(Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), FalseLiteral))) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -408,6 +404,18 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testUpdate(nonAllFalseCond, nonAllFalseCond) } + test("replace None of elseValue inside CaseWhen if all branches are null") { + val allFalseBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> Literal.create(null, BooleanType), + (UnresolvedAttribute("i") > Literal(40)) -> Literal.create(null, BooleanType)) + val allFalseCond = CaseWhen(allFalseBranches) + + testFilter(allFalseCond, FalseLiteral) + testJoin(allFalseCond, FalseLiteral) + testDelete(allFalseCond, FalseLiteral) + testUpdate(allFalseCond, FalseLiteral) + } + test("replace None of elseValue inside CaseWhen with PushFoldableIntoBranches") { val allFalseBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> Literal("a"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 1a3201a299e5..71ade84a1003 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -83,7 +83,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P // i.e. removing branches whose conditions are always false assertEquivalent( CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), - CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType))) + CaseWhen(normalBranch :: Nil, None)) } test("remove entire CaseWhen if only the else branch is reachable") { @@ -216,9 +216,23 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) } - test("SPARK-33847: Replace None of elseValue inside CaseWhen to null literal") { + test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { assertEquivalent( - CaseWhen(normalBranch :: Nil, None), - CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType))) + CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, + None), + Literal.create(null, IntegerType)) + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), 1), Literal.create(null, IntegerType)) :: Nil, + None), + Literal.create(null, IntegerType)) + + assertEquivalent( + CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, + Some(Literal.create(null, IntegerType))), + Literal.create(null, IntegerType)) + assertEquivalent( + CaseWhen((GreaterThan('a, 1), Literal(20)) :: (GreaterThan('b, 1), Literal(20)) :: Nil, + Some(Literal(20))), + Literal(20)) } } From 7f3529db3573edc2aa3c9271f66338302d1e1d29 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 23 Dec 2020 08:42:51 +0800 Subject: [PATCH 5/7] fix --- .../ReplaceNullWithFalseInPredicate.scala | 9 ++++++--- .../sql/catalyst/optimizer/expressions.scala | 4 ++-- .../PushFoldableIntoBranchesSuite.scala | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 070c86fe3cda..92401131e8b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -93,9 +93,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { val newBranches = cw.branches.map { case (cond, value) => replaceNullWithFalse(cond) -> replaceNullWithFalse(value) } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - .orElse(if (newBranches.forall(_._2 == FalseLiteral)) Some(FalseLiteral) else None) - CaseWhen(newBranches, newElseValue) + if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) { + FalseLiteral + } else { + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + } case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) case e if e.dataType == BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index c0c8ddbe206f..f01df5e5e676 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -526,8 +526,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } - case e @ CaseWhen(branches, elseOpt) - if elseOpt.isEmpty && branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) => + case e @ CaseWhen(branches, None) + if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) => Literal(null, e.dataType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 02307a52ebb8..967d980c7730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -258,4 +258,22 @@ class PushFoldableIntoBranchesSuite EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")), CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) } + + test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)), + Literal.create(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, IntegerType)))), + Literal(2)), + Literal.create(null, BooleanType)) + + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal("str")))).cast(IntegerType), Literal(2)), + Literal.create(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType), + Literal(2)), + Literal.create(null, BooleanType)) + } } From 0933019714d83bab40427b2f5594d710d5416d66 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 23 Dec 2020 17:05:33 +0800 Subject: [PATCH 6/7] fix --- .../PushFoldableIntoBranchesSuite.scala | 23 ++++------- ...ReplaceNullWithFalseInPredicateSuite.scala | 41 ++++++------------- .../optimizer/SimplifyConditionalSuite.scala | 13 +++--- ...ullWithFalseInPredicateEndToEndSuite.scala | 21 +++++++--- 4 files changed, 42 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 967d980c7730..2d826e7b55a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -260,20 +260,13 @@ class PushFoldableIntoBranchesSuite } test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)), - Literal.create(null, BooleanType)) - assertEquivalent( - EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, IntegerType)))), - Literal(2)), - Literal.create(null, BooleanType)) - - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal("str")))).cast(IntegerType), Literal(2)), - Literal.create(null, BooleanType)) - assertEquivalent( - EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType), - Literal(2)), - Literal.create(null, BooleanType)) + Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition => + assertEquivalent( + EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)), + Literal.create(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)), + Literal.create(null, BooleanType)) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 5f51f0ddf532..f49e6921fd46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, EqualTo, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Not, Or, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} @@ -38,7 +38,6 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { ConstantFolding, BooleanSimplification, SimplifyConditionals, - PushFoldableIntoBranches, ReplaceNullWithFalseInPredicate) :: Nil } @@ -223,14 +222,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(null, IntegerType), Literal(3)), FalseLiteral) - val expectedCond = If( - UnresolvedAttribute("i") > Literal(10), - Not(UnresolvedAttribute("i") === Literal(15)), - FalseLiteral) - testFilter(originalCond = condition, expectedCond = expectedCond) - testJoin(originalCond = condition, expectedCond = expectedCond) - testDelete(originalCond = condition, expectedCond = expectedCond) - testUpdate(originalCond = condition, expectedCond = expectedCond) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) + testUpdate(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean values of CaseWhen") { @@ -242,8 +237,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { TrueLiteral, FalseLiteral) val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) - val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> - CaseWhen(Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), FalseLiteral))) + val expectedCond = + CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen))) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -258,7 +253,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(3)), TrueLiteral, FalseLiteral) - val expectedCond = Not(UnresolvedAttribute("i") === Literal(15)) + val expectedCond = Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -405,23 +403,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { } test("replace None of elseValue inside CaseWhen if all branches are null") { - val allFalseBranches = Seq( + val allNullBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> Literal.create(null, BooleanType), (UnresolvedAttribute("i") > Literal(40)) -> Literal.create(null, BooleanType)) - val allFalseCond = CaseWhen(allFalseBranches) - - testFilter(allFalseCond, FalseLiteral) - testJoin(allFalseCond, FalseLiteral) - testDelete(allFalseCond, FalseLiteral) - testUpdate(allFalseCond, FalseLiteral) - } - - test("replace None of elseValue inside CaseWhen with PushFoldableIntoBranches") { - val allFalseBranches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> Literal("a"), - (UnresolvedAttribute("i") > Literal(40)) -> Literal("b")) - val allFalseCond = EqualTo(CaseWhen(allFalseBranches), "c") - + val allFalseCond = CaseWhen(allNullBranches) testFilter(allFalseCond, FalseLiteral) testJoin(allFalseCond, FalseLiteral) testDelete(allFalseCond, FalseLiteral) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 71ade84a1003..a4447e7afc17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -217,14 +217,11 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { - assertEquivalent( - CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, - None), - Literal.create(null, IntegerType)) - assertEquivalent( - CaseWhen((GreaterThan(Rand(0), 1), Literal.create(null, IntegerType)) :: Nil, - None), - Literal.create(null, IntegerType)) + Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition => + assertEquivalent( + CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None), + Literal.create(null, IntegerType)) + } assertEquivalent( CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index bdbb741f24bc..739b4052ee90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -27,6 +27,12 @@ import org.apache.spark.sql.types.BooleanType class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSparkSession { import testImplicits._ + private def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = + df.queryExecution.executedPlan match { + case s: LocalTableScanExec => assert(s.rows.isEmpty) + case p => fail(s"$p is not LocalTableScanExec") + } + test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { withTable("t1", "t2") { Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1") @@ -64,11 +70,6 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) } - - def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = df.queryExecution.executedPlan match { - case s: LocalTableScanExec => assert(s.rows.isEmpty) - case p => fail(s"$p is not LocalTableScanExec") - } } test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { @@ -112,4 +113,14 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared assertNoLiteralNullInPlan(q3) } } + + test("SPARK-33847: replace None of elseValue inside CaseWhen to FalseLiteral") { + withTable("t1") { + Seq((1, 1), (2, 2)).toDF("a", "b").write.saveAsTable("t1") + val t1 = spark.table("t1") + val q1 = t1.filter("(CASE WHEN a > 1 THEN 1 END) = 0") + checkAnswer(q1, Seq.empty) + checkPlanIsEmptyLocalScan(q1) + } + } } From d20fedeac9fcca59faea28da418f546d4fd07e20 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 23 Dec 2020 17:56:28 +0800 Subject: [PATCH 7/7] remove --- .../catalyst/optimizer/SimplifyConditionalSuite.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index a4447e7afc17..1876be21dea4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -222,14 +222,5 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None), Literal.create(null, IntegerType)) } - - assertEquivalent( - CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, - Some(Literal.create(null, IntegerType))), - Literal.create(null, IntegerType)) - assertEquivalent( - CaseWhen((GreaterThan('a, 1), Literal(20)) :: (GreaterThan('b, 1), Literal(20)) :: Nil, - Some(Literal(20))), - Literal(20)) } }