Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1239,35 +1239,38 @@ class Analyzer(
*/
object ResolveNaturalJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Should not skip unresolved nodes because natural join is always unresolved.
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
// find common column names from both sides, should be treated like usingColumns
// find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)

// Add joinPairs to joinConditions
val newCondition = (condition ++ joinPairs.map {
case (l, r) => EqualTo(l, r)
}).reduceLeftOption(And)
}).reduceOption(And)

// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
// we should only keep unique columns(depends on joinType) for joinCols

// the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map {
case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)()
}
joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case _ =>
case Inner =>
rightKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ case object LeftSemi extends JoinType {
}

case class NaturalJoin(tpe: JoinType) extends JoinType {
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
"Unsupported natural join type " + tpe)
override def sql: String = "NATURAL " + tpe.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
lazy val aNotNull = a.notNull
lazy val bNotNull = b.notNull
lazy val cNotNull = c.notNull
lazy val r1 = LocalRelation(a, b)
lazy val r2 = LocalRelation(a, c)
lazy val r1 = LocalRelation(b, a)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switching the order here more properly test the reordering of output projection.

lazy val r2 = LocalRelation(c, a)
lazy val r3 = LocalRelation(aNotNull, bNotNull)
lazy val r4 = LocalRelation(bNotNull, cNotNull)
lazy val r4 = LocalRelation(cNotNull, bNotNull)

test("natural inner join") {
val plan = r1.join(r2, NaturalJoin(Inner), None)
Expand Down