Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Window functions #515

Merged
merged 8 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,9 +1073,9 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
)
```
Expand All @@ -1086,7 +1086,7 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
newkey=Column("oldkey")
newkey=Column("oldkey")
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
)
```
"""
Expand All @@ -1099,7 +1099,7 @@ def mutate(self, **kwargs) -> "Self":
"Use a different name for the new column.",
)
for col_name, expr in kwargs.items():
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
raise DataChainColumnError(
col_name, f"Cannot infer type with expression {expr}"
)
Expand All @@ -1111,6 +1111,9 @@ def mutate(self, **kwargs) -> "Self":
# renaming existing column
for signal in schema.db_signals(name=value.name, as_columns=True):
mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
elif isinstance(value, Func):
# adding new signal
mutated[name] = value.get_column(schema)
else:
# adding new signal
mutated[name] = value
Expand Down
8 changes: 7 additions & 1 deletion src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func
from .func import Func, Window
from .window import dense_rank, first, rank, row_number

__all__ = [
"Func",
"Window",
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
"any_value",
"avg",
"collect",
"concat",
"count",
"dense_rank",
"first",
"max",
"min",
"rank",
"row_number",
"sum",
]
2 changes: 1 addition & 1 deletion src/datachain/lib/func/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def sum(col: str) -> Func:


def avg(col: str) -> Func:
return Func(inner=dc_func.aggregate.avg, col=col)
return Func(inner=dc_func.aggregate.avg, col=col, result_type=float)


def min(col: str) -> Func:
Expand Down
56 changes: 49 additions & 7 deletions src/datachain/lib/func/func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Callable, Optional

from sqlalchemy import desc

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError
from datachain.query.schema import Column, ColumnMeta
Expand All @@ -9,18 +11,44 @@
from datachain.lib.signal_schema import SignalSchema


class Window:
def __init__(self, *, partition_by: str, order_by: str, desc: bool = False) -> None:
self.partition_by = partition_by
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
self.order_by = order_by
self.desc = desc


class Func:
def __init__(
self,
inner: Callable,
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
is_window: bool = False,
window: Optional[Window] = None,
) -> None:
self.inner = inner
self.col = col
self.result_type = result_type
self.is_array = is_array
self.is_window = is_window
self.window = window

def over(self, window: Window) -> "Func":
if not self.is_window:
raise DataChainColumnError(
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
str(self.inner), "Window function is not supported"
)

return Func(
self.inner,
self.col,
self.result_type,
self.is_array,
self.is_window,
window,
)

@property
def db_col(self) -> Optional[str]:
Expand All @@ -33,12 +61,10 @@ def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]:
return list[col_type] if self.is_array else col_type # type: ignore[valid-type]

def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
col_type = self.db_col_type(signals_schema)

if self.result_type:
return self.result_type

if col_type:
if col_type := self.db_col_type(signals_schema):
return col_type

raise DataChainColumnError(
Expand All @@ -49,15 +75,31 @@ def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
def get_column(
self, signals_schema: "SignalSchema", label: Optional[str] = None
) -> Column:
col_type = self.get_result_type(signals_schema)
sql_type = python_to_sql(col_type)

if self.col:
if label == "collect":
print(label)
col_type = self.get_result_type(signals_schema)
col = Column(self.db_col, python_to_sql(col_type))
col = Column(self.db_col, sql_type)
func_col = self.inner(col)
else:
func_col = self.inner()

if self.is_window:
if not self.window:
raise DataChainColumnError(
str(self.inner), "Window function requires window"
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
)
func_col = func_col.over(
partition_by=self.window.partition_by,
order_by=(
desc(self.window.order_by)
if self.window.desc
else self.window.order_by
),
)

func_col.type = sql_type

if label:
func_col = func_col.label(label)

Expand Down
19 changes: 19 additions & 0 deletions src/datachain/lib/func/window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from sqlalchemy import func as sa_func

from .func import Func


def row_number() -> Func:
return Func(inner=sa_func.row_number, result_type=int, is_window=True)


def rank() -> Func:
return Func(inner=sa_func.rank, result_type=int, is_window=True)


def dense_rank() -> Func:
return Func(inner=sa_func.dense_rank, result_type=int, is_window=True)


def first(col: str) -> Func:
return Func(inner=sa_func.first_value, col=col, is_window=True)
4 changes: 4 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datachain.lib.convert.unflatten import unflatten_to_json_pos
from datachain.lib.data_model import DataModel, DataType, DataValue
from datachain.lib.file import File
from datachain.lib.func import Func
from datachain.lib.model_store import ModelStore
from datachain.lib.utils import DataChainParamsError
from datachain.query.schema import DEFAULT_DELIMITER, Column
Expand Down Expand Up @@ -494,6 +495,9 @@ def mutate(self, args_map: dict) -> "SignalSchema":
# changing the type of existing signal, e.g File -> ImageFile
del new_values[name]
new_values[name] = args_map[name]
elif isinstance(value, Func):
# adding new signal with function
new_values[name] = value.get_result_type(self)
else:
# adding new signal
new_values[name] = sql_to_python(value)
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand All @@ -25,6 +24,7 @@
from typing_extensions import Self

from datachain.catalog import Catalog
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf_signature import UdfSignature
from datachain.query.batch import RowsOutput

Expand Down Expand Up @@ -172,7 +172,7 @@ def teardown(self):
def _init(
self,
sign: "UdfSignature",
params: SignalSchema,
params: "SignalSchema",
func: Optional[Callable],
):
self.params = params
Expand All @@ -183,7 +183,7 @@ def _init(
def _create(
cls,
sign: "UdfSignature",
params: SignalSchema,
params: "SignalSchema",
) -> "Self":
if isinstance(sign.func, AbstractUDF):
if not isinstance(sign.func, cls): # type: ignore[unreachable]
Expand Down
105 changes: 104 additions & 1 deletion tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,7 +2131,7 @@ def test_group_by_int(test_session):
"cnt": "int",
"cnt_col": "int",
"sum": "int",
"avg": "int",
"avg": "float",
"min": "int",
"max": "int",
"value": "int",
Expand Down Expand Up @@ -2443,3 +2443,106 @@ def test_group_by_error(test_session):
SignalResolvingError, match="cannot resolve signal name 'col3': is not found"
):
dc.group_by(foo=func.sum("col2"), partition_by="col3")


@pytest.mark.parametrize("desc", [True, False])
def test_window_functions(test_session, desc):
from datachain import func

window = func.Window(partition_by="col1", order_by="col2", desc=desc)

ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)
.mutate(
row_number=func.row_number().over(window),
rank=func.rank().over(window),
dense_rank=func.dense_rank().over(window),
first=func.first("col2").over(window),
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"col2": "int",
"row_number": "int",
"rank": "int",
"dense_rank": "int",
"first": "int",
}
assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts(
[
{
"col1": "a",
"col2": 1,
"row_number": 2 if desc else 1,
"rank": 2 if desc else 1,
"dense_rank": 2 if desc else 1,
"first": 2 if desc else 1,
},
{
"col1": "a",
"col2": 2,
"row_number": 1 if desc else 2,
"rank": 1 if desc else 2,
"dense_rank": 1 if desc else 2,
"first": 2 if desc else 1,
},
{
"col1": "b",
"col2": 3,
"row_number": 3 if desc else 1,
"rank": 3 if desc else 1,
"dense_rank": 3 if desc else 1,
"first": 5 if desc else 3,
},
{
"col1": "b",
"col2": 4,
"row_number": 2,
"rank": 2,
"dense_rank": 2,
"first": 5 if desc else 3,
},
{
"col1": "b",
"col2": 5,
"row_number": 1 if desc else 3,
"rank": 1 if desc else 3,
"dense_rank": 1 if desc else 3,
"first": 5 if desc else 3,
},
{
"col1": "c",
"col2": 6,
"row_number": 1,
"rank": 1,
"dense_rank": 1,
"first": 6,
},
],
"col1",
"col2",
)


def test_window_error(test_session):
from datachain import func

window = func.Window(partition_by="col1", order_by="col2")

dc = DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)

with pytest.raises(DataChainColumnError, match="Window function requires window"):
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
dc.mutate(first=func.first("col2"))

with pytest.raises(DataChainColumnError, match="Window function is not supported"):
dc.mutate(first=func.sum("col2").over(window))