Skip to content

Commit 21be6ff

Browse files
committed
add Aggregation API
1 parent e2a18d4 commit 21be6ff

File tree

5 files changed

+82
-7
lines changed

5 files changed

+82
-7
lines changed

spec/API_specification/dataframe_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"Duration",
4141
"String",
4242
"is_dtype",
43+
"Aggregation",
4344
]
4445

4546

spec/API_specification/dataframe_api/groupby_object.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Protocol
44

55
if TYPE_CHECKING:
66
from .dataframe_object import DataFrame
77

88

9-
__all__ = ['GroupBy']
9+
__all__ = [
10+
"Aggregation",
11+
"GroupBy",
12+
]
1013

1114

1215
class GroupBy:
@@ -51,3 +54,71 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFr
5154

5255
def size(self) -> DataFrame:
5356
...
57+
58+
def aggregate(self, *aggregation: Aggregation) -> DataFrame:
59+
"""
60+
Aggregate columns according to given aggregation function.
61+
62+
Examples
63+
--------
64+
>>> df: DataFrame
65+
>>> namespace = df.__dataframe_namespace__()
66+
>>> df.group_by('year').aggregate(
67+
... namespace.Aggregation.sum('l_quantity').rename('sum_qty'),
68+
... namespace.Aggregation.mean('l_quantity').rename('avg_qty'),
69+
... namespace.Aggregation.mean('l_extended_price').rename('avg_price'),
70+
... namespace.Aggregation.mean('l_discount').rename('avg_disc'),
71+
... namespace.Aggregation.size().rename('count_order'),
72+
... )
73+
"""
74+
...
75+
76+
class Aggregation(Protocol):
77+
def rename(self, name: str) -> Aggregation:
78+
"""Assign given name to output of aggregation. """
79+
...
80+
81+
@classmethod
82+
def any(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
83+
...
84+
85+
@classmethod
86+
def all(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
87+
...
88+
89+
@classmethod
90+
def min(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
91+
...
92+
93+
@classmethod
94+
def max(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
95+
...
96+
97+
@classmethod
98+
def sum(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
99+
...
100+
101+
@classmethod
102+
def prod(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
103+
...
104+
105+
@classmethod
106+
def median(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
107+
...
108+
109+
@classmethod
110+
def mean(cls, column: str, *, skip_nulls: bool=True) -> Aggregation:
111+
...
112+
113+
@classmethod
114+
def std(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
115+
...
116+
117+
@classmethod
118+
def var(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
119+
...
120+
121+
@classmethod
122+
def size(cls) -> Aggregation:
123+
...
124+

spec/API_specification/dataframe_api/typing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from dataframe_api.column_object import Column
1717
from dataframe_api.dataframe_object import DataFrame
18-
from dataframe_api.groupby_object import GroupBy
18+
from dataframe_api.groupby_object import GroupBy, Aggregation as AggregationT
1919

2020
if TYPE_CHECKING:
2121
from .dtypes import (
@@ -147,7 +147,9 @@ def is_null(value: object, /) -> bool:
147147
@staticmethod
148148
def is_dtype(dtype: Any, kind: str | tuple[str, ...]) -> bool:
149149
...
150-
150+
151+
class Aggregation(AggregationT):
152+
...
151153

152154
class SupportsDataFrameAPI(Protocol):
153155
def __dataframe_consortium_standard__(
@@ -163,6 +165,7 @@ def __column_consortium_standard__(
163165

164166

165167
__all__ = [
168+
"Aggregation",
166169
"Column",
167170
"DataFrame",
168171
"DType",

spec/API_specification/examples/tpch/q5.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,9 @@ def query(
6565

6666
new_column = (
6767
result.get_column_by_name("l_extendedprice")
68-
* (1 - result.get_column_by_name("l_discount"))
68+
* (result.get_column_by_name("l_discount") * -1 + 1)
6969
).rename("revenue")
7070
result = result.assign(new_column)
71-
result = result.select(["revenue", "n_name"])
72-
result = result.group_by("n_name").sum()
71+
result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue"))
7372

7473
return result.dataframe

spec/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
('py:class', 'Scalar'),
8585
('py:class', 'Bool'),
8686
('py:class', 'optional'),
87+
('py:class', 'Aggregation'),
8788
('py:class', 'NullType'),
8889
('py:class', 'Namespace'),
8990
('py:class', 'SupportsDataFrameAPI'),

0 commit comments

Comments
 (0)