diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index bd374dcf814e6..e3a7e8c73355d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -37,6 +37,7 @@ Expression, LiteralExpression, SQLExpression, + ScalarFunctionExpression, ) from pyspark.sql.types import ( StructType, @@ -48,25 +49,13 @@ from pyspark.sql.connect.client import RemoteSparkSession -class GroupingFrame(object): - - MeasuresType = Union[Sequence[Tuple["ExpressionOrString", str]], Dict[str, str]] - OptMeasuresType = Optional[MeasuresType] - +class GroupedData(object): def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> None: self._df = df self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in grouping_cols] - def agg(self, exprs: Optional[MeasuresType] = None) -> "DataFrame": - - # Normalize the dictionary into a list of tuples. - if isinstance(exprs, Dict): - measures = list(exprs.items()) - elif isinstance(exprs, List): - measures = exprs - else: - measures = [] - + def agg(self, measures: Sequence[Expression]) -> "DataFrame": + assert len(measures) > 0, "exprs should not be empty" res = DataFrame.withPlan( plan.Aggregate( child=self._df._plan, @@ -77,23 +66,27 @@ def agg(self, exprs: Optional[MeasuresType] = None) -> "DataFrame": ) return res - def _map_cols_to_dict(self, fun: str, cols: List[Union[Column, str]]) -> Dict[str, str]: - return {x if isinstance(x, str) else x.name(): fun for x in cols} + def _map_cols_to_expression( + self, fun: str, col: Union[Expression, str] + ) -> Sequence[Expression]: + return [ + ScalarFunctionExpression(fun, Column(col)) if isinstance(col, str) else col, + ] - def min(self, *cols: Union[Column, str]) -> "DataFrame": - expr = self._map_cols_to_dict("min", list(cols)) + def min(self, col: Union[Expression, str]) -> "DataFrame": + expr = self._map_cols_to_expression("min", col) return self.agg(expr) - def max(self, *cols: Union[Column, str]) -> "DataFrame": - expr = self._map_cols_to_dict("max", list(cols)) + def max(self, col: Union[Expression, str]) -> "DataFrame": + expr = self._map_cols_to_expression("max", col) return self.agg(expr) - def sum(self, *cols: Union[Column, str]) -> "DataFrame": - expr = self._map_cols_to_dict("sum", list(cols)) + def sum(self, col: Union[Expression, str]) -> "DataFrame": + expr = self._map_cols_to_expression("sum", col) return self.agg(expr) def count(self) -> "DataFrame": - return self.agg([(LiteralExpression(1), "count")]) + return self.agg([ScalarFunctionExpression("count", LiteralExpression(1))]) class DataFrame(object): @@ -162,8 +155,18 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": return DataFrame.withPlan(plan.Project(self._plan, *sql_expr), session=self._session) - def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame": - return self.groupBy().agg(exprs) + def agg(self, *exprs: Union[Expression, Dict[str, str]]) -> "DataFrame": + if not exprs: + raise ValueError("Argument 'exprs' must not be empty") + + if len(exprs) == 1 and isinstance(exprs[0], dict): + measures = [ScalarFunctionExpression(f, Column(e)) for e, f in exprs[0].items()] + return self.groupBy().agg(measures) + else: + # other expressions + assert all(isinstance(c, Expression) for c in exprs), "all exprs should be Expression" + exprs = cast(Tuple[Expression, ...], exprs) + return self.groupBy().agg(exprs) def alias(self, alias: str) -> "DataFrame": return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), session=self._session) @@ -208,7 +211,7 @@ def sparkSession(self) -> "RemoteSparkSession": def count(self) -> int: """Returns the number of rows in the data frame""" - pdd = self.agg([(LiteralExpression(1), "count")]).toPandas() + pdd = self.agg(ScalarFunctionExpression("count", LiteralExpression(1))).toPandas() if pdd is None: raise Exception("Empty result") return pdd.iloc[0, 0] @@ -340,8 +343,8 @@ def first(self) -> Optional[Row]: """ return self.head() - def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame: - return GroupingFrame(self, *cols) + def groupBy(self, *cols: "ColumnOrName") -> GroupedData: + return GroupedData(self, *cols) @overload def head(self) -> Optional[Row]: diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 8aadc3dc4fa5c..853b1a6dc0e1a 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -20,7 +20,6 @@ List, Optional, Sequence, - Tuple, Union, cast, TYPE_CHECKING, @@ -558,29 +557,19 @@ def _repr_html_(self) -> str: class Aggregate(LogicalPlan): - MeasureType = Tuple["ExpressionOrString", str] - MeasuresType = Sequence[MeasureType] - OptMeasuresType = Optional[MeasuresType] - def __init__( self, child: Optional["LogicalPlan"], grouping_cols: List[Column], - measures: OptMeasuresType, + measures: Sequence[Expression], ) -> None: super().__init__(child) self.grouping_cols = grouping_cols - self.measures = measures if measures is not None else [] + self.measures = measures - def _convert_measure(self, m: MeasureType, session: "RemoteSparkSession") -> proto.Expression: - exp, fun = m + def _convert_measure(self, m: Expression, session: "RemoteSparkSession") -> proto.Expression: proto_expr = proto.Expression() - measure = proto_expr.unresolved_function - measure.parts.append(fun) - if type(exp) is str: - measure.arguments.append(self.unresolved_attr(exp)) - else: - measure.arguments.append(cast(Expression, exp).to_plan(session)) + proto_expr.CopyFrom(m.to_plan(session)) return proto_expr def plan(self, session: "RemoteSparkSession") -> proto.Relation: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index fa27b6099413e..76e28159abb03 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -451,6 +451,13 @@ def test_alias(self) -> None: self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect() self.assertIn("(this, is, not)", str(exc.exception)) + def test_agg_with_two_agg_exprs(self): + # SPARK-41230: test dataframe.agg() + self.assert_eq( + self.connect.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(), + self.spark.read.table(self.tbl_name).agg({"name": "min", "id": "max"}).toPandas(), + ) + class ChannelBuilderTests(ReusedPySparkTestCase): def test_invalid_connection_strings(self):