Skip to content

Commit

Permalink
[SPARK-21979][SQL] Improve QueryPlanConstraints framework
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Improve QueryPlanConstraints framework, make it robust and simple.
In #15319, constraints for expressions like `a = f(b, c)` is resolved.
However, for expressions like
```scala
a = f(b, c) && c = g(a, b)
```
The current QueryPlanConstraints framework will produce non-converging constraints.
Essentially, the problem is caused by having both the name and child of aliases in the same constraint set.   We infer constraints, and push down constraints as predicates in filters, later on these predicates are propagated as constraints, etc..
Simply using the alias names only can resolve these problems.  The size of constraints is reduced without losing any information. We can always get these inferred constraints on child of aliases when pushing down filters.

Also, the EqualNullSafe between name and child in propagating alias is meaningless
```scala
allConstraints += EqualNullSafe(e, a.toAttribute)
```
It just produces redundant constraints.

## How was this patch tested?

Unit test

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #19201 from gengliangwang/QueryPlanConstraints.
  • Loading branch information
gengliangwang authored and gatorsmile committed Sep 12, 2017
1 parent c5f9b89 commit 1a98574
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ abstract class UnaryNode extends LogicalPlan {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
})
allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,91 +106,48 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*
* [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
* as they are often useless and can lead to a non-converging set of constraints.
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
val constraintClasses = generateEquivalentConstraintClasses(constraints)

val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints)
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
aliasedConstraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
inferredConstraints ++= candidateConstraints.map(_ transform {
case a: Attribute if a.semanticEquals(l) &&
!isRecursiveDeduction(r, constraintClasses) => r
})
inferredConstraints ++= candidateConstraints.map(_ transform {
case a: Attribute if a.semanticEquals(r) &&
!isRecursiveDeduction(l, constraintClasses) => l
})
val candidateConstraints = aliasedConstraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}

/**
* Generate a sequence of expression sets from constraints, where each set stores an equivalence
* class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
* expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
* to an selected attribute.
* Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints.
* Thus non-converging inference can be prevented.
* E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions.
* Also, the size of constraints is reduced without losing any information.
* When the inferred filters are pushed down the operators that generate the alias,
* the alias names used in filters are replaced by the aliased expressions.
*/
private def generateEquivalentConstraintClasses(
constraints: Set[Expression]): Seq[Set[Expression]] = {
var constraintClasses = Seq.empty[Set[Expression]]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
// Transform [[Alias]] to its child.
val left = aliasMap.getOrElse(l, l)
val right = aliasMap.getOrElse(r, r)
// Get the expression set for an equivalence constraint class.
val leftConstraintClass = getConstraintClass(left, constraintClasses)
val rightConstraintClass = getConstraintClass(right, constraintClasses)
if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
// Combine the two sets.
constraintClasses = constraintClasses
.diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
(leftConstraintClass ++ rightConstraintClass)
} else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
// Update equivalence class of `left` expression.
constraintClasses = constraintClasses
.diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
} else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
// Update equivalence class of `right` expression.
constraintClasses = constraintClasses
.diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
} else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
// Create new equivalence constraint class since neither expression presents
// in any classes.
constraintClasses = constraintClasses :+ Set(left, right)
}
case _ => // Skip
private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression])
: Set[Expression] = {
val attributesInEqualTo = constraints.flatMap {
case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil
case _ => Nil
}

constraintClasses
}

/**
* Get all expressions equivalent to the selected expression.
*/
private def getConstraintClass(
expr: Expression,
constraintClasses: Seq[Set[Expression]]): Set[Expression] =
constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])

/**
* Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
* has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
* Here we first get all expressions equal to `attr` and then check whether at least one of them
* is a child of the referenced expression.
*/
private def isRecursiveDeduction(
attr: Attribute,
constraintClasses: Seq[Set[Expression]]): Boolean = {
val expr = aliasMap.getOrElse(attr, attr)
getConstraintClass(expr, constraintClasses).exists { e =>
expr.children.exists(_.semanticEquals(e))
var aliasedConstraints = constraints
attributesInEqualTo.foreach { a =>
if (aliasMap.contains(a)) {
val child = aliasMap.get(a).get
aliasedConstraints = replaceConstraints(aliasedConstraints, child, a)
}
}
aliasedConstraints
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
.where(IsNotNull('a) && IsNotNull('b) &&'a === 'b)
.select('a, 'b.as('d)).as("t")
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
Expand All @@ -176,24 +176,48 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))
&& Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b)))
&& 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)))
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a)))
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b)))
&& 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b))
&& 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b))
&& 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b)))
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
.select('int_col, 'd, 'a).as("t")
.join(t2
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
.join(
t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) &&
'a === Coalesce(Seq('a, 'a))),
Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("inner join with EqualTo expressions containing part of each other: don't generate " +
"constraints for recursive functions") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

// We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating
// complicated constraints through the constraint inference procedure.
val originalQuery = t1
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
.where('a === 'd && 'c === 'e)
.join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) &&
'c === Coalesce(Seq('a, 'b)))
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
.join(t2.where(IsNotNull('a) && IsNotNull('c)),
Inner,
Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("generate correct filters for alias that don't produce recursive constraints") {
val t1 = testRelation.subquery('t1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
verifyConstraints(aliasedRelation.analyze.constraints,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))

Expand Down

0 comments on commit 1a98574

Please sign in to comment.