diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 087a4dca2f70a..25bc4e8a16b18 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -94,16 +94,17 @@ message Filter { message Join { Relation left = 1; Relation right = 2; - Expression on = 3; - JoinType how = 4; + Expression join_condition = 3; + JoinType join_type = 4; enum JoinType { JOIN_TYPE_UNSPECIFIED = 0; JOIN_TYPE_INNER = 1; - JOIN_TYPE_OUTER = 2; + JOIN_TYPE_FULL_OUTER = 2; JOIN_TYPE_LEFT_OUTER = 3; JOIN_TYPE_RIGHT_OUTER = 4; - JOIN_TYPE_ANTI = 5; + JOIN_TYPE_LEFT_ANTI = 5; + JOIN_TYPE_LEFT_SEMI = 6; } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index d54d5b404410e..234b423a80316 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect import scala.collection.JavaConverters._ import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.parser.CatalystSqlParser /** @@ -51,6 +52,21 @@ package object dsl { .build() ).build() } + + def join( + otherPlan: proto.Relation, + joinType: JoinType = JoinType.JOIN_TYPE_INNER, + condition: Option[proto.Expression] = None): proto.Relation = { + val relation = proto.Relation.newBuilder() + val join = proto.Join.newBuilder() + join.setLeft(logicalPlan) + .setRight(otherPlan) + .setJoinType(joinType) + if (condition.isDefined) { + join.setJoinCondition(condition.get) + } + relation.setJoin(join).build() + } } } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index fa9dd18d3bfa5..e3bb7e2932273 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.connect.proto import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.{expressions, plans} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types._ @@ -223,15 +223,30 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformJoin(rel: proto.Join): LogicalPlan = { assert(rel.hasLeft && rel.hasRight, "Both join sides must be present") + val joinCondition = + if (rel.hasJoinCondition) Some(transformExpression(rel.getJoinCondition)) else None + logical.Join( left = transformRelation(rel.getLeft), right = transformRelation(rel.getRight), - // TODO(SPARK-40534) Support additional join types and configuration. - joinType = plans.Inner, - condition = Some(transformExpression(rel.getOn)), + joinType = transformJoinType( + if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER), + condition = joinCondition, hint = logical.JoinHint.NONE) } + private def transformJoinType(t: proto.Join.JoinType): JoinType = { + t match { + case proto.Join.JoinType.JOIN_TYPE_INNER => Inner + case proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI => LeftAnti + case proto.Join.JoinType.JOIN_TYPE_FULL_OUTER => FullOuter + case proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER => LeftOuter + case proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER => RightOuter + case proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI => LeftSemi + case _ => throw InvalidPlanInput(s"Join type ${t} is not supported") + } + } + private def transformSort(rel: proto.Sort): LogicalPlan = { assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.") logical.Sort( diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index e1a658fb57b27..37d80e01f72b4 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -161,7 +161,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build() intercept[AssertionError](transform(incompleteJoin)) - // Cartesian Product not supported. + // Join type JOIN_TYPE_UNSPECIFIED is not supported. intercept[InvalidPlanInput] { val simpleJoin = proto.Relation.newBuilder .setJoin(proto.Join.newBuilder.setLeft(readRel).setRight(readRel)) @@ -185,7 +185,12 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val simpleJoin = proto.Relation.newBuilder .setJoin( - proto.Join.newBuilder.setLeft(readRel).setRight(readRel).setOn(joinCondition).build()) + proto.Join.newBuilder + .setLeft(readRel) + .setRight(readRel) + .setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER) + .setJoinCondition(joinCondition) + .build()) .build() val res = transform(simpleJoin) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6eab50a0a2bdc..4f3f0fea387e0 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation /** @@ -32,8 +33,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int)) + lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int)) + lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int) + lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int) + test("Basic select") { val connectPlan = { // TODO: Scala only allows one implicit per scope so we keep proto implicit imports in @@ -46,6 +51,36 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } + test("Basic joins with different join types") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.join(connectTestRelation2)) + } + val sparkPlan = sparkTestRelation.join(sparkTestRelation2) + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + + val connectPlan2 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.join(connectTestRelation2, condition = None)) + } + val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) + comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) + for ((t, y) <- Seq( + (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), + (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), + (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), + (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), + (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), + (JoinType.JOIN_TYPE_INNER, Inner))) { + val connectPlan3 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.join(connectTestRelation2, t)) + } + val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y) + comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) + } + } + private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() // TODO: set data types for each local relation attribute one proto supports data type.