diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index bb74282cfda20..31215b4da792b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -183,7 +183,7 @@ def head(self, n: int) -> Optional["pandas.DataFrame"]: return self.toPandas() # TODO(martin.grund) fix mypu - def join(self, other: "DataFrame", on: Any, how: Any = None) -> "DataFrame": + def join(self, other: "DataFrame", on: Any, how: Optional[str] = None) -> "DataFrame": if self._plan is None: raise Exception("Cannot join when self._plan is empty.") if other._plan is None: diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 67ed6b964fa19..486778b9d3749 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -368,21 +368,45 @@ def __init__( left: Optional["LogicalPlan"], right: "LogicalPlan", on: "ColumnOrString", - how: proto.Join.JoinType.ValueType = proto.Join.JoinType.JOIN_TYPE_INNER, + how: Optional[str], ) -> None: super().__init__(left) self.left = cast(LogicalPlan, left) self.right = right self.on = on if how is None: - how = proto.Join.JoinType.JOIN_TYPE_INNER - self.how = how + join_type = proto.Join.JoinType.JOIN_TYPE_INNER + elif how == "inner": + join_type = proto.Join.JoinType.JOIN_TYPE_INNER + elif how in ["outer", "full", "fullouter"]: + join_type = proto.Join.JoinType.JOIN_TYPE_FULL_OUTER + elif how in ["leftouter", "left"]: + join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER + elif how in ["rightouter", "right"]: + join_type = proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER + elif how in ["leftsemi", "semi"]: + join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI + elif how in ["leftanti", "anti"]: + join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI + else: + raise NotImplementedError( + """ + Unsupported join type: %s. Supported join types include: + "inner", "outer", "full", "fullouter", "full_outer", + "leftouter", "left", "left_outer", "rightouter", + "right", "right_outer", "leftsemi", "left_semi", + "semi", "leftanti", "left_anti", "anti", + """ + % how + ) + self.how = join_type def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: rel = proto.Relation() rel.join.left.CopyFrom(self.left.plan(session)) rel.join.right.CopyFrom(self.right.plan(session)) - rel.join.on.CopyFrom(self.to_attr_or_expression(self.on, session)) + rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session)) + rel.join.join_type = self.how return rel def print(self, indent: int = 0) -> str: diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index f7f164c11db49..37a64abcc5edf 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -18,14 +18,30 @@ from pyspark.sql.connect import DataFrame from pyspark.sql.connect.functions import col from pyspark.sql.connect.plan import Read, InputValidationError +import pyspark.sql.connect.proto as proto -class SparkConnectSelectOpsSuite(PlanOnlyTestFixture): +class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_literal(self): df = DataFrame.withPlan(Read("table")) self.assertIsNotNone(df.select(col("name"))._plan.collect()) self.assertRaises(InputValidationError, df.select, "name") + def test_join_with_join_type(self): + df_left = DataFrame.withPlan(Read("table")) + df_right = DataFrame.withPlan(Read("table")) + for (join_type_str, join_type) in [ + (None, proto.Join.JoinType.JOIN_TYPE_INNER), + ("inner", proto.Join.JoinType.JOIN_TYPE_INNER), + ("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER), + ("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER), + ("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER), + ("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI), + ("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI), + ]: + joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.collect() + self.assertEqual(joined_df.root.join.join_type, join_type) + if __name__ == "__main__": import unittest