Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,23 @@ trait ConstraintHelper {

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*
* This method performs two main types of inference:
* 1. Attribute-to-attribute: For example, if an operator has constraints
* of the form (`a = 5`, `a = b`), this returns an additional constraint of the form `b = 5`.
* 2. Constant propagation: If the constraints contain both an equality to a constant and a
* complex expression, such as `a = 5` and `b = a + 3`, it will infer `b = 5 + 3`
* by substituting the constant into the expression.
*
* @param constraints The set of input constraints
* @return A new set of inferred constraints
*/
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
var inferredConstraints = ExpressionSet()
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])

// Step 1: Infer attribute-to-attribute equalities
predicates.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
// Also remove EqualNullSafe with the same l and r to avoid Once strategy's idempotence
Expand All @@ -77,6 +87,43 @@ trait ConstraintHelper {
inferredConstraints ++= replaceConstraints(predicates - eq - EqualNullSafe(l, r), l, r)
case _ => // No inference
}

// Step 2: Infer by constant substitution (e.g., a = 5, b = a + 3 => b = 5 + 3)
val equalityPredicates = predicates.toSeq.flatMap {
case e @ EqualTo(left: AttributeReference, right: Literal) => Some(((left, right), e))
case e @ EqualTo(left: Literal, right: AttributeReference) => Some(((right, left), e))
case _ => None
}
if (equalityPredicates.nonEmpty) {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicateSet = equalityPredicates.map(_._2).toSet
def replaceConstantsInExpression(expression: Expression) = expression transform {
case a: AttributeReference =>
constantsMap.get(a) match {
case Some(literal) => literal
case None => a
}
}
predicates.foreach { cond =>
val replaced = cond transform {
// attribute equality is handled above, no need to replace
case e @ EqualTo(_: Attribute, _: Attribute) => e
case e @ EqualTo(_: Cast, _: Attribute) => e
case e @ EqualTo(_: Attribute, _: Cast) => e

case e @ EqualTo(_, _) if !predicateSet.contains(e) => replaceConstantsInExpression(e)
}
// Avoid inferring tautologies like 1 = 1
val isTautology = replaced match {
case EqualTo(left: Expression, right: Expression) if left.foldable && right.foldable =>
left.eval() == right.eval()
case _ => false
}
if (!constraints.contains(replaced) && !isTautology) {
inferredConstraints += replaced
}
}
}
inferredConstraints -- constraints
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("filter: filter out constraints in condition with complex expression") {
val originalQuery = testRelation.where($"a" === 1 && $"b" === $"a" + 2).analyze
val correctAnswer = testRelation.where(IsNotNull($"a") && IsNotNull($"b") &&
$"a" === 1 && $"b" === $"a" + 2 && $"b" === Add(1, 2)).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("single inner join: filter out values on either side on equi-join keys") {
val x = testRelation.subquery("x")
val y = testRelation.subquery("y")
Expand Down Expand Up @@ -213,6 +221,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(Optimize.execute(original.analyze), correct.analyze)
}

test("single inner join: infer constraints in condition with complex expressions") {
val leftRelation = testRelation.subquery("x")
val rightRelation = testRelation.subquery("y")

val left = leftRelation.where($"a" === 1)
val right = rightRelation

testConstraintsAfterJoin(
left,
right,
leftRelation.where(IsNotNull($"a") && $"a" === 1),
rightRelation.where(IsNotNull($"b") && $"b" === Add(1, 2)),
Inner,
Some("y.b".attr === "x.a".attr + 2)
)
}

test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
val x = testRelation.subquery("x")
val y = testRelation.subquery("y")
Expand Down