Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def head(self, n: int) -> Optional["pandas.DataFrame"]:
return self.toPandas()

# TODO(martin.grund) fix mypu
def join(self, other: "DataFrame", on: Any, how: Any = None) -> "DataFrame":
def join(self, other: "DataFrame", on: Any, how: Optional[str] = None) -> "DataFrame":
if self._plan is None:
raise Exception("Cannot join when self._plan is empty.")
if other._plan is None:
Expand Down
32 changes: 28 additions & 4 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,21 +368,45 @@ def __init__(
left: Optional["LogicalPlan"],
right: "LogicalPlan",
on: "ColumnOrString",
how: proto.Join.JoinType.ValueType = proto.Join.JoinType.JOIN_TYPE_INNER,
how: Optional[str],
) -> None:
super().__init__(left)
self.left = cast(LogicalPlan, left)
self.right = right
self.on = on
if how is None:
how = proto.Join.JoinType.JOIN_TYPE_INNER
self.how = how
join_type = proto.Join.JoinType.JOIN_TYPE_INNER
elif how == "inner":
join_type = proto.Join.JoinType.JOIN_TYPE_INNER
elif how in ["outer", "full", "fullouter"]:
join_type = proto.Join.JoinType.JOIN_TYPE_FULL_OUTER
elif how in ["leftouter", "left"]:
join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
elif how in ["rightouter", "right"]:
join_type = proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
elif how in ["leftsemi", "semi"]:
join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
elif how in ["leftanti", "anti"]:
join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
else:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in #38157 the feedback was to not support CROSS join so it is not listed here. Whenever it makes sense to have CROSS join then adding the support will be easy and fast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that? Please keep in mind that these type of decisions increase the pain of porting a pyspark job over to connect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some different opinions for why that is useful. If you think we probably need that back, let's have a discussion first to reach a consensus.

raise NotImplementedError(
"""
Unsupported join type: %s. Supported join types include:
"inner", "outer", "full", "fullouter", "full_outer",
"leftouter", "left", "left_outer", "rightouter",
"right", "right_outer", "leftsemi", "left_semi",
"semi", "leftanti", "left_anti", "anti",
"""
% how
)
self.how = join_type

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
rel = proto.Relation()
rel.join.left.CopyFrom(self.left.plan(session))
rel.join.right.CopyFrom(self.right.plan(session))
rel.join.on.CopyFrom(self.to_attr_or_expression(self.on, session))
rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
rel.join.join_type = self.how
return rel

def print(self, indent: int = 0) -> str:
Expand Down
18 changes: 17 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_select_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.functions import col
from pyspark.sql.connect.plan import Read, InputValidationError
import pyspark.sql.connect.proto as proto


class SparkConnectSelectOpsSuite(PlanOnlyTestFixture):
class SparkConnectToProtoSuite(PlanOnlyTestFixture):
def test_select_with_literal(self):
df = DataFrame.withPlan(Read("table"))
self.assertIsNotNone(df.select(col("name"))._plan.collect())
self.assertRaises(InputValidationError, df.select, "name")

def test_join_with_join_type(self):
df_left = DataFrame.withPlan(Read("table"))
df_right = DataFrame.withPlan(Read("table"))
for (join_type_str, join_type) in [
(None, proto.Join.JoinType.JOIN_TYPE_INNER),
("inner", proto.Join.JoinType.JOIN_TYPE_INNER),
("outer", proto.Join.JoinType.JOIN_TYPE_FULL_OUTER),
("leftouter", proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER),
("rightouter", proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER),
("leftanti", proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI),
("leftsemi", proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI),
]:
joined_df = df_left.join(df_right, on=col("name"), how=join_type_str)._plan.collect()
self.assertEqual(joined_df.root.join.join_type, join_type)


if __name__ == "__main__":
import unittest
Expand Down