From 6d72c4c24e15361d4cc2c0556698b59b25822202 Mon Sep 17 00:00:00 2001 From: Jaromir Vanek Date: Wed, 22 Oct 2025 18:08:43 -0700 Subject: [PATCH] [SPARK-53996][SQL] Improve InferFiltersFromConstraints to infer filters from complex join expressions --- .../plans/logical/QueryPlanConstraints.scala | 51 ++++++++++++++++++- .../InferFiltersFromConstraintsSuite.scala | 25 +++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index ef035eba5922c..360012075c8db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -57,13 +57,23 @@ trait ConstraintHelper { /** * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. + * + * This method performs two main types of inference: + * 1. Attribute-to-attribute: For example, if an operator has constraints + * of the form (`a = 5`, `a = b`), this returns an additional constraint of the form `b = 5`. + * 2. Constant propagation: If the constraints contain both an equality to a constant and a + * complex expression, such as `a = 5` and `b = a + 3`, it will infer `b = 5 + 3` + * by substituting the constant into the expression. + * + * @param constraints The set of input constraints + * @return A new set of inferred constraints */ def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = { var inferredConstraints = ExpressionSet() // IsNotNull should be constructed by `constructIsNotNullConstraints`. val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) + + // Step 1: Infer attribute-to-attribute equalities predicates.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => // Also remove EqualNullSafe with the same l and r to avoid Once strategy's idempotence @@ -77,6 +87,43 @@ trait ConstraintHelper { inferredConstraints ++= replaceConstraints(predicates - eq - EqualNullSafe(l, r), l, r) case _ => // No inference } + + // Step 2: Infer by constant substitution (e.g., a = 5, b = a + 3 => b = 5 + 3) + val equalityPredicates = predicates.toSeq.flatMap { + case e @ EqualTo(left: AttributeReference, right: Literal) => Some(((left, right), e)) + case e @ EqualTo(left: Literal, right: AttributeReference) => Some(((right, left), e)) + case _ => None + } + if (equalityPredicates.nonEmpty) { + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicateSet = equalityPredicates.map(_._2).toSet + def replaceConstantsInExpression(expression: Expression) = expression transform { + case a: AttributeReference => + constantsMap.get(a) match { + case Some(literal) => literal + case None => a + } + } + predicates.foreach { cond => + val replaced = cond transform { + // attribute equality is handled above, no need to replace + case e @ EqualTo(_: Attribute, _: Attribute) => e + case e @ EqualTo(_: Cast, _: Attribute) => e + case e @ EqualTo(_: Attribute, _: Cast) => e + + case e @ EqualTo(_, _) if !predicateSet.contains(e) => replaceConstantsInExpression(e) + } + // Avoid inferring tautologies like 1 = 1 + val isTautology = replaced match { + case EqualTo(left: Expression, right: Expression) if left.foldable && right.foldable => + left.eval() == right.eval() + case _ => false + } + if (!constraints.contains(replaced) && !isTautology) { + inferredConstraints += replaced + } + } + } inferredConstraints -- constraints } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index d8d8a2b333bcd..c32023744aa8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -63,6 +63,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("filter: filter out constraints in condition with complex expression") { + val originalQuery = testRelation.where($"a" === 1 && $"b" === $"a" + 2).analyze + val correctAnswer = testRelation.where(IsNotNull($"a") && IsNotNull($"b") && + $"a" === 1 && $"b" === $"a" + 2 && $"b" === Add(1, 2)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("single inner join: filter out values on either side on equi-join keys") { val x = testRelation.subquery("x") val y = testRelation.subquery("y") @@ -213,6 +221,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } + test("single inner join: infer constraints in condition with complex expressions") { + val leftRelation = testRelation.subquery("x") + val rightRelation = testRelation.subquery("y") + + val left = leftRelation.where($"a" === 1) + val right = rightRelation + + testConstraintsAfterJoin( + left, + right, + leftRelation.where(IsNotNull($"a") && $"a" === 1), + rightRelation.where(IsNotNull($"b") && $"b" === Add(1, 2)), + Inner, + Some("y.b".attr === "x.a".attr + 2) + ) + } + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery("x") val y = testRelation.subquery("y")