diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7129c6984cf3..96c1b78a1418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -176,20 +176,58 @@ trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] { */ protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = { var allConstraints = child.constraints - projectList.foreach { - case a @ Alias(l: Literal, _) => - allConstraints += EqualNullSafe(a.toAttribute, l) - case a @ Alias(e, _) => - // For every alias in `projectList`, replace the reference in constraints by its attribute. - allConstraints ++= allConstraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute + + // For each expression collect its aliases + val aliasMap = projectList.collect { + case alias @ Alias(expr, _) if !expr.foldable && expr.deterministic => + (expr.canonicalized, alias) + }.groupBy(_._1).mapValues(_.map(_._2)) + val remainingExpressions = collection.mutable.Set(aliasMap.keySet.toSeq: _*) + + /** + * Filtering allConstraints between each iteration is necessary, because + * otherwise collecting valid constraints could in the worst case have exponential + * time and memory complexity. Each replaced alias could double the number of constraints, + * because we would keep both the original constraint and the one with alias. + */ + def shouldBeKept(expr: Expression): Boolean = { + expr.references.subsetOf(outputSet) || + remainingExpressions.contains(expr.canonicalized) || + (expr.children.nonEmpty && expr.children.forall(shouldBeKept)) + } + + // Replace expressions with aliases + for ((expr, aliases) <- aliasMap) { + allConstraints ++= allConstraints.flatMap(constraint => { + aliases.map(alias => { + constraint transform { + case e: Expression if e.semanticEquals(expr) => + alias.toAttribute + } }) - allConstraints += EqualNullSafe(e, a.toAttribute) + }) + + remainingExpressions.remove(expr) + allConstraints = allConstraints.filter(shouldBeKept) + } + + // Equality between aliases for the same expression + aliasMap.values.foreach(_.combinations(2).foreach { + case Seq(a1, a2) => + allConstraints += EqualNullSafe(a1.toAttribute, a2.toAttribute) + }) + + /** + * We keep the child constraints and equality between original and aliased attributes, + * so [[ConstraintHelper.inferAdditionalConstraints]] would have the full information available. + */ + projectList.foreach { + case alias @ Alias(expr, _) => + allConstraints += EqualNullSafe(alias.toAttribute, expr) case _ => // Don't change. } - allConstraints + allConstraints ++ child.constraints } override protected lazy val validConstraints: ExpressionSet = child.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 5ad748b6113d..9deb816f8b0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans import java.util.TimeZone +import org.scalatest.PrivateMethodTester + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -28,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} -class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { +class ConstraintPropagationSuite extends SparkFunSuite with PlanTest with PrivateMethodTester { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = resolveColumn(tr.analyze, columnName) @@ -422,4 +424,92 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { assert(aliasedRelation.analyze.constraints.isEmpty) } } + + test("SPARK-33152: infer from child constraint") { + val plan = LocalRelation('a.int, 'b.int) + .where('a === 'b) + .select('a, ('b + 1) as 'b2) + .analyze + + verifyConstraints(plan.constraints, ExpressionSet(Seq( + IsNotNull(resolveColumn(plan, "a")), + resolveColumn(plan, "a") + 1 <=> resolveColumn(plan, "b2") + ))) + } + + test("SPARK-33152: equality constraints from aliases") { + val plan1 = LocalRelation('a.int) + .select('a as 'a2, 'a as 'b2) + .analyze + + verifyConstraints(plan1.constraints, ExpressionSet(Seq( + resolveColumn(plan1, "a2") <=> resolveColumn(plan1, "b2") + ))) + + val plan2 = LocalRelation() + .select(rand(0) as 'a2, rand(0) as 'b2) + .analyze + + // No equality from non-deterministic expressions + verifyConstraints(plan2.constraints, ExpressionSet(Seq( + IsNotNull(resolveColumn(plan2, "a2")), + IsNotNull(resolveColumn(plan2, "b2")) + ))) + } + + test("SPARK-33152: Avoid exponential growth of constraints") { + val validConstraints = PrivateMethod[ExpressionSet](Symbol("validConstraints")) + + val relation = LocalRelation('a.int, 'b.int, 'c.int) + .where('a + 'b + 'c > intToLiteral(0)) + + val plan1 = relation + .select('a as 'a1, 'b as 'b1, 'c as 'c1) + .analyze + + assert(plan1.invokePrivate(validConstraints()).size == 11) + verifyConstraints(plan1.constraints, + ExpressionSet(Seq( + IsNotNull(resolveColumn(plan1, "a1")), + IsNotNull(resolveColumn(plan1, "b1")), + IsNotNull(resolveColumn(plan1, "c1")), + resolveColumn(plan1, "a1") + + resolveColumn(plan1, "b1") + + resolveColumn(plan1, "c1") > 0 + ))) + + val plan2 = relation + .select('a as 'a1, 'b as 'b1, 'c as 'c1, 'a + 'b + 'c) + .analyze + + assert(plan2.invokePrivate(validConstraints()).size == 13) + verifyConstraints(plan2.constraints, + ExpressionSet(Seq( + IsNotNull(resolveColumn(plan2, "a1")), + IsNotNull(resolveColumn(plan2, "b1")), + IsNotNull(resolveColumn(plan2, "c1")), + IsNotNull(resolveColumn(plan2, "((a + b) + c)")), + resolveColumn(plan2, "((a + b) + c)") > 0, + resolveColumn(plan2, "a1") + + resolveColumn(plan2, "b1") + + resolveColumn(plan2, "c1") > 0 + ))) + + val plan3 = relation + .select('a as 'a1, 'b as 'b1, 'c as 'c1, ('a + 'b + 'c) as 'x1) + .analyze + + assert(plan3.invokePrivate(validConstraints()).size == 13) + verifyConstraints(plan3.constraints, + ExpressionSet(Seq( + IsNotNull(resolveColumn(plan3, "a1")), + IsNotNull(resolveColumn(plan3, "b1")), + IsNotNull(resolveColumn(plan3, "c1")), + IsNotNull(resolveColumn(plan3, "x1")), + resolveColumn(plan3, "x1") > 0, + resolveColumn(plan3, "a1") + + resolveColumn(plan3, "b1") + + resolveColumn(plan3, "c1") > 0 + ))) + } }