Skip to content

Commit

Permalink
Implement window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Oct 17, 2024
1 parent 95675c5 commit f60ab59
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 11 deletions.
13 changes: 8 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,9 +1073,9 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
)
```
Expand All @@ -1086,7 +1086,7 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
newkey=Column("oldkey")
newkey=Column("oldkey")
)
```
"""
Expand All @@ -1099,7 +1099,7 @@ def mutate(self, **kwargs) -> "Self":
"Use a different name for the new column.",
)
for col_name, expr in kwargs.items():
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
raise DataChainColumnError(
col_name, f"Cannot infer type with expression {expr}"
)
Expand All @@ -1111,6 +1111,9 @@ def mutate(self, **kwargs) -> "Self":
# renaming existing column
for signal in schema.db_signals(name=value.name, as_columns=True):
mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
elif isinstance(value, Func):
# adding new signal
mutated[name] = value.get_column(schema)
else:
# adding new signal
mutated[name] = value
Expand Down
8 changes: 7 additions & 1 deletion src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func
from .func import Func, Window
from .window import dense_rank, first, rank, row_number

__all__ = [
"Func",
"Window",
"any_value",
"avg",
"collect",
"concat",
"count",
"dense_rank",
"first",
"max",
"min",
"rank",
"row_number",
"sum",
]
2 changes: 1 addition & 1 deletion src/datachain/lib/func/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def sum(col: str) -> Func:


def avg(col: str) -> Func:
return Func(inner=dc_func.aggregate.avg, col=col)
return Func(inner=dc_func.aggregate.avg, col=col, result_type=float)


def min(col: str) -> Func:
Expand Down
44 changes: 44 additions & 0 deletions src/datachain/lib/func/func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Callable, Optional

from sqlalchemy import desc

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError
from datachain.query.schema import Column, ColumnMeta
Expand All @@ -9,18 +11,44 @@
from datachain.lib.signal_schema import SignalSchema


class Window:
def __init__(self, *, partition_by: str, order_by: str, desc: bool = False) -> None:
self.partition_by = partition_by
self.order_by = order_by
self.desc = desc


class Func:
def __init__(
self,
inner: Callable,
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
is_window: bool = False,
window: Optional[Window] = None,
) -> None:
self.inner = inner
self.col = col
self.result_type = result_type
self.is_array = is_array
self.is_window = is_window
self.window = window

def over(self, window: Window) -> "Func":
if not self.is_window:
raise DataChainColumnError(
str(self.inner), "Window function is not supported"
)

return Func(
self.inner,
self.col,
self.result_type,
self.is_array,
self.is_window,
window,
)

@property
def db_col(self) -> Optional[str]:
Expand Down Expand Up @@ -56,6 +84,22 @@ def get_column(
else:
func_col = self.inner()

if self.is_window:
if not self.window:
raise DataChainColumnError(
str(self.inner), "Window function requires window"
)
func_col = func_col.over(
partition_by=self.window.partition_by,
order_by=(
desc(self.window.order_by)
if self.window.desc
else self.window.order_by
),
)

func_col.type = sql_type

if label:
func_col = func_col.label(label)

Expand Down
19 changes: 19 additions & 0 deletions src/datachain/lib/func/window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from sqlalchemy import func as sa_func

from .func import Func


def row_number() -> Func:
return Func(inner=sa_func.row_number, result_type=int, is_window=True)


def rank() -> Func:
return Func(inner=sa_func.rank, result_type=int, is_window=True)


def dense_rank() -> Func:
return Func(inner=sa_func.dense_rank, result_type=int, is_window=True)


def first(col: str) -> Func:
return Func(inner=sa_func.first_value, col=col, is_window=True)
4 changes: 4 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datachain.lib.convert.unflatten import unflatten_to_json_pos
from datachain.lib.data_model import DataModel, DataType, DataValue
from datachain.lib.file import File
from datachain.lib.func import Func
from datachain.lib.model_store import ModelStore
from datachain.lib.utils import DataChainParamsError
from datachain.query.schema import DEFAULT_DELIMITER, Column
Expand Down Expand Up @@ -494,6 +495,9 @@ def mutate(self, args_map: dict) -> "SignalSchema":
# changing the type of existing signal, e.g File -> ImageFile
del new_values[name]
new_values[name] = args_map[name]
elif isinstance(value, Func):
# adding new signal with function
new_values[name] = value.get_result_type(self)
else:
# adding new signal
new_values[name] = sql_to_python(value)
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand All @@ -25,6 +24,7 @@
from typing_extensions import Self

from datachain.catalog import Catalog
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf_signature import UdfSignature
from datachain.query.batch import RowsOutput

Expand Down Expand Up @@ -172,7 +172,7 @@ def teardown(self):
def _init(
self,
sign: "UdfSignature",
params: SignalSchema,
params: "SignalSchema",
func: Optional[Callable],
):
self.params = params
Expand All @@ -183,7 +183,7 @@ def _init(
def _create(
cls,
sign: "UdfSignature",
params: SignalSchema,
params: "SignalSchema",
) -> "Self":
if isinstance(sign.func, AbstractUDF):
if not isinstance(sign.func, cls): # type: ignore[unreachable]
Expand Down
105 changes: 104 additions & 1 deletion tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,7 +2131,7 @@ def test_group_by_int(test_session):
"cnt": "int",
"cnt_col": "int",
"sum": "int",
"avg": "int",
"avg": "float",
"min": "int",
"max": "int",
"value": "int",
Expand Down Expand Up @@ -2443,3 +2443,106 @@ def test_group_by_error(test_session):
SignalResolvingError, match="cannot resolve signal name 'col3': is not found"
):
dc.group_by(foo=func.sum("col2"), partition_by="col3")


@pytest.mark.parametrize("desc", [True, False])
def test_window_functions(test_session, desc):
from datachain import func

window = func.Window(partition_by="col1", order_by="col2", desc=desc)

ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)
.mutate(
row_number=func.row_number().over(window),
rank=func.rank().over(window),
dense_rank=func.dense_rank().over(window),
first=func.first("col2").over(window),
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"col2": "int",
"row_number": "int",
"rank": "int",
"dense_rank": "int",
"first": "int",
}
assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts(
[
{
"col1": "a",
"col2": 1,
"row_number": 2 if desc else 1,
"rank": 2 if desc else 1,
"dense_rank": 2 if desc else 1,
"first": 2 if desc else 1,
},
{
"col1": "a",
"col2": 2,
"row_number": 1 if desc else 2,
"rank": 1 if desc else 2,
"dense_rank": 1 if desc else 2,
"first": 2 if desc else 1,
},
{
"col1": "b",
"col2": 3,
"row_number": 3 if desc else 1,
"rank": 3 if desc else 1,
"dense_rank": 3 if desc else 1,
"first": 5 if desc else 3,
},
{
"col1": "b",
"col2": 4,
"row_number": 2,
"rank": 2,
"dense_rank": 2,
"first": 5 if desc else 3,
},
{
"col1": "b",
"col2": 5,
"row_number": 1 if desc else 3,
"rank": 1 if desc else 3,
"dense_rank": 1 if desc else 3,
"first": 5 if desc else 3,
},
{
"col1": "c",
"col2": 6,
"row_number": 1,
"rank": 1,
"dense_rank": 1,
"first": 6,
},
],
"col1",
"col2",
)


def test_window_error(test_session):
from datachain import func

window = func.Window(partition_by="col1", order_by="col2")

dc = DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)

with pytest.raises(DataChainColumnError, match="Window function requires window"):
dc.mutate(first=func.first("col2"))

with pytest.raises(DataChainColumnError, match="Window function is not supported"):
dc.mutate(first=func.sum("col2").over(window))

0 comments on commit f60ab59

Please sign in to comment.