Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,12 @@ setQuantifier
;

relation
: left=relation
(joinType JOIN right=relation joinCriteria?
| NATURAL joinType JOIN right=relation
) #joinRelation
| relationPrimary #relationDefault
: relationPrimary joinRelation*
;

joinRelation
: (joinType) JOIN right=relationPrimary joinCriteria?
| NATURAL joinType JOIN right=relationPrimary
;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think NATURAL CROSS JOIN is invalid, so perhaps we should not include CROSS in joinType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a point there. Let me update that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move the code around. I have added a check in the AstBuilder (spark side of the parser) to catch this.


joinType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {

// Apply CTEs
query.optional(ctx.ctes) {
val ctes = ctx.ctes.namedQuery.asScala.map {
case nCtx =>
val namedQuery = visitNamedQuery(nCtx)
(namedQuery.alias, namedQuery)
val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
val namedQuery = visitNamedQuery(nCtx)
(namedQuery.alias, namedQuery)
}
// Check for duplicate names.
checkDuplicateKeys(ctes, ctx)
Expand Down Expand Up @@ -401,7 +400,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* separated) relations here, these get converted into a single plan by condition-less inner join.
*/
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
val right = plan(relation.relationPrimary)
val join = right.optionalMap(left)(Join(_, _, Inner, None))
withJoinRelations(join, relation)
}
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
}

Expand Down Expand Up @@ -532,55 +535,53 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}

/**
* Create a joins between two or more logical plans.
* Create a single relation referenced in a FROM claused. This method is used when a part of the
* join condition is nested, for example:
* {{{
* select * from t1 join (t2 cross join t3) on col1 = col2
* }}}
*/
override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
/** Build a join between two plans. */
def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
val baseJoinType = ctx.joinType match {
case null => Inner
case jt if jt.CROSS != null => Cross
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
withJoinRelations(plan(ctx.relationPrimary), ctx)
}

// Resolve the join type and join condition
val (joinType, condition) = Option(ctx.joinCriteria) match {
case Some(c) if c.USING != null =>
val columns = c.identifier.asScala.map { column =>
UnresolvedAttribute.quoted(column.getText)
}
(UsingJoin(baseJoinType, columns), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case None if ctx.NATURAL != null =>
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
Join(left, right, joinType, condition)
}
/**
* Join one more [[LogicalPlan]]s to the current logical plan.
*/
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
withOrigin(join) {
val baseJoinType = join.joinType match {
case null => Inner
case jt if jt.CROSS != null => Cross
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}

// Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
// first join clause is at the top. However fields of previously referenced tables can be used
// in following join clauses. The tree needs to be reversed in order to make this work.
var result = plan(ctx.left)
var current = ctx
while (current != null) {
current.right match {
case right: JoinRelationContext =>
result = join(current, result, plan(right.left))
current = right
case right =>
result = join(current, result, plan(right))
current = null
// Resolve the join type and join condition
val (joinType, condition) = Option(join.joinCriteria) match {
case Some(c) if c.USING != null =>
val columns = c.identifier.asScala.map { column =>
UnresolvedAttribute.quoted(column.getText)
}
(UsingJoin(baseJoinType, columns), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case None if join.NATURAL != null =>
if (baseJoinType == Cross) {
throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
}
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
Join(left, plan(join.right), joinType, condition)
}
}
result
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser

import scala.collection.mutable.StringBuilder

import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.TerminalNode

Expand Down Expand Up @@ -189,9 +189,7 @@ object ParserUtils {
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
* passed function. The original plan is returned when the context does not exist.
*/
def optionalMap[C <: ParserRuleContext](
ctx: C)(
f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f(ctx, plan)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,54 @@ class PlanParserSuite extends PlanTest {
test("left anti join", LeftAnti, testExistence)
test("anti join", LeftAnti, testExistence)

// Test natural cross join
intercept("select * from a natural cross join b")

// Test natural join with a condition
intercept("select * from a natural join b on a.id = b.id")

// Test multiple consecutive joins
assertEqual(
"select * from a join b join c right join d",
table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))

// SPARK-17296
assertEqual(
"select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is something like
SELECT * FROM T1 INNER JOIN T2 INNER JOIN T3 ON col3 = col2 ON col3 = col1;
supposed to parse ?
Without your change it returns the following error:
org.apache.spark.sql.AnalysisException: cannot resolve 'col3' given input columns: [col1, col2]; line 1 pos 63
which I don't understand. The following parses though:
SELECT * FROM T1 INNER JOIN T2 INNER JOIN T3 ON col1 = col2 ON col2 = col1
and returns a result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, it looks like your patch will disallow both queries at the parser level. Could you add a test that enforces this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I have added a test.

table("t1")
.join(table("t2"), Cross)
.join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id")))
.join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id")))
.select(star()))

// Test multiple on clauses.
intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, let's also add a test somewhere for
SELECT * FROM T1 INNER JOIN (T2 INNER JOIN T3 ON col3 = col2) ON col3 = col1
SELECT * FROM T1 INNER JOIN (T2 INNER JOIN T3) ON col3 = col2
SELECT * FROM T1 INNER JOIN (T2 INNER JOIN T3 ON col3 = col2)

This looks good to me.


// Parenthesis
assertEqual(
"select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1",
table("t1")
.join(table("t2")
.join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1))
.select(star()))
assertEqual(
"select * from t1 inner join (t2 inner join t3) on col3 = col2",
table("t1")
.join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2))
.select(star()))
assertEqual(
"select * from t1 inner join (t2 inner join t3 on col3 = col2)",
table("t1")
.join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None)
.select(star()))

// Implicit joins.
assertEqual(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM

"select * from t1, t3 join t2 on t1.col1 = t2.col2",
table("t1")
.join(table("t3"))
.join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2")))
.select(star()))
}

test("sampled relations") {
Expand Down