From 41311faf74d0d7e46e8fe1075b4ae87b4c40d82a Mon Sep 17 00:00:00 2001 From: yzhou2001 Date: Wed, 23 Dec 2015 09:08:34 -0800 Subject: [PATCH] fix for Spark-10838/11576 --- .../sql/catalyst/analysis/Analyzer.scala | 8 +++-- .../apache/spark/sql/DataFrameJoinSuite.scala | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c396546b4c005..e0128f4a713dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -389,7 +389,7 @@ class Analyzer( a.copy(aggregateExpressions = expanded) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => + case j @ Join(left, right, _, condition) if !j.selfJoinResolved => val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") @@ -437,7 +437,11 @@ class Analyzer( case a: Attribute => attributeRewrites.get(a).getOrElse(a) } } - j.copy(right = newRight) + val newCondition = condition.map(_ transformUp { + case a: AttributeReference => attributeRewrites.get(a).getOrElse(a) + }) + + j.copy(right = newRight, condition = newCondition) } // When resolve `SortOrder`s in Sort based on child, don't report errors as diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39a65413bd592..e5619d28a56ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -76,6 +76,37 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) } + test("Spark-11576: Complex self join") { + val df1 = Seq((1,3), (2,1)).toDF("keyCol1", "keyCol2") + val df2 = Seq((1,4), (2,1)).toDF("keyCol1", "keyCol3") + val df3 = df1.join(df2, df1("keyCol1")===df2("keyCol1")).select(df1("keyCol1"), $"keyCol3") + checkAnswer( + df3.join(df1, df3("keyCol3")===df1("keyCol1")), + Row(2,1,1,3) :: Nil) + } + + test("Spark-10838: Complex Self-Join") { + val df1 = Seq(Tuple1("1")).toDF("col_a") + val df2 = Seq(Tuple2("1","1")).toDF("col_a", "col_b") + val df3 = df1.join(df2, df1("col_a") === df2("col_a")).select(df1("col_a"), $"col_b") + + checkAnswer( + df3.join(df2, df3("col_b") === df2("col_a")), + Row("1", "1", "1", "1") :: Nil + ) + + val df4 = df2.as("df4") + checkAnswer( + df3.join(df4, df3("col_b") === df4("col_a")), + Row("1", "1", "1", "1") :: Nil + ) + + checkAnswer( + df3.join(df2.as("df4"), df3("col_b") === $"df4.col_a"), + Row("1", "1", "1", "1") :: Nil + ) + } + test("join - self join") { val df1 = testData.select(testData("key")).as('df1) val df2 = testData.select(testData("key")).as('df2)