Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,17 @@ message Filter {
message Join {
Relation left = 1;
Relation right = 2;
Expression on = 3;
JoinType how = 4;
Expression join_condition = 3;
JoinType join_type = 4;

enum JoinType {
JOIN_TYPE_UNSPECIFIED = 0;
JOIN_TYPE_INNER = 1;
JOIN_TYPE_OUTER = 2;
JOIN_TYPE_FULL_OUTER = 2;
JOIN_TYPE_LEFT_OUTER = 3;
JOIN_TYPE_RIGHT_OUTER = 4;
JOIN_TYPE_ANTI = 5;
JOIN_TYPE_LEFT_ANTI = 5;
JOIN_TYPE_LEFT_SEMI = 6;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect
import scala.collection.JavaConverters._

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser

/**
Expand Down Expand Up @@ -51,6 +52,21 @@ package object dsl {
.build()
).build()
}

def join(
otherPlan: proto.Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
condition: Option[proto.Expression] = None): proto.Relation = {
val relation = proto.Relation.newBuilder()
val join = proto.Join.newBuilder()
join.setLeft(logicalPlan)
.setRight(otherPlan)
.setJoinType(joinType)
if (condition.isDefined) {
join.setJoinCondition(condition.get)
}
relation.setJoin(join).build()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.{Since, Unstable}
import org.apache.spark.connect.proto
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, plans}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please keep the package import? This makes it easier to read where the specific classes come from in particular when they have similar names.

So it's easier to ready seeing logical.JoinType and proto.JoinType for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually still there :). IDE seems auto pack things into import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}

import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -223,15 +223,30 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

private def transformJoin(rel: proto.Join): LogicalPlan = {
assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
val joinCondition =
if (rel.hasJoinCondition) Some(transformExpression(rel.getJoinCondition)) else None

logical.Join(
left = transformRelation(rel.getLeft),
right = transformRelation(rel.getRight),
// TODO(SPARK-40534) Support additional join types and configuration.
joinType = plans.Inner,
condition = Some(transformExpression(rel.getOn)),
joinType = transformJoinType(
if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER),
condition = joinCondition,
hint = logical.JoinHint.NONE)
}

private def transformJoinType(t: proto.Join.JoinType): JoinType = {
t match {
case proto.Join.JoinType.JOIN_TYPE_INNER => Inner
case proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI => LeftAnti
case proto.Join.JoinType.JOIN_TYPE_FULL_OUTER => FullOuter
case proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER => LeftOuter
case proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER => RightOuter
case proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI => LeftSemi
case _ => throw InvalidPlanInput(s"Join type ${t} is not supported")
}
}

private def transformSort(rel: proto.Sort): LogicalPlan = {
assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.")
logical.Sort(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build()
intercept[AssertionError](transform(incompleteJoin))

// Cartesian Product not supported.
// Join type JOIN_TYPE_UNSPECIFIED is not supported.
intercept[InvalidPlanInput] {
val simpleJoin = proto.Relation.newBuilder
.setJoin(proto.Join.newBuilder.setLeft(readRel).setRight(readRel))
Expand All @@ -185,7 +185,12 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val simpleJoin = proto.Relation.newBuilder
.setJoin(
proto.Join.newBuilder.setLeft(readRel).setRight(readRel).setOn(joinCondition).build())
proto.Join.newBuilder
.setLeft(readRel)
.setRight(readRel)
.setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER)
.setJoinCondition(joinCondition)
.build())
.build()

val res = transform(simpleJoin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

/**
Expand All @@ -32,8 +33,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {

lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int))

lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int))

lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int)

lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int)

test("Basic select") {
val connectPlan = {
// TODO: Scala only allows one implicit per scope so we keep proto implicit imports in
Expand All @@ -46,6 +51,36 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

test("Basic joins with different join types") {
val connectPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test-case for unspecified as well to see that we catch the error?

Copy link
Contributor Author

@amaliujia amaliujia Oct 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed with @cloud-fan and we decided to remove both JoinType.CROSS and JoinType.Unspecified.

Proto is our API. We should make the proto itself less ambiguous. For example any proto submitted to the server should not contain a JoinType.Unspecified. It is either set a join type with explicit semantic (inner, left outer, etc.) or not set, which Spark by default treats it as inner join.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for unspecified is not the proto contract but the language behavior for different auto generated targets. To avoid issues with defaults, the recommendation in the typical proto style guides is to always have the first element of an enum be unspecified.

Cc @cloud-fan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the proto is an API, I'd say the join type is a required field and clients must set the join type in the join plan. For the python client, its dataframe API can omit the join type, and the python client should use INNER as the default join type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two layers of this discussion one on the proto infra level and one on the API level. I'm fine with the API level decision.

My point referred to the recommendations when using protos:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah if this is the style guide of protobuf, let's keep it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the server can simply fail if it sees UNSPECIFIED join type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly.

import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.join(connectTestRelation2))
}
val sparkPlan = sparkTestRelation.join(sparkTestRelation2)
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)

val connectPlan2 = {
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.join(connectTestRelation2, condition = None))
}
val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None)
comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false)
for ((t, y) <- Seq(
(JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter),
(JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter),
(JoinType.JOIN_TYPE_FULL_OUTER, FullOuter),
(JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti),
(JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi),
(JoinType.JOIN_TYPE_INNER, Inner))) {
val connectPlan3 = {
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.join(connectTestRelation2, t))
}
val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y)
comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false)
}
}

private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
// TODO: set data types for each local relation attribute one proto supports data type.
Expand Down