Skip to content

Commit

Permalink
Cherrypick for 4.6.25.0-GA [KE-43557][SPARK-33152][SQL] Improve the p…
Browse files Browse the repository at this point in the history
…erformance of constraint propagation for Project and Aggregate (Kyligence#739) (Kyligence#748)
  • Loading branch information
Mrhs121 authored Apr 28, 2024
1 parent 1d40d6e commit 6bb3121
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,58 @@ trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] {
*/
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = {
var allConstraints = child.constraints
projectList.foreach {
case a @ Alias(l: Literal, _) =>
allConstraints += EqualNullSafe(a.toAttribute, l)
case a @ Alias(e, _) if e.deterministic =>
// For every alias in `projectList`, replace the reference in constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute

// For each expression collect its aliases
val aliasMap = projectList.collect {
case alias @ Alias(expr, _) if !expr.foldable && expr.deterministic =>
(expr.canonicalized, alias)
}.groupBy(_._1).mapValues(_.map(_._2))
val remainingExpressions = collection.mutable.Set(aliasMap.keySet.toSeq: _*)

/**
* https://issues.apache.org/jira/browse/SPARK-33152
* Filtering allConstraints between each iteration is necessary, because
* otherwise collecting valid constraints could in the worst case have exponential
* time and memory complexity. Each replaced alias could double the number of constraints,
* because we would keep both the original constraint and the one with alias.
*/
def shouldBeKept(expr: Expression): Boolean = {
expr.references.subsetOf(outputSet) ||
remainingExpressions.contains(expr.canonicalized) ||
(expr.children.nonEmpty && expr.children.forall(shouldBeKept))
}

// Replace expressions with aliases
for ((expr, aliases) <- aliasMap) {
allConstraints ++= allConstraints.flatMap(constraint => {
aliases.map(alias => {
constraint transform {
case e: Expression if e.semanticEquals(expr) =>
alias.toAttribute
}
})
allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
})

remainingExpressions.remove(expr)
allConstraints = allConstraints.filter(shouldBeKept)
}

allConstraints
// Equality between aliases for the same expression
aliasMap.values.foreach(_.combinations(2).foreach {
case Seq(a1, a2) =>
allConstraints += EqualNullSafe(a1.toAttribute, a2.toAttribute)
})

/**
* We keep the child constraints and equality between original and aliased attributes,
* so [[ConstraintHelper.inferAdditionalConstraints]] would have the full information available.
*/
projectList.foreach {
case alias @ Alias(expr, _) =>
allConstraints += EqualNullSafe(alias.toAttribute, expr)
case _ => // Don't change.
}
allConstraints ++ child.constraints
}

override protected lazy val validConstraints: ExpressionSet = child.constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans

import java.util.TimeZone

import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand All @@ -28,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType}

class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
class ConstraintPropagationSuite extends SparkFunSuite with PlanTest with PrivateMethodTester{

private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
resolveColumn(tr.analyze, columnName)
Expand Down Expand Up @@ -422,4 +424,92 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
assert(aliasedRelation.analyze.constraints.isEmpty)
}
}

test("SPARK-33152: infer from child constraint") {
val plan = LocalRelation('a.int, 'b.int)
.where('a === 'b)
.select('a, ('b + 1) as 'b2)
.analyze

verifyConstraints(plan.constraints, ExpressionSet(Seq(
IsNotNull(resolveColumn(plan, "a")),
resolveColumn(plan, "a") + 1 <=> resolveColumn(plan, "b2")
)))
}

test("SPARK-33152: equality constraints from aliases") {
val plan1 = LocalRelation('a.int)
.select('a as 'a2, 'a as 'b2)
.analyze

verifyConstraints(plan1.constraints, ExpressionSet(Seq(
resolveColumn(plan1, "a2") <=> resolveColumn(plan1, "b2")
)))

val plan2 = LocalRelation()
.select(rand(0) as 'a2, rand(0) as 'b2)
.analyze

// No equality from non-deterministic expressions
verifyConstraints(plan2.constraints, ExpressionSet(Seq(
IsNotNull(resolveColumn(plan2, "a2")),
IsNotNull(resolveColumn(plan2, "b2"))
)))
}

test("SPARK-33152: Avoid exponential growth of constraints") {
val validConstraints = PrivateMethod[ExpressionSet](Symbol("validConstraints"))

val relation = LocalRelation('a.int, 'b.int, 'c.int)
.where('a + 'b + 'c > intToLiteral(0))

val plan1 = relation
.select('a as 'a1, 'b as 'b1, 'c as 'c1)
.analyze

assert(plan1.invokePrivate(validConstraints()).size == 11)
verifyConstraints(plan1.constraints,
ExpressionSet(Seq(
IsNotNull(resolveColumn(plan1, "a1")),
IsNotNull(resolveColumn(plan1, "b1")),
IsNotNull(resolveColumn(plan1, "c1")),
resolveColumn(plan1, "a1")
+ resolveColumn(plan1, "b1")
+ resolveColumn(plan1, "c1") > 0
)))

val plan2 = relation
.select('a as 'a1, 'b as 'b1, 'c as 'c1, 'a + 'b + 'c)
.analyze

assert(plan2.invokePrivate(validConstraints()).size == 13)
verifyConstraints(plan2.constraints,
ExpressionSet(Seq(
IsNotNull(resolveColumn(plan2, "a1")),
IsNotNull(resolveColumn(plan2, "b1")),
IsNotNull(resolveColumn(plan2, "c1")),
IsNotNull(resolveColumn(plan2, "((a + b) + c)")),
resolveColumn(plan2, "((a + b) + c)") > 0,
resolveColumn(plan2, "a1")
+ resolveColumn(plan2, "b1")
+ resolveColumn(plan2, "c1") > 0
)))

val plan3 = relation
.select('a as 'a1, 'b as 'b1, 'c as 'c1, ('a + 'b + 'c) as 'x1)
.analyze

assert(plan3.invokePrivate(validConstraints()).size == 13)
verifyConstraints(plan3.constraints,
ExpressionSet(Seq(
IsNotNull(resolveColumn(plan3, "a1")),
IsNotNull(resolveColumn(plan3, "b1")),
IsNotNull(resolveColumn(plan3, "c1")),
IsNotNull(resolveColumn(plan3, "x1")),
resolveColumn(plan3, "x1") > 0,
resolveColumn(plan3, "a1")
+ resolveColumn(plan3, "b1")
+ resolveColumn(plan3, "c1") > 0
)))
}
}

0 comments on commit 6bb3121

Please sign in to comment.