diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 76724e7bbdb76..62567cc2bda60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -899,14 +899,72 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // +- SubqueryAlias t1, `t1` // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] // +- LocalRelation [_1#73, _2#74] - def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { - if (found) { + // SPARK-35080: The same issue can happen to correlated equality predicates when + // they do not guarantee one-to-one mapping between inner and outer attributes. + // For example: + // Table: + // t1(a, b): [(0, 6), (1, 5), (2, 4)] + // t2(c): [(6)] + // + // Query: + // SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2 + // + // Original subquery plan: + // Aggregate [count(1)] + // +- Filter ((a + b) = outer(c)) + // +- LocalRelation [a, b] + // + // Plan after pulling up correlated predicates: + // Aggregate [a, b] [count(1), a, b] + // +- LocalRelation [a, b] + // + // Plan after rewrite: + // Project [c1, count(1)] + // +- Join LeftOuter ((a + b) = c) + // :- LocalRelation [c] + // +- Aggregate [a, b] [count(1), a, b] + // +- LocalRelation [a, b] + // + // The right hand side of the join transformed from the subquery will output + // count(1) | a | b + // 1 | 0 | 6 + // 1 | 1 | 5 + // 1 | 2 | 4 + // and the plan after rewrite will give the original query incorrect results. + def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = { + if (predicates.nonEmpty) { // Report a non-supported case as an exception - failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + failAnalysis("Correlated column is not allowed in predicate " + + s"${predicates.map(_.sql).mkString}:\n$p") } } - var foundNonEqualCorrelatedPred: Boolean = false + def containsAttribute(e: Expression): Boolean = { + e.find(_.isInstanceOf[Attribute]).isDefined + } + + // Given a correlated predicate, check if it is either a non-equality predicate or + // equality predicate that does not guarantee one-on-one mapping between inner and + // outer attributes. When the correlated predicate does not contain any attribute + // (i.e. only has outer references), it is supported and should return false. E.G.: + // (a = outer(c)) -> false + // (outer(c) = outer(d)) -> false + // (a > outer(c)) -> true + // (a + b = outer(c)) -> true + // The last one is true because there can be multiple combinations of (a, b) that + // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0) + // and (-1, 1) can make the predicate evaluate to true. + def isUnsupportedPredicate(condition: Expression): Boolean = condition match { + // Only allow equality condition with one side being an attribute and another + // side being an expression without attributes from the inner query. Note + // OuterReference is a leaf node and will not be found here. + case Equality(_: Attribute, b) => containsAttribute(b) + case Equality(a, _: Attribute) => containsAttribute(a) + case e @ Equality(_, _) => containsAttribute(e) + case _ => true + } + + val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression] // Simplify the predicates before validating any unsupported correlation patterns in the plan. AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { @@ -949,22 +1007,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // The other operator is Join. Filter can be anywhere in a correlated subquery. case f: Filter => val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - - // Find any non-equality correlated predicates - foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { - case _: EqualTo | _: EqualNullSafe => false - case _ => true - } + unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate) failOnInvalidOuterReference(f) // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains - // only equality correlated predicates. + // only supported correlated equality predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. case a: Aggregate => failOnInvalidOuterReference(a) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a) // Join can host correlated expressions. case j @ Join(left, right, joinType, _, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index aecbf241e3947..8ea84a484d570 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -767,4 +767,28 @@ class AnalysisErrorSuite extends AnalysisTest { "using ordinal position or wrap it in first() (or first_value) if you don't care " + "which value you get." :: Nil) } + + test("SPARK-35080: Unsupported correlated equality predicates in subquery") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", IntegerType)() + val t1 = LocalRelation(a, b) + val t2 = LocalRelation(c) + val conditions = Seq( + (abs($"a") === $"c", "abs(a) = outer(c)"), + (abs($"a") <=> $"c", "abs(a) <=> outer(c)"), + ($"a" + 1 === $"c", "(a + 1) = outer(c)"), + ($"a" + $"b" === $"c", "(a + b) = outer(c)"), + ($"a" + $"c" === $"b", "(a + outer(c)) = b"), + (And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(a AS INT) = outer(c)")) + conditions.foreach { case (cond, msg) => + val plan = Project( + ScalarSubquery( + Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil, + Filter(cond, t1)) + ).as("sub") :: Nil, + t2) + assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out index 0a5958b2a7694..7d21715fbaa8a 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out @@ -100,6 +100,15 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v)) FROM t2 WHERE t2.k = t1.k) -- !query schema -struct +struct<> -- !query output -two +org.apache.spark.sql.AnalysisException +Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)): +Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS udf(max(udf(v)))#x] ++- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string)) + +- SubqueryAlias t2 + +- View (`t2`, [k#x,v#x]) + +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x] + +- Project [k#x, v#x] + +- SubqueryAlias t2 + +- LocalRelation [k#x, v#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e30ca0cf309ce..bb6b402e8156d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -554,7 +554,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") } assert(msg1.getMessage.contains( - "Correlated column is not allowed in a non-equality predicate:")) + "Correlated column is not allowed in predicate (l2.a < outer(l1.a))")) } test("disjunctive correlated scalar subquery") { @@ -1827,4 +1827,13 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row(0, 1, 1) :: Row(1, 2, null) :: Nil) } } + + test("SPARK-35080: correlated equality predicates contain only outer references") { + withTempView("t") { + Seq((0, 1), (1, 1)).toDF("c1", "c2").createOrReplaceTempView("t") + checkAnswer( + sql("select c1, c2, (select count(*) from l where c1 = c2) from t"), + Row(0, 1, 0) :: Row(1, 1, 8) :: Nil) + } + } }