From cf95639d6ea9f5027a4a4612088e5cc67c6d3470 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Dec 2014 13:42:39 +0800 Subject: [PATCH 1/5] Adds an optimization rule for filter normalization --- .../sql/catalyst/optimizer/Optimizer.scala | 26 +++++++++ .../optimizer/NormalizeFiltersSuite.scala | 53 +++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 806c1394eb15..8e34f52fca1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, + NormalizeFilters, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -347,6 +348,31 @@ object CombineFilters extends Rule[LogicalPlan] { } } +/** + * Normalizes conjuctions and disjunctions to eliminate common factors. + */ +object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(predicate: Predicate, child) => + f.copy(condition = normalizedPredicate(predicate).reduce(And)) + } + + def normalizedPredicate(predicate: Expression): Seq[Expression] = predicate match { + // a || a => a + case Or(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil + // a && a => a + case And(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil + // (a || b || c || ...) && (a || b || d || ...) => a && b && (c || d || ...) + case Or(lhs, rhs) => + val lhsSet = splitConjunctivePredicates(lhs).toSet + val rhsSet = splitConjunctivePredicates(rhs).toSet + val commonPredicates = lhsSet & rhsSet + val otherPredicates = (lhsSet | rhsSet) &~ commonPredicates + otherPredicates.reduceOption(Or).getOrElse(Literal(true)) :: commonPredicates.toList + case _ => predicate :: Nil + } +} + /** * Removes filters that can be evaluated trivially. This is done either by eliding the filter for * cases where it will always evaluate to `true`, or substituting a dummy empty relation when the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala new file mode 100644 index 000000000000..865ab6dc1d38 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala @@ -0,0 +1,53 @@ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class NormalizeFiltersSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq( + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators), + Batch("NormalizeFilters", FixedPoint(100), + NormalizeFilters, + SimplifyFilters)) + } + + val relation = LocalRelation('a.int, 'b.int, 'c.string) + + def checkExpression(original: Expression, expected: Expression): Unit = { + val actual = Optimize(relation.where(original)).collect { case f: Filter => f.condition }.head + val result = (actual, expected) match { + case (And(l1, r1), And(l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1) + case (Or (l1, r1), Or (l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1) + case (lhs, rhs) => lhs fastEquals rhs + } + + assert(result, s"$actual isn't equivalent to $expected") + } + + test("a && a => a") { + checkExpression('a === 1 && 'a === 1, 'a === 1) + } + + test("a || a => a") { + checkExpression('a === 1 || 'a === 1, 'a === 1) + } + + test("(a && b) || (a && c)") { + checkExpression( + ('a === 1 && 'a < 10) || ('a > 2 && 'a === 1), + ('a === 1) && ('a < 10 || 'a > 2)) + + checkExpression( + ('a < 1 && 'b > 2 && 'c.isNull) || ('a < 1 && 'c === "hello" && 'b > 2), + ('c.isNull || 'c === "hello") && 'a < 1 && 'b > 2) + } +} From 2abbf8ed77d617675a59c0560d8774617f9b1bb7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Dec 2014 14:08:49 +0800 Subject: [PATCH 2/5] Forgot our sacred Apache licence header... --- .../optimizer/NormalizeFiltersSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala index 865ab6dc1d38..3d2ba0802316 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators From 5d54349a99fb0521c92f9cd409d59e0945a8203b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Dec 2014 14:11:53 +0800 Subject: [PATCH 3/5] Fixes typo in comment --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8e34f52fca1b..536054607957 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -362,7 +362,7 @@ object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper { case Or(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil // a && a => a case And(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil - // (a || b || c || ...) && (a || b || d || ...) => a && b && (c || d || ...) + // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...) case Or(lhs, rhs) => val lhsSet = splitConjunctivePredicates(lhs).toSet val rhsSet = splitConjunctivePredicates(rhs).toSet From 4ab3a58fe8a86bc8f08fa0007d88022b3021e0e6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Dec 2014 15:44:48 +0800 Subject: [PATCH 4/5] Fixes test failure, adds more tests --- .../sql/catalyst/expressions/predicates.scala | 9 +++++++- .../sql/catalyst/optimizer/Optimizer.scala | 22 ++++++++++++------- .../optimizer/NormalizeFiltersSuite.scala | 4 +++- .../columnar/PartitionBatchPruningSuite.scala | 10 ++++++--- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 94b6fb084d38..cb5ff6795986 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType @@ -48,6 +47,14 @@ trait PredicateHelper { } } + protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case Or(cond1, cond2) => + splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2) + case other => other :: Nil + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when is is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 536054607957..c4db0f108011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -349,7 +349,7 @@ object CombineFilters extends Rule[LogicalPlan] { } /** - * Normalizes conjuctions and disjunctions to eliminate common factors. + * Normalizes conjunctions and disjunctions to eliminate common factors. */ object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -358,17 +358,23 @@ object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper { } def normalizedPredicate(predicate: Expression): Seq[Expression] = predicate match { - // a || a => a - case Or(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil - // a && a => a - case And(lhs, rhs) if lhs fastEquals rhs => lhs :: Nil + // a && a && a ... => a + case p @ And(e, _) if splitConjunctivePredicates(p).distinct.size == 1 => e :: Nil + + // a || a || a ... => a + case p @ Or(e, _) if splitDisjunctivePredicates(p).distinct.size == 1 => e :: Nil + // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...) case Or(lhs, rhs) => val lhsSet = splitConjunctivePredicates(lhs).toSet val rhsSet = splitConjunctivePredicates(rhs).toSet - val commonPredicates = lhsSet & rhsSet - val otherPredicates = (lhsSet | rhsSet) &~ commonPredicates - otherPredicates.reduceOption(Or).getOrElse(Literal(true)) :: commonPredicates.toList + val common = lhsSet.intersect(rhsSet) + + (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And)) + .reduceOption(Or) + .map(_ :: common.toList) + .getOrElse(common.toList) + case _ => predicate :: Nil } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala index 3d2ba0802316..85e2682a11c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala @@ -52,13 +52,15 @@ class NormalizeFiltersSuite extends PlanTest { test("a && a => a") { checkExpression('a === 1 && 'a === 1, 'a === 1) + checkExpression('a === 1 && 'a === 1 && 'a === 1, 'a === 1) } test("a || a => a") { checkExpression('a === 1 || 'a === 1, 'a === 1) + checkExpression('a === 1 || 'a === 1 || 'a === 1, 'a === 1) } - test("(a && b) || (a && c)") { + test("(a && b) || (a && c) => a && (b || c)") { checkExpression( ('a === 1 && 'a < 10) || ('a > 2 && 'a === 1), ('a === 1) && ('a < 10 || 'a > 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 82afa31a99a7..1915c25392f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -105,7 +105,9 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be test(query) { val schemaRdd = sql(query) - assertResult(expectedQueryResult.toArray, "Wrong query result") { + val queryExecution = schemaRdd.queryExecution + + assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { schemaRdd.collect().map(_.head).toArray } @@ -113,8 +115,10 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head - assert(readBatches === expectedReadBatches, "Wrong number of read batches") - assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + assert(readBatches === expectedReadBatches, s"Wrong number of read batches: $queryExecution") + assert( + readPartitions === expectedReadPartitions, + s"Wrong number of read partitions: $queryExecution") } } } From caca56024026d8211cf55eba2da95279a6b000bd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 25 Dec 2014 17:22:47 +0800 Subject: [PATCH 5/5] Moves filter normalization into BooleanSimplification rule --- .../sql/catalyst/optimizer/Optimizer.scala | 59 ++++++++----------- .../optimizer/NormalizeFiltersSuite.scala | 2 +- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c4db0f108011..d82fb85b80d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,7 +40,6 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, - NormalizeFilters, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -295,11 +294,16 @@ object OptimizeIn extends Rule[LogicalPlan] { } /** - * Simplifies boolean expressions where the answer can be determined without evaluating both sides. + * Simplifies boolean expressions: + * + * 1. Simplifies expressions whose answer can be determined without evaluating both sides. + * 2. Eliminates / extracts common factors. + * 3. Removes `Not` operator. + * * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus * is only safe when evaluations of expressions does not result in side effects. */ -object BooleanSimplification extends Rule[LogicalPlan] { +object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case and @ And(left, right) => @@ -308,7 +312,9 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (l, Literal(true, BooleanType)) => l case (Literal(false, BooleanType), _) => Literal(false) case (_, Literal(false, BooleanType)) => Literal(false) - case (_, _) => and + // a && a && a ... => a + case _ if splitConjunctivePredicates(and).distinct.size == 1 => left + case _ => and } case or @ Or(left, right) => @@ -317,7 +323,19 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, Literal(true, BooleanType)) => Literal(true) case (Literal(false, BooleanType), r) => r case (l, Literal(false, BooleanType)) => l - case (_, _) => or + // a || a || a ... => a + case _ if splitDisjunctivePredicates(or).distinct.size == 1 => left + // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...) + case _ => + val lhsSet = splitConjunctivePredicates(left).toSet + val rhsSet = splitConjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + + (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And)) + .reduceOption(Or) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(And) } case not @ Not(exp) => @@ -348,37 +366,6 @@ object CombineFilters extends Rule[LogicalPlan] { } } -/** - * Normalizes conjunctions and disjunctions to eliminate common factors. - */ -object NormalizeFilters extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(predicate: Predicate, child) => - f.copy(condition = normalizedPredicate(predicate).reduce(And)) - } - - def normalizedPredicate(predicate: Expression): Seq[Expression] = predicate match { - // a && a && a ... => a - case p @ And(e, _) if splitConjunctivePredicates(p).distinct.size == 1 => e :: Nil - - // a || a || a ... => a - case p @ Or(e, _) if splitDisjunctivePredicates(p).distinct.size == 1 => e :: Nil - - // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...) - case Or(lhs, rhs) => - val lhsSet = splitConjunctivePredicates(lhs).toSet - val rhsSet = splitConjunctivePredicates(rhs).toSet - val common = lhsSet.intersect(rhsSet) - - (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And)) - .reduceOption(Or) - .map(_ :: common.toList) - .getOrElse(common.toList) - - case _ => predicate :: Nil - } -} - /** * Removes filters that can be evaluated trivially. This is done either by eliding the filter for * cases where it will always evaluate to `true`, or substituting a dummy empty relation when the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala index 85e2682a11c5..906300d8336c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala @@ -33,7 +33,7 @@ class NormalizeFiltersSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateAnalysisOperators), Batch("NormalizeFilters", FixedPoint(100), - NormalizeFilters, + BooleanSimplification, SimplifyFilters)) }