Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,12 @@ setQuantifier
;

relation
: left=relation
((CROSS | joinType) JOIN right=relation joinCriteria?
| NATURAL joinType JOIN right=relation
) #joinRelation
| relationPrimary #relationDefault
: relationPrimary joinRelation*
;

joinRelation
: (CROSS | joinType) JOIN right=relationPrimary joinCriteria?
| NATURAL joinType JOIN right=relationPrimary
;

joinType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {

// Apply CTEs
query.optional(ctx.ctes) {
val ctes = ctx.ctes.namedQuery.asScala.map {
case nCtx =>
val namedQuery = visitNamedQuery(nCtx)
(namedQuery.alias, namedQuery)
val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
val namedQuery = visitNamedQuery(nCtx)
(namedQuery.alias, namedQuery)
}
// Check for duplicate names.
checkDuplicateKeys(ctes, ctx)
Expand Down Expand Up @@ -400,7 +399,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* separated) relations here, these get converted into a single plan by condition-less inner join.
*/
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
val right = plan(relation.relationPrimary)
val join = right.optionalMap(left)(Join(_, _, Inner, None))
withJoinRelations(join, relation)
}
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
}

Expand Down Expand Up @@ -526,54 +529,49 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}

/**
* Create a joins between two or more logical plans.
* Create a single relation referenced in a FROM claused. This method is used when a part of the
* join condition is nested, for example:
* {{{
* select * from t1 join (t2 cross join t3) on col1 = col2
* }}}
*/
override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
/** Build a join between two plans. */
def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
val baseJoinType = ctx.joinType match {
case null => Inner
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
withJoinRelations(plan(ctx.relationPrimary), ctx)
}

// Resolve the join type and join condition
val (joinType, condition) = Option(ctx.joinCriteria) match {
case Some(c) if c.USING != null =>
val columns = c.identifier.asScala.map { column =>
UnresolvedAttribute.quoted(column.getText)
}
(UsingJoin(baseJoinType, columns), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case None if ctx.NATURAL != null =>
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
Join(left, right, joinType, condition)
}
/**
* Join one more [[LogicalPlan]]s to the current logical plan.
*/
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
withOrigin(join) {
val baseJoinType = join.joinType match {
case null => Inner
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}

// Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
// first join clause is at the top. However fields of previously referenced tables can be used
// in following join clauses. The tree needs to be reversed in order to make this work.
var result = plan(ctx.left)
var current = ctx
while (current != null) {
current.right match {
case right: JoinRelationContext =>
result = join(current, result, plan(right.left))
current = right
case right =>
result = join(current, result, plan(right))
current = null
// Resolve the join type and join condition
val (joinType, condition) = Option(join.joinCriteria) match {
case Some(c) if c.USING != null =>
val columns = c.identifier.asScala.map { column =>
UnresolvedAttribute.quoted(column.getText)
}
(UsingJoin(baseJoinType, columns), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case None if join.NATURAL != null =>
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
Join(left, plan(join.right), joinType, condition)
}
}
result
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ object ParserUtils {
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
* passed function. The original plan is returned when the context does not exist.
*/
def optionalMap[C <: ParserRuleContext](
ctx: C)(
f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f(ctx, plan)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,54 @@ class PlanParserSuite extends PlanTest {
test("left anti join", LeftAnti, testExistence)
test("anti join", LeftAnti, testExistence)

// Test natural cross join
intercept("select * from a natural cross join b")

// Test natural join with a condition
intercept("select * from a natural join b on a.id = b.id")

// Test multiple consecutive joins
assertEqual(
"select * from a join b join c right join d",
table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))

// SPARK-17296
assertEqual(
"select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id",
table("t1")
.join(table("t2"), Inner)
.join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id")))
.join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id")))
.select(star()))

// Test multiple on clauses.
intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1")

// Parenthesis
assertEqual(
"select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1",
table("t1")
.join(table("t2")
.join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1))
.select(star()))
assertEqual(
"select * from t1 inner join (t2 inner join t3) on col3 = col2",
table("t1")
.join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2))
.select(star()))
assertEqual(
"select * from t1 inner join (t2 inner join t3 on col3 = col2)",
table("t1")
.join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None)
.select(star()))

// Implicit joins.
assertEqual(
"select * from t1, t3 join t2 on t1.col1 = t2.col2",
table("t1")
.join(table("t3"))
.join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2")))
.select(star()))
}

test("sampled relations") {
Expand Down