Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,44 @@ def coalesce(self, num_partitions: int) -> "DataFrame":
def describe(self, cols: List[ColumnRef]) -> Any:
...

def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame":
"""Return a new :class:`DataFrame` with duplicate rows removed,
optionally only deduplicating based on certain columns.

.. versionadded:: 3.4.0

Parameters
----------
subset : List of column names, optional
List of columns to use for duplicate comparison (default All columns).

Returns
-------
:class:`DataFrame`
DataFrame without duplicated rows.
"""
if subset is None:
return DataFrame.withPlan(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
else:
return DataFrame.withPlan(
plan.Deduplicate(child=self._plan, column_names=subset), session=self._session
)

def distinct(self) -> "DataFrame":
"""Returns all distinct rows."""
all_cols = self.columns
gf = self.groupBy(*all_cols)
return gf.agg()
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.

.. versionadded:: 3.4.0

Returns
-------
:class:`DataFrame`
DataFrame with distinct rows.
"""
return DataFrame.withPlan(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)

def drop(self, *cols: "ColumnOrString") -> "DataFrame":
all_cols = self.columns
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,45 @@ def _repr_html_(self) -> str:
"""


class Deduplicate(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
all_columns_as_keys: bool = False,
column_names: Optional[List[str]] = None,
) -> None:
super().__init__(child)
self.all_columns_as_keys = all_columns_as_keys
self.column_names = column_names

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
if self.column_names is not None:
plan.deduplicate.column_names.extend(self.column_names)
return plan

def print(self, indent: int = 0) -> str:
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
return (
f"{' ' * indent}<all_columns_as_keys={self.all_columns_as_keys} "
f"column_names={self.column_names}>\n{c_buf}"
)

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b></b>Deduplicate<br />
all_columns_as_keys: {self.all_columns_as_keys} <br />
column_names: {self.column_names} <br />
{self._child_repr_()}
</li>
</uL>
"""


class Sort(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str]
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ def test_sample(self):
self.assertEqual(plan.root.sample.with_replacement, True)
self.assertEqual(plan.root.sample.seed.seed, -1)

def test_deduplicate(self):
df = self.connect.readTable(table_name=self.tbl_name)

distinct_plan = df.distinct()._plan.to_proto(self.connect)
self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, True)
self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0)

deduplicate_on_all_columns_plan = df.dropDuplicates()._plan.to_proto(self.connect)
self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True)
self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0)

deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", "height"])._plan.to_proto(
self.connect
)
self.assertEqual(
deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False
)
self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names), 2)

def test_relation_alias(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.alias("table_alias")._plan.to_proto(self.connect)
Expand Down