diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 909ec9080208..25bfcac9fe6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -158,7 +158,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { * for all conflicting attributes. */ private def dedupRight(left: LogicalPlan, right: LogicalPlan): LogicalPlan = { - val conflictingAttributes = left.outputSet.intersect(right.outputSet) + val conflictingAttributes = (left.outputSet ++ left.references).intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 2b687b103129..e40d98381102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -606,7 +606,9 @@ case class Join( } } - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def duplicateResolved: Boolean = { + (left.outputSet ++ left.references).intersect(right.outputSet).isEmpty + } // Joins are only resolved if they don't introduce ambiguous expression ids. // NaturalJoin should be ready for resolution only if everything else is resolved here @@ -1938,7 +1940,10 @@ case class LateralJoin( joinType: JoinType, condition: Option[Expression]) extends UnaryNode { - require(Seq(Inner, LeftOuter, Cross).contains(joinType), + require(Seq(Inner, LeftOuter, Cross).contains(joinType match { + case uj: UsingJoin => uj.tpe + case jt: JoinType => jt + }), s"Unsupported lateral join type $joinType") override def child: LogicalPlan = left @@ -1968,7 +1973,8 @@ case class LateralJoin( override def childrenResolved: Boolean = left.resolved && right.resolved - def duplicateResolved: Boolean = left.outputSet.intersect(right.plan.outputSet).isEmpty + def duplicateResolved: Boolean = (left.outputSet ++ left.references) + .intersect(right.plan.outputSet).isEmpty override lazy val resolved: Boolean = { childrenResolved && diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DeduplicateRelationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DeduplicateRelationsSuite.scala new file mode 100644 index 000000000000..aac17730446d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DeduplicateRelationsSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.analysis.DeduplicateRelations +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.IntegerType + +class DeduplicateRelationsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolution", FixedPoint(10), + DeduplicateRelations) :: Nil + } + + val value = AttributeReference("value", IntegerType)() + val testRelation = LocalRelation(value) + + + test("SPARK-41162: deduplicate referenced expression ids in join") { + withSQLConf(SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "error") { + val relation = testRelation.select($"value".as("a")).deduplicate() + val left = relation.select(($"a" + 1).as("a")) + val right = relation + val originalQuery = left.join(right, UsingJoin(Inner, Seq("a"))) + val optimized = Optimize.execute(originalQuery.analyze) + + def exprIds(plan: LogicalPlan): Set[Long] = + plan.children.flatMap(exprIds).toSet ++ plan.expressions.map { + case ne: NamedExpression => ne.exprId.id + case _ => 0L + }.toSet + + @tailrec + def planDeduplicated(plan: LogicalPlan): Boolean = plan.children match { + case Seq(child) => planDeduplicated(child) + case children => + // collect all expression ids of each children and index children idx by exprId + val childIdxByExprId = children.map(exprIds).zipWithIndex.flatMap { + case (set, idx) => set.map(id => (id, idx)) + }.groupBy(_._1).mapValues(_.map(_._2)) + + // each exprId should occur in exactly one child + plan.resolved && childIdxByExprId.values.forall(_.length == 1) + } + + assert(planDeduplicated(optimized), optimized) + } + } + + test("SPARK-41162: deduplicate referenced expression ids in lateral join") { + withSQLConf(SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "error") { + val relation = testRelation.select($"value".as("a")).deduplicate() + val left = relation.select(($"a" + 1).as("a")) + val right = relation + val cond = Some(left.analyze.output.head === right.analyze.output.head) + val originalQuery = left.lateralJoin(right, UsingJoin(Inner, Seq("a"))) + val optimized = Optimize.execute(originalQuery.analyze) + + def children(plan: LogicalPlan): Seq[LogicalPlan] = plan match { + case lj: LateralJoin => lj.child :: lj.right.plan :: Nil + case p: LogicalPlan => p.children + } + + def exprIds(plan: LogicalPlan): Set[Long] = + children(plan).flatMap(exprIds).toSet ++ plan.expressions.map { + case ne: NamedExpression => ne.exprId.id + case _ => 0L + }.toSet + + @tailrec + def planDeduplicated(plan: LogicalPlan): Boolean = children(plan) match { + case Seq(child) => planDeduplicated(child) + case children => + // collect all expression ids of each children and index children idx by exprId + val childIdxByExprId = children.map(exprIds).zipWithIndex.flatMap { + case (set, idx) => set.map(id => (id, idx)) + }.groupBy(_._1).mapValues(_.map(_._2)) + + // each exprId should occur in exactly one child + plan.resolved && childIdxByExprId.values.forall(_.length == 1) + } + + assert(planDeduplicated(optimized), optimized) + } + } + + // Problem: deduplicating attributes already referenced will break those old references + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c841bffac8cd..e4f6b4cb40c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -344,6 +344,24 @@ class DataFrameJoinSuite extends QueryTest } } + Seq("left_semi", "left_anti").foreach { joinType => + test(s"SPARK-41162: $joinType self-joined aggregated dataframe") { + // aggregated dataframe + val ids = Seq(1, 2, 3).toDF("id").distinct() + + // self-joined via joinType + val result = ids.withColumn("id", $"id" + 1) + .join(ids, "id", joinType).collect() + + val expected = joinType match { + case "left_semi" => 2 + case "left_anti" => 1 + case _ => -1 // unsupported test type, test will always fail + } + assert(result.length == expected) + } + } + def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match { case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left) case Filter(_, child) => extractLeftDeepInnerJoins(child)