Skip to content

Add GroupBy.aggregate (and tpch-1 query to examples) #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/API_specification/dataframe_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .typing import DType, Scalar

__all__ = [
"Aggregation",
"Bool",
"Column",
"DataFrame",
Expand Down
77 changes: 76 additions & 1 deletion spec/API_specification/dataframe_api/groupby_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from .dataframe_object import DataFrame


__all__ = ['GroupBy']
__all__ = [
"Aggregation",
"GroupBy",
]


class GroupBy(Protocol):
Expand Down Expand Up @@ -51,3 +54,75 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFr

def size(self) -> DataFrame:
...

def aggregate(self, *aggregation: Aggregation) -> DataFrame:
"""
Aggregate columns according to given aggregation function.

Examples
--------
>>> df: DataFrame
>>> namespace = df.__dataframe_namespace__()
>>> df.group_by('year').aggregate(
... namespace.Aggregation.sum('l_quantity').rename('sum_qty'),
... namespace.Aggregation.mean('l_quantity').rename('avg_qty'),
... namespace.Aggregation.mean('l_extended_price').rename('avg_price'),
... namespace.Aggregation.mean('l_discount').rename('avg_disc'),
... namespace.Aggregation.size().rename('count_order'),
... )
"""
...

class Aggregation(Protocol):
def rename(self, name: str) -> Aggregation:
"""
Assign given name to output of aggregation.

If not called, the column's name will be used as the output name.
"""
...

@classmethod
def any(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def all(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def min(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def max(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def sum(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def prod(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def median(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def mean(cls, column: str, *, skip_nulls: bool=True) -> Aggregation:
...

@classmethod
def std(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
...

@classmethod
def var(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
...

@classmethod
def size(cls) -> Aggregation:
...

7 changes: 5 additions & 2 deletions spec/API_specification/dataframe_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from dataframe_api.column_object import Column
from dataframe_api.dataframe_object import DataFrame
from dataframe_api.groupby_object import GroupBy
from dataframe_api.groupby_object import GroupBy, Aggregation as AggregationT

if TYPE_CHECKING:
from .dtypes import (
Expand Down Expand Up @@ -112,6 +112,8 @@ def __init__(
class String():
...

Aggregation: AggregationT

def concat(self, dataframes: Sequence[DataFrame]) -> DataFrame:
...

Expand Down Expand Up @@ -146,7 +148,7 @@ def is_null(self, value: object, /) -> bool:

def is_dtype(self, dtype: Any, kind: str | tuple[str, ...]) -> bool:
...

def date(self, year: int, month: int, day: int) -> Scalar:
...

Expand All @@ -164,6 +166,7 @@ def __column_consortium_standard__(


__all__ = [
"Aggregation",
"Column",
"DataFrame",
"DType",
Expand Down
37 changes: 37 additions & 0 deletions spec/API_specification/examples/tpch/q1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
from dataframe_api.typing import SupportsDataFrameAPI


def query(lineitem_raw: SupportsDataFrameAPI) -> Any:
lineitem = lineitem_raw.__dataframe_consortium_standard__()
namespace = lineitem.__dataframe_namespace__()

mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2)
lineitem = lineitem.assign(
(
lineitem.get_column_by_name("l_extended_price")
* (1 - lineitem.get_column_by_name("l_discount"))
).rename("l_disc_price"),
(
lineitem.get_column_by_name("l_extended_price")
* (1 - lineitem.get_column_by_name("l_discount"))
* (1 + lineitem.get_column_by_name("l_tax"))
).rename("l_charge"),
)
result = (
lineitem.filter(mask)
.group_by("l_returnflag", "l_linestatus")
.aggregate(
namespace.Aggregation.sum("l_quantity").rename("sum_qty"),
namespace.Aggregation.sum("l_extendedprice").rename("sum_base_price"),
namespace.Aggregation.sum("l_disc_price").rename("sum_disc_price"),
namespace.Aggregation.sum("change").rename("sum_charge"),
namespace.Aggregation.mean("l_quantity").rename("avg_qty"),
namespace.Aggregation.mean("l_discount").rename("avg_disc"),
namespace.Aggregation.size().rename("count_order"),
)
.sort("l_returnflag", "l_linestatus")
)
return result.dataframe
3 changes: 1 addition & 2 deletions spec/API_specification/examples/tpch/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def query(
* (1 - result.get_column_by_name("l_discount"))
).rename("revenue")
result = result.assign(new_column)
result = result.select("revenue", "n_name")
result = result.group_by("n_name").sum()
result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue"))

return result.dataframe
1 change: 1 addition & 0 deletions spec/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
('py:class', 'Scalar'),
('py:class', 'Bool'),
('py:class', 'optional'),
('py:class', 'Aggregation'),
('py:class', 'NullType'),
('py:class', 'Namespace'),
('py:class', 'SupportsDataFrameAPI'),
Expand Down