Skip to content

Commit bbcd26f

Browse files
committed
[SPARK-21091][SQL] Move constraint code into QueryPlanConstraints
1 parent dccc0aa commit bbcd26f

File tree

2 files changed

+210
-183
lines changed

2 files changed

+210
-183
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 4 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -21,193 +21,14 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.trees.TreeNode
2222
import org.apache.spark.sql.types.{DataType, StructType}
2323

24-
abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
24+
abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
25+
extends TreeNode[PlanType]
26+
with QueryPlanConstraints[PlanType] {
27+
2528
self: PlanType =>
2629

2730
def output: Seq[Attribute]
2831

29-
/**
30-
* Extracts the relevant constraints from a given set of constraints based on the attributes that
31-
* appear in the [[outputSet]].
32-
*/
33-
protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
34-
constraints
35-
.union(inferAdditionalConstraints(constraints))
36-
.union(constructIsNotNullConstraints(constraints))
37-
.filter(constraint =>
38-
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) &&
39-
constraint.deterministic)
40-
}
41-
42-
/**
43-
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
44-
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
45-
* returns a constraint of the form `isNotNull(a)`
46-
*/
47-
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
48-
// First, we propagate constraints from the null intolerant expressions.
49-
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)
50-
51-
// Second, we infer additional constraints from non-nullable attributes that are part of the
52-
// operator's output
53-
val nonNullableAttributes = output.filterNot(_.nullable)
54-
isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet
55-
56-
isNotNullConstraints -- constraints
57-
}
58-
59-
/**
60-
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
61-
* of constraints.
62-
*/
63-
private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
64-
constraint match {
65-
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
66-
// expressions
67-
case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
68-
// Constraints always return true for all the inputs. That means, null will never be returned.
69-
// Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
70-
// null intolerant expressions.
71-
case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
72-
}
73-
74-
/**
75-
* Recursively explores the expressions which are null intolerant and returns all attributes
76-
* in these expressions.
77-
*/
78-
private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
79-
case a: Attribute => Seq(a)
80-
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
81-
case _ => Seq.empty[Attribute]
82-
}
83-
84-
// Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so
85-
// we may avoid producing recursive constraints.
86-
private lazy val aliasMap: AttributeMap[Expression] = AttributeMap(
87-
expressions.collect {
88-
case a: Alias => (a.toAttribute, a.child)
89-
} ++ children.flatMap(_.aliasMap))
90-
91-
/**
92-
* Infers an additional set of constraints from a given set of equality constraints.
93-
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
94-
* additional constraint of the form `b = 5`.
95-
*
96-
* [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
97-
* as they are often useless and can lead to a non-converging set of constraints.
98-
*/
99-
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
100-
val constraintClasses = generateEquivalentConstraintClasses(constraints)
101-
102-
var inferredConstraints = Set.empty[Expression]
103-
constraints.foreach {
104-
case eq @ EqualTo(l: Attribute, r: Attribute) =>
105-
val candidateConstraints = constraints - eq
106-
inferredConstraints ++= candidateConstraints.map(_ transform {
107-
case a: Attribute if a.semanticEquals(l) &&
108-
!isRecursiveDeduction(r, constraintClasses) => r
109-
})
110-
inferredConstraints ++= candidateConstraints.map(_ transform {
111-
case a: Attribute if a.semanticEquals(r) &&
112-
!isRecursiveDeduction(l, constraintClasses) => l
113-
})
114-
case _ => // No inference
115-
}
116-
inferredConstraints -- constraints
117-
}
118-
119-
/*
120-
* Generate a sequence of expression sets from constraints, where each set stores an equivalence
121-
* class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
122-
* expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
123-
* to an selected attribute.
124-
*/
125-
private def generateEquivalentConstraintClasses(
126-
constraints: Set[Expression]): Seq[Set[Expression]] = {
127-
var constraintClasses = Seq.empty[Set[Expression]]
128-
constraints.foreach {
129-
case eq @ EqualTo(l: Attribute, r: Attribute) =>
130-
// Transform [[Alias]] to its child.
131-
val left = aliasMap.getOrElse(l, l)
132-
val right = aliasMap.getOrElse(r, r)
133-
// Get the expression set for an equivalence constraint class.
134-
val leftConstraintClass = getConstraintClass(left, constraintClasses)
135-
val rightConstraintClass = getConstraintClass(right, constraintClasses)
136-
if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
137-
// Combine the two sets.
138-
constraintClasses = constraintClasses
139-
.diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
140-
(leftConstraintClass ++ rightConstraintClass)
141-
} else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
142-
// Update equivalence class of `left` expression.
143-
constraintClasses = constraintClasses
144-
.diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
145-
} else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
146-
// Update equivalence class of `right` expression.
147-
constraintClasses = constraintClasses
148-
.diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
149-
} else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
150-
// Create new equivalence constraint class since neither expression presents
151-
// in any classes.
152-
constraintClasses = constraintClasses :+ Set(left, right)
153-
}
154-
case _ => // Skip
155-
}
156-
157-
constraintClasses
158-
}
159-
160-
/*
161-
* Get all expressions equivalent to the selected expression.
162-
*/
163-
private def getConstraintClass(
164-
expr: Expression,
165-
constraintClasses: Seq[Set[Expression]]): Set[Expression] =
166-
constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
167-
168-
/*
169-
* Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
170-
* has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
171-
* Here we first get all expressions equal to `attr` and then check whether at least one of them
172-
* is a child of the referenced expression.
173-
*/
174-
private def isRecursiveDeduction(
175-
attr: Attribute,
176-
constraintClasses: Seq[Set[Expression]]): Boolean = {
177-
val expr = aliasMap.getOrElse(attr, attr)
178-
getConstraintClass(expr, constraintClasses).exists { e =>
179-
expr.children.exists(_.semanticEquals(e))
180-
}
181-
}
182-
183-
/**
184-
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
185-
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
186-
* evaluate to `true` for all rows produced.
187-
*/
188-
lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))
189-
190-
/**
191-
* Returns [[constraints]] depending on the config of enabling constraint propagation. If the
192-
* flag is disabled, simply returning an empty constraints.
193-
*/
194-
private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet =
195-
if (constraintPropagationEnabled) {
196-
constraints
197-
} else {
198-
ExpressionSet(Set.empty)
199-
}
200-
201-
/**
202-
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
203-
* based on the given operator's constraint propagation logic. These constraints are then
204-
* canonicalized and filtered automatically to contain only those attributes that appear in the
205-
* [[outputSet]].
206-
*
207-
* See [[Canonicalize]] for more details.
208-
*/
209-
protected def validConstraints: Set[Expression] = Set.empty
210-
21132
/**
21233
* Returns the set of attributes that are output by this node.
21334
*/

0 commit comments

Comments
 (0)