Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ case class SimpleCatalystConf(
override val starSchemaDetection: Boolean = false,
override val warehousePath: String = "/user/hive/warehouse",
override val sessionLocalTimeZone: String = TimeZone.getDefault().getID,
override val maxNestedViewDepth: Int = 100)
override val maxNestedViewDepth: Int = 100,
override val constraintPropagationEnabled: Boolean = true)
extends SQLConf {

override def clone(): SimpleCatalystConf = this.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
// Operator push down
PushProjectionThroughUnion,
ReorderJoin(conf),
EliminateOuterJoin,
EliminateOuterJoin(conf),
PushPredicateThroughJoin,
PushDownPredicate,
LimitPushDown(conf),
ColumnPruning,
InferFiltersFromConstraints,
InferFiltersFromConstraints(conf),
// Operator combine
CollapseRepartition,
CollapseProject,
Expand All @@ -107,7 +107,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
PruneFilters,
PruneFilters(conf),
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
Expand Down Expand Up @@ -615,8 +615,16 @@ object CollapseWindow extends Rule[LogicalPlan] {
* Note: While this optimization is applicable to all types of join, it primarily benefits Inner and
* LeftSemi joins.
*/
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case class InferFiltersFromConstraints(conf: CatalystConf)
extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) {
inferFilters(plan)
} else {
plan
}


private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, child) =>
val newFilters = filter.constraints --
(child.constraints ++ splitConjunctivePredicates(condition))
Expand Down Expand Up @@ -705,7 +713,7 @@ object EliminateSorts extends Rule[LogicalPlan] {
* 2) by substituting a dummy empty relation when the filter will always evaluate to `false`.
* 3) by eliminating the always-true conditions given the constraints on the child's output.
*/
object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
case class PruneFilters(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// If the filter condition always evaluate to true, remove the filter.
case Filter(Literal(true, BooleanType), child) => child
Expand All @@ -718,7 +726,7 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
case f @ Filter(fc, p: LogicalPlan) =>
val (prunedPredicates, remainingPredicates) =
splitConjunctivePredicates(fc).partition { cond =>
cond.deterministic && p.constraints.contains(cond)
cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond)
}
if (prunedPredicates.isEmpty) {
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation}
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -439,7 +440,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe
*
* This rule should be executed before pushing down the Filter
*/
object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper {

/**
* Returns whether the expression returns null or false when all inputs are nulls.
Expand All @@ -455,7 +456,8 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
}

private def buildNewJoinType(filter: Filter, join: Join): JoinType = {
val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints
val conditions = splitConjunctivePredicates(filter.condition) ++
filter.getConstraints(conf.constraintPropagationEnabled)
val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet))
val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))

/**
* Returns [[constraints]] depending on the config of enabling constraint propagation. If the
* flag is disabled, simply returning an empty constraints.
*/
private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet =
if (constraintPropagationEnabled) {
constraints
} else {
ExpressionSet(Set.empty)
}

/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
* based on the given operator's constraint propagation logic. These constraints are then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled")
.internal()
.doc("When true, the query optimizer will infer and propagate data constraints in the query " +
"plan to optimize them. Constraint propagation can sometimes be computationally expensive" +
"for certain kinds of query plans (such as those with a large number of predicates and " +
"aliases) which might negatively impact overall runtime.")
.booleanConf
.createWithDefault(true)

val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema")
.doc("When true, the Parquet data source merges schemas collected from all data files, " +
"otherwise the schema is picked from the summary file or a random data file " +
Expand Down Expand Up @@ -887,6 +896,8 @@ class SQLConf extends Serializable with Logging {

def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)

def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ import org.apache.spark.sql.catalyst.rules._
class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val conf = SimpleCatalystConf(caseSensitiveAnalysis = true)
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)),
NullPropagation(conf),
ConstantFolding,
BooleanSimplification,
SimplifyBinaryComparison,
PruneFilters) :: Nil
PruneFilters(conf)) :: Nil
}

val nullableRelation = LocalRelation('a.int.withNullability(true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ import org.apache.spark.sql.catalyst.rules._
class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val conf = SimpleCatalystConf(caseSensitiveAnalysis = true)
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)),
NullPropagation(conf),
ConstantFolding,
BooleanSimplification,
PruneFilters) :: Nil
PruneFilters(conf)) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -31,7 +32,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints,
InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true)),
CombineFilters) :: Nil
}

object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] {
val batches =
Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true,
constraintPropagationEnabled = false)),
CombineFilters) :: Nil
}

Expand Down Expand Up @@ -201,4 +212,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("No inferred filter when constraint propagation is disabled") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand All @@ -31,7 +32,17 @@ class OuterJoinEliminationSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Outer Join Elimination", Once,
EliminateOuterJoin,
EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for outer join elimination as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test.

PushPredicateThroughJoin) :: Nil
}

object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Outer Join Elimination", Once,
EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true,
constraintPropagationEnabled = false)),
PushPredicateThroughJoin) :: Nil
}

Expand Down Expand Up @@ -231,4 +242,21 @@ class OuterJoinEliminationSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("no outer join elimination if constraint propagation is disabled") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

// The predicate "x.b + y.d >= 3" will be inferred constraints like:
// "x.b != null" and "y.d != null", if constraint propagation is enabled.
// When we disable it, the predicate can't be evaluated on left or right plan and used to
// filter out nulls. So the Outer Join will not be eliminated.
val originalQuery =
x.join(y, FullOuter, Option("x.a".attr === "y.d".attr))
.where("x.b".attr + "y.d".attr >= 3)

val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze)

comparePlans(optimized, originalQuery.analyze)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans._
Expand All @@ -33,7 +34,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters,
PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)),
PropagateEmptyRelation) :: Nil
}

Expand All @@ -45,7 +46,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters) :: Nil
PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil
}

val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand All @@ -33,7 +34,19 @@ class PruneFiltersSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown and Pruning", Once,
CombineFilters,
PruneFilters,
PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)),
PushDownPredicate,
PushPredicateThroughJoin) :: Nil
}

object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Filter Pushdown and Pruning", Once,
CombineFilters,
PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true,
constraintPropagationEnabled = false)),
PushDownPredicate,
PushPredicateThroughJoin) :: Nil
}
Expand Down Expand Up @@ -133,4 +146,29 @@ class PruneFiltersSuite extends PlanTest {
val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze
comparePlans(optimized, correctAnswer)
}

test("No pruning when constraint propagation is disabled") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)

val query = tr1
.where("tr1.a".attr > 10 || "tr1.c".attr < 10)
.join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr))

val queryWithUselessFilter =
query.where(
("tr1.a".attr > 10 || "tr1.c".attr < 10) &&
'd.attr < 100)

val optimized =
OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze)
// When constraint propagation is disabled, the useless filter won't be pruned.
// It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant
// and duplicate filters.
val correctAnswer = tr1
.where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10)
.join(tr2.where('d.attr < 100).where('d.attr < 100),
Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand All @@ -34,7 +35,7 @@ class SetOperationSuite extends PlanTest {
CombineUnions,
PushProjectionThroughUnion,
PushDownPredicate,
PruneFilters) :: Nil
PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,22 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c")))))
}

test("enable/disable constraint propagation") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
val filterRelation = tr.where('a.attr > 10)

verifyConstraints(
filterRelation.analyze.getConstraints(constraintPropagationEnabled = true),
filterRelation.analyze.constraints)

assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty)

val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3)

verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true),
aliasedRelation.analyze.constraints)
assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty)
}
}