Skip to content

Commit 8f6f91d

Browse files
committed
Address comment. Consider more general case.
1 parent 9e1c315 commit 8f6f91d

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
8989
CombineFilters,
9090
CombineLimits,
9191
CombineUnions,
92-
// Push down Filters again after combination
93-
PushDownPredicate,
9492
// Constant folding and strength reduction
9593
NullPropagation,
9694
FoldablePropagation,
@@ -588,15 +586,33 @@ object CombineUnions extends Rule[LogicalPlan] {
588586
* one conjunctive predicate.
589587
*/
590588
object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
589+
private def toCNF(predicate: Expression): Expression = {
590+
val disjunctives = splitDisjunctivePredicates(predicate)
591+
var finalPredicates = splitConjunctivePredicates(disjunctives.head)
592+
disjunctives.tail.foreach { cond =>
593+
val predicates = new ArrayBuffer[Expression]()
594+
splitConjunctivePredicates(cond).map { p =>
595+
predicates ++= finalPredicates.map(Or(_, p))
596+
}
597+
finalPredicates = predicates.toSeq
598+
}
599+
finalPredicates.reduce(And)
600+
}
591601
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
592602
case Filter(fc, nf @ Filter(nc, grandChild)) =>
593-
(ExpressionSet(splitConjunctivePredicates(fc)) --
594-
ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match {
603+
val fcCNF = toCNF(fc)
604+
val ncCNF = toCNF(nc)
605+
val combinedFilter = (ExpressionSet(splitConjunctivePredicates(fcCNF)) --
606+
ExpressionSet(splitConjunctivePredicates(ncCNF))).reduceOption(And) match {
595607
case Some(ac) =>
596608
Filter(And(nc, ac), grandChild)
597609
case None =>
598610
nf
599611
}
612+
// [[Filter]] can't pushdown through another [[Filter]]. Once they are combined,
613+
// [[BooleanSimplification]] rule will possibly simplify the predicate to the form that
614+
// will not be able to pushdown. So we pushdown the combined [[Filter]] immediately.
615+
PushDownPredicate(combinedFilter)
600616
}
601617
}
602618

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class FilterPushdownSuite extends PlanTest {
3636
Batch("Filter Pushdown", FixedPoint(10),
3737
PushDownPredicate,
3838
CombineFilters,
39-
PushDownPredicate,
4039
BooleanSimplification,
4140
PushPredicateThroughJoin,
4241
CollapseProject) :: Nil
@@ -193,6 +192,26 @@ class FilterPushdownSuite extends PlanTest {
193192
comparePlans(optimized, correctAnswer)
194193
}
195194

195+
test("disjunctive predicates which are able to pushdown should be pushed down after converted") {
196+
// (('a === 2) || ('c > 10 || 'a === 3)) can't be pushdown due to the disjunctive form.
197+
// However, its conjunctive normal form can be pushdown.
198+
val originalQuery = testRelation
199+
.select('a, 'b, ('c + 1) as 'cc)
200+
.groupBy('a)('a, count('cc) as 'c)
201+
.where('c > 10)
202+
.where(('a === 2) || ('c > 10 && 'a === 3))
203+
204+
val optimized = Optimize.execute(originalQuery.analyze)
205+
val correctAnswer =
206+
testRelation
207+
.where('a === 2 || 'a === 3)
208+
.select('a, 'b, ('c + 1) as 'cc)
209+
.groupBy('a)('a, count('cc) as 'c)
210+
.where('c > 10).analyze
211+
212+
comparePlans(optimized, correctAnswer)
213+
}
214+
196215
test("joins: push to either side") {
197216
val x = testRelation.subquery('x)
198217
val y = testRelation.subquery('y)

0 commit comments

Comments
 (0)