diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4696699337c9..cf25c3bcfa8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -385,6 +385,40 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case _ => false } + // If a condition in a branch is previously seen, this branch can be pruned. + // TODO: In fact, if a condition is a sub-condition of the previous one, + // TODO: it can be pruned. This is less strict and can be implemented + // TODO: by decomposing the seen conditions. + private def pruneSeenBranches(branches: Seq[(Expression, Expression)]) + : Option[Seq[(Expression, Expression)]] = { + val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { + case (newBranches, branch) if newBranches.exists(_._1.semanticEquals(branch._1)) => + newBranches + case (newBranches, branch) => newBranches += branch + } + if (newBranches.length < branches.length) { + Some(newBranches) + } else { + None + } + } + + // If the outputs of two adjacent branches are the same, two branches can be combined. + private def combineAdjacentBranches(branches: Seq[(Expression, Expression)]) + : Option[Seq[(Expression, Expression)]] = { + val newBranches = branches.foldLeft(new ArrayBuffer[(Expression, Expression)]()) { + case (newBranches, branch) + if newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2) => + newBranches.init += ((Or(newBranches.last._1, branch._1), newBranches.last._2)) + case (newBranches, branch) => newBranches += branch + } + if (newBranches.length < branches.length) { + Some(newBranches) + } else { + None + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue @@ -416,6 +450,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) + + case e @ CaseWhen(branches, _) if pruneSeenBranches(branches).nonEmpty => + e.copy(branches = pruneSeenBranches(branches).get) + + case e @ CaseWhen(branches, _) if combineAdjacentBranches(branches).nonEmpty => + e.copy(branches = combineAdjacentBranches(branches).get) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index e210874a55d8..afdbee9a59f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -46,7 +46,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) - private val testRelation = LocalRelation('a.int) + val isNotNullCond = IsNotNull(UnresolvedAttribute("a")) + val isNullCond = IsNull(UnresolvedAttribute("a")) + val notCond = Not(UnresolvedAttribute("c")) test("simplify if") { assertEquivalent( @@ -122,4 +124,54 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { None), CaseWhen(normalBranch :: trueBranch :: Nil, None)) } + + test("remove a branch in CaseWhen if a cond in this branch is previously seen") { + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (GreaterThan(Rand(0), Literal(0.5)), Literal(2)) :: + (NonFoldableLiteral(true), Literal(3)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(4)) :: + (NonFoldableLiteral(true), Literal(5)) :: + (NonFoldableLiteral(false), Literal(6)) :: + (NonFoldableLiteral(false), Literal(7)) :: + Nil, + None), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (GreaterThan(Rand(0), Literal(0.5)), Literal(2)) :: + (NonFoldableLiteral(true), Literal(3)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(4)) :: + (NonFoldableLiteral(false), Literal(6)) :: + Nil, + None) + ) + } + + test("combine two adjacent branches in CaseWhen if they have the same output values") { + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(3)) :: + (NonFoldableLiteral(true), Literal(3)) :: + (NonFoldableLiteral(false), Literal(4)) :: + Nil, + None), + CaseWhen((Or(GreaterThan(Rand(0), Literal(0.5)), NonFoldableLiteral(true)), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(3)) :: + (NonFoldableLiteral(false), Literal(4)) :: + Nil, + None) + ) + + // The first two conditions can be combined, and then the optimizer uses rule in `Or` + // to be optimized into `TrueLiteral`. Thus, the entire `CaseWhen` can be removed. + assertEquivalent( + CaseWhen((UnresolvedAttribute("a"), Literal(1)) :: + (Not(UnresolvedAttribute("a")), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(3)) :: + (NonFoldableLiteral(true), Literal(4)) :: + (NonFoldableLiteral(false), Literal(5)) :: + Nil, + None), + Literal(1)) + } }