@@ -21,193 +21,14 @@ import org.apache.spark.sql.catalyst.expressions._
2121import org .apache .spark .sql .catalyst .trees .TreeNode
2222import org .apache .spark .sql .types .{DataType , StructType }
2323
24- abstract class QueryPlan [PlanType <: QueryPlan [PlanType ]] extends TreeNode [PlanType ] {
24+ abstract class QueryPlan [PlanType <: QueryPlan [PlanType ]]
25+ extends TreeNode [PlanType ]
26+ with QueryPlanConstraints [PlanType ] {
27+
2528 self : PlanType =>
2629
2730 def output : Seq [Attribute ]
2831
29- /**
30- * Extracts the relevant constraints from a given set of constraints based on the attributes that
31- * appear in the [[outputSet ]].
32- */
33- protected def getRelevantConstraints (constraints : Set [Expression ]): Set [Expression ] = {
34- constraints
35- .union(inferAdditionalConstraints(constraints))
36- .union(constructIsNotNullConstraints(constraints))
37- .filter(constraint =>
38- constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) &&
39- constraint.deterministic)
40- }
41-
42- /**
43- * Infers a set of `isNotNull` constraints from null intolerant expressions as well as
44- * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
45- * returns a constraint of the form `isNotNull(a)`
46- */
47- private def constructIsNotNullConstraints (constraints : Set [Expression ]): Set [Expression ] = {
48- // First, we propagate constraints from the null intolerant expressions.
49- var isNotNullConstraints : Set [Expression ] = constraints.flatMap(inferIsNotNullConstraints)
50-
51- // Second, we infer additional constraints from non-nullable attributes that are part of the
52- // operator's output
53- val nonNullableAttributes = output.filterNot(_.nullable)
54- isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull ).toSet
55-
56- isNotNullConstraints -- constraints
57- }
58-
59- /**
60- * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
61- * of constraints.
62- */
63- private def inferIsNotNullConstraints (constraint : Expression ): Seq [Expression ] =
64- constraint match {
65- // When the root is IsNotNull, we can push IsNotNull through the child null intolerant
66- // expressions
67- case IsNotNull (expr) => scanNullIntolerantAttribute(expr).map(IsNotNull (_))
68- // Constraints always return true for all the inputs. That means, null will never be returned.
69- // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
70- // null intolerant expressions.
71- case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull (_))
72- }
73-
74- /**
75- * Recursively explores the expressions which are null intolerant and returns all attributes
76- * in these expressions.
77- */
78- private def scanNullIntolerantAttribute (expr : Expression ): Seq [Attribute ] = expr match {
79- case a : Attribute => Seq (a)
80- case _ : NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
81- case _ => Seq .empty[Attribute ]
82- }
83-
84- // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so
85- // we may avoid producing recursive constraints.
86- private lazy val aliasMap : AttributeMap [Expression ] = AttributeMap (
87- expressions.collect {
88- case a : Alias => (a.toAttribute, a.child)
89- } ++ children.flatMap(_.aliasMap))
90-
91- /**
92- * Infers an additional set of constraints from a given set of equality constraints.
93- * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
94- * additional constraint of the form `b = 5`.
95- *
96- * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
97- * as they are often useless and can lead to a non-converging set of constraints.
98- */
99- private def inferAdditionalConstraints (constraints : Set [Expression ]): Set [Expression ] = {
100- val constraintClasses = generateEquivalentConstraintClasses(constraints)
101-
102- var inferredConstraints = Set .empty[Expression ]
103- constraints.foreach {
104- case eq @ EqualTo (l : Attribute , r : Attribute ) =>
105- val candidateConstraints = constraints - eq
106- inferredConstraints ++= candidateConstraints.map(_ transform {
107- case a : Attribute if a.semanticEquals(l) &&
108- ! isRecursiveDeduction(r, constraintClasses) => r
109- })
110- inferredConstraints ++= candidateConstraints.map(_ transform {
111- case a : Attribute if a.semanticEquals(r) &&
112- ! isRecursiveDeduction(l, constraintClasses) => l
113- })
114- case _ => // No inference
115- }
116- inferredConstraints -- constraints
117- }
118-
119- /*
120- * Generate a sequence of expression sets from constraints, where each set stores an equivalence
121- * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
122- * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
123- * to an selected attribute.
124- */
125- private def generateEquivalentConstraintClasses (
126- constraints : Set [Expression ]): Seq [Set [Expression ]] = {
127- var constraintClasses = Seq .empty[Set [Expression ]]
128- constraints.foreach {
129- case eq @ EqualTo (l : Attribute , r : Attribute ) =>
130- // Transform [[Alias]] to its child.
131- val left = aliasMap.getOrElse(l, l)
132- val right = aliasMap.getOrElse(r, r)
133- // Get the expression set for an equivalence constraint class.
134- val leftConstraintClass = getConstraintClass(left, constraintClasses)
135- val rightConstraintClass = getConstraintClass(right, constraintClasses)
136- if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
137- // Combine the two sets.
138- constraintClasses = constraintClasses
139- .diff(leftConstraintClass :: rightConstraintClass :: Nil ) :+
140- (leftConstraintClass ++ rightConstraintClass)
141- } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
142- // Update equivalence class of `left` expression.
143- constraintClasses = constraintClasses
144- .diff(leftConstraintClass :: Nil ) :+ (leftConstraintClass + right)
145- } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
146- // Update equivalence class of `right` expression.
147- constraintClasses = constraintClasses
148- .diff(rightConstraintClass :: Nil ) :+ (rightConstraintClass + left)
149- } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
150- // Create new equivalence constraint class since neither expression presents
151- // in any classes.
152- constraintClasses = constraintClasses :+ Set (left, right)
153- }
154- case _ => // Skip
155- }
156-
157- constraintClasses
158- }
159-
160- /*
161- * Get all expressions equivalent to the selected expression.
162- */
163- private def getConstraintClass (
164- expr : Expression ,
165- constraintClasses : Seq [Set [Expression ]]): Set [Expression ] =
166- constraintClasses.find(_.contains(expr)).getOrElse(Set .empty[Expression ])
167-
168- /*
169- * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
170- * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
171- * Here we first get all expressions equal to `attr` and then check whether at least one of them
172- * is a child of the referenced expression.
173- */
174- private def isRecursiveDeduction (
175- attr : Attribute ,
176- constraintClasses : Seq [Set [Expression ]]): Boolean = {
177- val expr = aliasMap.getOrElse(attr, attr)
178- getConstraintClass(expr, constraintClasses).exists { e =>
179- expr.children.exists(_.semanticEquals(e))
180- }
181- }
182-
183- /**
184- * An [[ExpressionSet ]] that contains invariants about the rows output by this operator. For
185- * example, if this set contains the expression `a = 2` then that expression is guaranteed to
186- * evaluate to `true` for all rows produced.
187- */
188- lazy val constraints : ExpressionSet = ExpressionSet (getRelevantConstraints(validConstraints))
189-
190- /**
191- * Returns [[constraints ]] depending on the config of enabling constraint propagation. If the
192- * flag is disabled, simply returning an empty constraints.
193- */
194- private [spark] def getConstraints (constraintPropagationEnabled : Boolean ): ExpressionSet =
195- if (constraintPropagationEnabled) {
196- constraints
197- } else {
198- ExpressionSet (Set .empty)
199- }
200-
201- /**
202- * This method can be overridden by any child class of QueryPlan to specify a set of constraints
203- * based on the given operator's constraint propagation logic. These constraints are then
204- * canonicalized and filtered automatically to contain only those attributes that appear in the
205- * [[outputSet ]].
206- *
207- * See [[Canonicalize ]] for more details.
208- */
209- protected def validConstraints : Set [Expression ] = Set .empty
210-
21132 /**
21233 * Returns the set of attributes that are output by this node.
21334 */
0 commit comments