From 64e3320a248a3970b2ea3b006d026375b12f6ef6 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 9 Mar 2016 15:44:11 -0800 Subject: [PATCH 1/2] Infer additional constraints from attribute equality --- .../spark/sql/catalyst/plans/QueryPlan.scala | 20 +++++++++++++++++++ .../plans/ConstraintPropagationSuite.scala | 14 +++++++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 371d72ef5af0..6f1f410905b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -32,6 +32,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { constraints + .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) @@ -61,6 +62,25 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }.foldLeft(Set.empty[Expression])(_ union _.toSet) } + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5` + */ + private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints.map { + case eq @ EqualTo(l: Attribute, r: Attribute) => + (constraints -- Set(eq)).map(_ transform { + case a: Attribute if a.semanticEquals(l) => r + }).union( + (constraints -- Set(eq)).map(_ transform { + case a: Attribute if a.semanticEquals(r) => l + })) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _) -- constraints + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 868ad934daf1..8f01f8905204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -158,6 +158,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, tr1.resolveQuoted("a", caseInsensitiveResolution).get === tr2.resolveQuoted("a", caseInsensitiveResolution).get, + tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) @@ -203,4 +204,17 @@ class ConstraintPropagationSuite extends SparkFunSuite { .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints.isEmpty) } + + test("infer additional constraints in filters") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr + .where('a.attr > 10 && 'a.attr === 'b.attr) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "b") > 10, + resolveColumn(tr, "a") === resolveColumn(tr, "b"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b"))))) + } } From c57db5b4a3950ba5e653ff3b8cdfa33e157305ef Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 9 Mar 2016 23:34:37 -0800 Subject: [PATCH 2/2] Reynold's comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 6f1f410905b7..0a3d4056f64c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -68,17 +68,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * additional constraint of the form `b = 5` */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints.map { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - (constraints -- Set(eq)).map(_ transform { + inferredConstraints ++= (constraints - eq).map(_ transform { case a: Attribute if a.semanticEquals(l) => r - }).union( - (constraints -- Set(eq)).map(_ transform { - case a: Attribute if a.semanticEquals(r) => l - })) - case _ => - Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _) -- constraints + }) + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(r) => l + }) + case _ => // No inference + } + inferredConstraints -- constraints } /**