Skip to content

Commit 4ac3d1e

Browse files
committed
prune casewhen branch
1 parent 0a0f68b commit 4ac3d1e

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,29 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
416416
// these branches can be pruned away
417417
val (h, t) = branches.span(_._1 != TrueLiteral)
418418
CaseWhen( h :+ t.head, None)
419+
420+
case e @ CaseWhen(branches, _) =>
421+
val newBranches = branches.foldLeft(List[(Expression, Expression)]()) {
422+
case (newBranches, branch) =>
423+
if (newBranches.exists(_._1.semanticEquals(branch._1))) {
424+
// If a condition in a branch is previously seen, this branch can be pruned.
425+
// TODO: In fact, if a condition is a sub-condition of the previous one,
426+
// TODO: it can be pruned. This is less strict and can be implemented
427+
// TODO: by decomposing seen conditions.
428+
newBranches
429+
} else if (newBranches.nonEmpty && newBranches.last._2.semanticEquals(branch._2)) {
430+
// If the outputs of two adjacent branches are the same, two branches can be combined.
431+
newBranches.take(newBranches.length - 1)
432+
.:+((Or(newBranches.last._1, branch._1), newBranches.last._2))
433+
} else {
434+
newBranches.:+(branch)
435+
}
436+
}
437+
if (newBranches.length < branches.length) {
438+
e.copy(branches = newBranches)
439+
} else {
440+
e
441+
}
419442
}
420443
}
421444
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4646
private val unreachableBranch = (FalseLiteral, Literal(20))
4747
private val nullBranch = (Literal.create(null, NullType), Literal(30))
4848

49-
private val testRelation = LocalRelation('a.int)
49+
val isNotNullCond = IsNotNull(UnresolvedAttribute("a"))
50+
val isNullCond = IsNull(UnresolvedAttribute("a"))
51+
val notCond = Not(UnresolvedAttribute("c"))
5052

5153
test("simplify if") {
5254
assertEquivalent(
@@ -122,4 +124,54 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
122124
None),
123125
CaseWhen(normalBranch :: trueBranch :: Nil, None))
124126
}
127+
128+
test("remove a branch in CaseWhen if a cond in this branch is previously seen") {
129+
assertEquivalent(
130+
CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
131+
(GreaterThan(Rand(0), Literal(0.5)), Literal(2)) ::
132+
(NonFoldableLiteral(true), Literal(3)) ::
133+
(LessThan(Rand(1), Literal(0.5)), Literal(4)) ::
134+
(NonFoldableLiteral(true), Literal(5)) ::
135+
(NonFoldableLiteral(false), Literal(6)) ::
136+
(NonFoldableLiteral(false), Literal(7)) ::
137+
Nil,
138+
None),
139+
CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
140+
(GreaterThan(Rand(0), Literal(0.5)), Literal(2)) ::
141+
(NonFoldableLiteral(true), Literal(3)) ::
142+
(LessThan(Rand(1), Literal(0.5)), Literal(4)) ::
143+
(NonFoldableLiteral(false), Literal(6)) ::
144+
Nil,
145+
None)
146+
)
147+
}
148+
149+
test("combine two adjacent branches in CaseWhen if they have the same output values") {
150+
assertEquivalent(
151+
CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
152+
(NonFoldableLiteral(true), Literal(1)) ::
153+
(LessThan(Rand(1), Literal(0.5)), Literal(3)) ::
154+
(NonFoldableLiteral(true), Literal(3)) ::
155+
(NonFoldableLiteral(false), Literal(4)) ::
156+
Nil,
157+
None),
158+
CaseWhen((Or(GreaterThan(Rand(0), Literal(0.5)), NonFoldableLiteral(true)), Literal(1)) ::
159+
(Or(LessThan(Rand(1), Literal(0.5)), NonFoldableLiteral(true)), Literal(3)) ::
160+
(NonFoldableLiteral(false), Literal(4)) ::
161+
Nil,
162+
None)
163+
)
164+
165+
// The first two conditions can be combined, and then the optimizer uses rule in `Or`
166+
// to be optimized into `TrueLiteral`. Thus, the entire `CaseWhen` can be removed.
167+
assertEquivalent(
168+
CaseWhen((UnresolvedAttribute("a"), Literal(1)) ::
169+
(Not(UnresolvedAttribute("a")), Literal(1)) ::
170+
(LessThan(Rand(1), Literal(0.5)), Literal(3)) ::
171+
(NonFoldableLiteral(true), Literal(4)) ::
172+
(NonFoldableLiteral(false), Literal(5)) ::
173+
Nil,
174+
None),
175+
Literal(1))
176+
}
125177
}

0 commit comments

Comments
 (0)