Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,58 @@ trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] {
*/
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = {
var allConstraints = child.constraints
projectList.foreach {
case a @ Alias(l: Literal, _) =>
allConstraints += EqualNullSafe(a.toAttribute, l)
case a @ Alias(e, _) =>
// For every alias in `projectList`, replace the reference in constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute

// For each expression collect its aliases
val aliasMap = projectList.collect {
case alias @ Alias(expr, _) if !expr.foldable && expr.deterministic =>
(expr.canonicalized, alias)
}.groupBy(_._1).mapValues(_.map(_._2))
val remainingExpressions = collection.mutable.Set(aliasMap.keySet.toSeq: _*)

/**
* Filtering allConstraints between each iteration is necessary, because
* otherwise collecting valid constraints could in the worst case have exponential
* time and memory complexity. Each replaced alias could double the number of constraints,
* because we would keep both the original constraint and the one with alias.
*/
def shouldBeKept(expr: Expression): Boolean = {
expr.references.subsetOf(outputSet) ||
remainingExpressions.contains(expr.canonicalized) ||
(expr.children.nonEmpty && expr.children.forall(shouldBeKept))
}

// Replace expressions with aliases
for ((expr, aliases) <- aliasMap) {
allConstraints ++= allConstraints.flatMap(constraint => {
aliases.map(alias => {
constraint transform {
case e: Expression if e.semanticEquals(expr) =>
alias.toAttribute
}
})
allConstraints += EqualNullSafe(e, a.toAttribute)
})

remainingExpressions.remove(expr)
allConstraints = allConstraints.filter(shouldBeKept)
}

// Equality between aliases for the same expression
aliasMap.values.foreach(_.combinations(2).foreach {
case Seq(a1, a2) =>
allConstraints += EqualNullSafe(a1.toAttribute, a2.toAttribute)
})

/**
* We keep the child constraints and equality between original and aliased attributes,
* so [[ConstraintHelper.inferAdditionalConstraints]] would have the full information available.
*/
projectList.foreach {
case alias @ Alias(expr, _) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we just need to handle a @ Alias(l: Literal, _) here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be right, that only the literal aliases are used currently, but all aliases were kept in the previous code (lines 180 & 187) and when somebody wants to improve inferAdditionalConstraints, then they might need these.
Removing non-literal aliases would be marginal performance improvement and I would rather keep the existing behavior.

allConstraints += EqualNullSafe(alias.toAttribute, expr)
case _ => // Don't change.
}

allConstraints
allConstraints ++ child.constraints
}

override protected lazy val validConstraints: ExpressionSet = child.constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans

import java.util.TimeZone

import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand All @@ -28,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType}

class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
class ConstraintPropagationSuite extends SparkFunSuite with PlanTest with PrivateMethodTester {

private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
resolveColumn(tr.analyze, columnName)
Expand Down Expand Up @@ -422,4 +424,92 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
assert(aliasedRelation.analyze.constraints.isEmpty)
}
}

test("SPARK-33152: infer from child constraint") {
val plan = LocalRelation('a.int, 'b.int)
.where('a === 'b)
.select('a, ('b + 1) as 'b2)
.analyze

verifyConstraints(plan.constraints, ExpressionSet(Seq(
IsNotNull(resolveColumn(plan, "a")),
resolveColumn(plan, "a") + 1 <=> resolveColumn(plan, "b2")
)))
}

test("SPARK-33152: equality constraints from aliases") {
val plan1 = LocalRelation('a.int)
.select('a as 'a2, 'a as 'b2)
.analyze

verifyConstraints(plan1.constraints, ExpressionSet(Seq(
resolveColumn(plan1, "a2") <=> resolveColumn(plan1, "b2")
)))

val plan2 = LocalRelation()
.select(rand(0) as 'a2, rand(0) as 'b2)
.analyze

// No equality from non-deterministic expressions
verifyConstraints(plan2.constraints, ExpressionSet(Seq(
IsNotNull(resolveColumn(plan2, "a2")),
IsNotNull(resolveColumn(plan2, "b2"))
)))
}

test("SPARK-33152: Avoid exponential growth of constraints") {
val validConstraints = PrivateMethod[ExpressionSet](Symbol("validConstraints"))

val relation = LocalRelation('a.int, 'b.int, 'c.int)
.where('a + 'b + 'c > intToLiteral(0))

val plan1 = relation
.select('a as 'a1, 'b as 'b1, 'c as 'c1)
.analyze

assert(plan1.invokePrivate(validConstraints()).size == 11)
verifyConstraints(plan1.constraints,
ExpressionSet(Seq(
IsNotNull(resolveColumn(plan1, "a1")),
IsNotNull(resolveColumn(plan1, "b1")),
IsNotNull(resolveColumn(plan1, "c1")),
resolveColumn(plan1, "a1")
+ resolveColumn(plan1, "b1")
+ resolveColumn(plan1, "c1") > 0
)))

val plan2 = relation
.select('a as 'a1, 'b as 'b1, 'c as 'c1, 'a + 'b + 'c)
.analyze

assert(plan2.invokePrivate(validConstraints()).size == 13)
verifyConstraints(plan2.constraints,
ExpressionSet(Seq(
IsNotNull(resolveColumn(plan2, "a1")),
IsNotNull(resolveColumn(plan2, "b1")),
IsNotNull(resolveColumn(plan2, "c1")),
IsNotNull(resolveColumn(plan2, "((a + b) + c)")),
resolveColumn(plan2, "((a + b) + c)") > 0,
resolveColumn(plan2, "a1")
+ resolveColumn(plan2, "b1")
+ resolveColumn(plan2, "c1") > 0
)))

val plan3 = relation
.select('a as 'a1, 'b as 'b1, 'c as 'c1, ('a + 'b + 'c) as 'x1)
.analyze

assert(plan3.invokePrivate(validConstraints()).size == 13)
verifyConstraints(plan3.constraints,
ExpressionSet(Seq(
IsNotNull(resolveColumn(plan3, "a1")),
IsNotNull(resolveColumn(plan3, "b1")),
IsNotNull(resolveColumn(plan3, "c1")),
IsNotNull(resolveColumn(plan3, "x1")),
resolveColumn(plan3, "x1") > 0,
resolveColumn(plan3, "a1")
+ resolveColumn(plan3, "b1")
+ resolveColumn(plan3, "c1") > 0
)))
}
}