diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 580133dd971b..2c12524ab6cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -754,11 +754,15 @@ class Analyzer( * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { + private val emptyAttrMap = new AttributeMap[Attribute](Map.empty) + /** * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + private def dedupRight( + left: LogicalPlan, + right: LogicalPlan): (LogicalPlan, AttributeMap[Attribute]) = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -805,10 +809,10 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - right + (right, emptyAttrMap) case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - right transformUp { + val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { @@ -818,6 +822,7 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } + (newRight, attributeRewrites) } } @@ -921,12 +926,18 @@ class Analyzer( failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") // To resolve duplicate expression IDs for Join and Intersect - case j @ Join(left, right, _, _) if !j.duplicateResolved => - j.copy(right = dedupRight(left, right)) + case j @ Join(left, right, _, condition) if !j.duplicateResolved => + val (dedupedRight, attributeRewrites) = dedupRight(left, right) + val changedCondition = condition.map(_.transform { + case attr: Attribute if attr.resolved => dedupAttr(attr, attributeRewrites) + }) + j.copy(right = dedupedRight, condition = changedCondition) case i @ Intersect(left, right, _) if !i.duplicateResolved => - i.copy(right = dedupRight(left, right)) + val (dedupedRight, _) = dedupRight(left, right) + i.copy(right = dedupedRight) case e @ Except(left, right, _) if !e.duplicateResolved => - e.copy(right = dedupRight(left, right)) + val (dedupedRight, _) = dedupRight(left, right) + e.copy(right = dedupedRight) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e6b30f9956da..970e00c50a3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,4 +295,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } + + test("SPARK-25150: Attribute deduplication handles attributes in join condition properly") { + val a = spark.range(1, 5) + val b = spark.range(10) + val c = b.filter($"id" % 2 === 0) + + val r = a.join(b, a("id") === b("id"), "inner").join(c, a("id") === c("id"), "inner") + + checkAnswer(r, Row(2, 2, 2) :: Row(4, 4, 4) :: Nil) + } }