Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Expression,
LiteralExpression,
SQLExpression,
ScalarFunctionExpression,
)
from pyspark.sql.types import (
StructType,
Expand All @@ -48,25 +49,13 @@
from pyspark.sql.connect.client import RemoteSparkSession


class GroupingFrame(object):

MeasuresType = Union[Sequence[Tuple["ExpressionOrString", str]], Dict[str, str]]
OptMeasuresType = Optional[MeasuresType]

class GroupedData(object):
def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> None:
self._df = df
self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in grouping_cols]

def agg(self, exprs: Optional[MeasuresType] = None) -> "DataFrame":

# Normalize the dictionary into a list of tuples.
if isinstance(exprs, Dict):
measures = list(exprs.items())
elif isinstance(exprs, List):
measures = exprs
else:
measures = []

def agg(self, measures: Sequence[Expression]) -> "DataFrame":
assert len(measures) > 0, "exprs should not be empty"
res = DataFrame.withPlan(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In OSS we have an assert here

>>> df.groupBy("state").agg()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/martin.grund/Development/spark/python/pyspark/sql/group.py", line 162, in agg
    assert exprs, "exprs should not be empty"
AssertionError: exprs should not be empty

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

plan.Aggregate(
child=self._df._plan,
Expand All @@ -77,23 +66,27 @@ def agg(self, exprs: Optional[MeasuresType] = None) -> "DataFrame":
)
return res

def _map_cols_to_dict(self, fun: str, cols: List[Union[Column, str]]) -> Dict[str, str]:
return {x if isinstance(x, str) else x.name(): fun for x in cols}
def _map_cols_to_expression(
self, fun: str, col: Union[Expression, str]
) -> Sequence[Expression]:
return [
ScalarFunctionExpression(fun, Column(col)) if isinstance(col, str) else col,
]

def min(self, *cols: Union[Column, str]) -> "DataFrame":
expr = self._map_cols_to_dict("min", list(cols))
def min(self, col: Union[Expression, str]) -> "DataFrame":
expr = self._map_cols_to_expression("min", col)
return self.agg(expr)

def max(self, *cols: Union[Column, str]) -> "DataFrame":
expr = self._map_cols_to_dict("max", list(cols))
def max(self, col: Union[Expression, str]) -> "DataFrame":
expr = self._map_cols_to_expression("max", col)
return self.agg(expr)

def sum(self, *cols: Union[Column, str]) -> "DataFrame":
expr = self._map_cols_to_dict("sum", list(cols))
def sum(self, col: Union[Expression, str]) -> "DataFrame":
expr = self._map_cols_to_expression("sum", col)
return self.agg(expr)

def count(self) -> "DataFrame":
return self.agg([(LiteralExpression(1), "count")])
return self.agg([ScalarFunctionExpression("count", LiteralExpression(1))])


class DataFrame(object):
Expand Down Expand Up @@ -162,8 +155,18 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":

return DataFrame.withPlan(plan.Project(self._plan, *sql_expr), session=self._session)

def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
return self.groupBy().agg(exprs)
def agg(self, *exprs: Union[Expression, Dict[str, str]]) -> "DataFrame":
if not exprs:
raise ValueError("Argument 'exprs' must not be empty")

if len(exprs) == 1 and isinstance(exprs[0], dict):
measures = [ScalarFunctionExpression(f, Column(e)) for e, f in exprs[0].items()]
return self.groupBy().agg(measures)
else:
# other expressions
assert all(isinstance(c, Expression) for c in exprs), "all exprs should be Expression"
exprs = cast(Tuple[Expression, ...], exprs)
return self.groupBy().agg(exprs)

def alias(self, alias: str) -> "DataFrame":
return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), session=self._session)
Expand Down Expand Up @@ -208,7 +211,7 @@ def sparkSession(self) -> "RemoteSparkSession":

def count(self) -> int:
"""Returns the number of rows in the data frame"""
pdd = self.agg([(LiteralExpression(1), "count")]).toPandas()
pdd = self.agg(ScalarFunctionExpression("count", LiteralExpression(1))).toPandas()
if pdd is None:
raise Exception("Empty result")
return pdd.iloc[0, 0]
Expand Down Expand Up @@ -340,8 +343,8 @@ def first(self) -> Optional[Row]:
"""
return self.head()

def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame:
return GroupingFrame(self, *cols)
def groupBy(self, *cols: "ColumnOrName") -> GroupedData:
return GroupedData(self, *cols)

@overload
def head(self) -> Optional[Row]:
Expand Down
19 changes: 4 additions & 15 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
List,
Optional,
Sequence,
Tuple,
Union,
cast,
TYPE_CHECKING,
Expand Down Expand Up @@ -558,29 +557,19 @@ def _repr_html_(self) -> str:


class Aggregate(LogicalPlan):
MeasureType = Tuple["ExpressionOrString", str]
MeasuresType = Sequence[MeasureType]
OptMeasuresType = Optional[MeasuresType]

def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: List[Column],
measures: OptMeasuresType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no longer a way to call this with empty measures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a sequence now and I think it can be len=0 which is empty measures?

measures: Sequence[Expression],
) -> None:
super().__init__(child)
self.grouping_cols = grouping_cols
self.measures = measures if measures is not None else []
self.measures = measures

def _convert_measure(self, m: MeasureType, session: "RemoteSparkSession") -> proto.Expression:
exp, fun = m
def _convert_measure(self, m: Expression, session: "RemoteSparkSession") -> proto.Expression:
proto_expr = proto.Expression()
measure = proto_expr.unresolved_function
measure.parts.append(fun)
if type(exp) is str:
measure.arguments.append(self.unresolved_attr(exp))
else:
measure.arguments.append(cast(Expression, exp).to_plan(session))
proto_expr.CopyFrom(m.to_plan(session))
return proto_expr

def plan(self, session: "RemoteSparkSession") -> proto.Relation:
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@ def test_alias(self) -> None:
self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect()
self.assertIn("(this, is, not)", str(exc.exception))

def test_agg_with_two_agg_exprs(self):
# SPARK-41230: test dataframe.agg()
self.assert_eq(
self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(),
)


class ChannelBuilderTests(ReusedPySparkTestCase):
def test_invalid_connection_strings(self):
Expand Down