Skip to content

Commit

Permalink
Implement window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Oct 16, 2024
1 parent 95675c5 commit c24c925
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 9 deletions.
13 changes: 8 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,9 +1073,9 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
)
```
Expand All @@ -1086,7 +1086,7 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
newkey=Column("oldkey")
newkey=Column("oldkey")
)
```
"""
Expand All @@ -1099,7 +1099,7 @@ def mutate(self, **kwargs) -> "Self":
"Use a different name for the new column.",
)
for col_name, expr in kwargs.items():
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
raise DataChainColumnError(
col_name, f"Cannot infer type with expression {expr}"
)
Expand All @@ -1111,6 +1111,9 @@ def mutate(self, **kwargs) -> "Self":
# renaming existing column
for signal in schema.db_signals(name=value.name, as_columns=True):
mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
elif isinstance(value, Func):
# adding new signal
mutated[name] = value.get_column(schema)
else:
# adding new signal
mutated[name] = value
Expand Down
6 changes: 5 additions & 1 deletion src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func
from .func import Func, Window
from .window import first, row_number

__all__ = [
"Func",
"Window",
"any_value",
"avg",
"collect",
"concat",
"count",
"first",
"max",
"min",
"row_number",
"sum",
]
39 changes: 39 additions & 0 deletions src/datachain/lib/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,50 @@
from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError
from datachain.query.schema import Column, ColumnMeta
from datachain.sql.types import NullType

if TYPE_CHECKING:
from datachain import DataType
from datachain.lib.signal_schema import SignalSchema


class Window:
def __init__(self, *, partition_by: str, order_by: str) -> None:
self.partition_by = partition_by
self.order_by = order_by


class Func:
def __init__(
self,
inner: Callable,
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
is_window: bool = False,
window: Optional[Window] = None,
) -> None:
self.inner = inner
self.col = col
self.result_type = result_type
self.is_array = is_array
self.is_window = is_window
self.window = window

def over(self, window: Window) -> "Func":
if not self.is_window:
raise DataChainColumnError(

Check warning on line 38 in src/datachain/lib/func/func.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/func/func.py#L38

Added line #L38 was not covered by tests
str(self.inner), "Window function is not supported"
)

return Func(
self.inner,
self.col,
self.result_type,
self.is_array,
self.is_window,
window,
)

@property
def db_col(self) -> Optional[str]:
Expand Down Expand Up @@ -56,6 +82,19 @@ def get_column(
else:
func_col = self.inner()

if self.is_window:
if not self.window:
raise DataChainColumnError(

Check warning on line 87 in src/datachain/lib/func/func.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/func/func.py#L87

Added line #L87 was not covered by tests
str(self.inner), "Window function requires window"
)
func_col = func_col.over(
partition_by=self.window.partition_by,
order_by=self.window.order_by,
)

if isinstance(func_col.type, NullType):
func_col.type = sql_type

if label:
func_col = func_col.label(label)

Expand Down
11 changes: 11 additions & 0 deletions src/datachain/lib/func/window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy import func as sa_func

from .func import Func


def row_number() -> Func:
return Func(inner=sa_func.row_number, result_type=int, is_window=True)


def first(col: str) -> Func:
return Func(inner=sa_func.first_value, col=col, is_window=True)
4 changes: 4 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datachain.lib.convert.unflatten import unflatten_to_json_pos
from datachain.lib.data_model import DataModel, DataType, DataValue
from datachain.lib.file import File
from datachain.lib.func import Func
from datachain.lib.model_store import ModelStore
from datachain.lib.utils import DataChainParamsError
from datachain.query.schema import DEFAULT_DELIMITER, Column
Expand Down Expand Up @@ -494,6 +495,9 @@ def mutate(self, args_map: dict) -> "SignalSchema":
# changing the type of existing signal, e.g File -> ImageFile
del new_values[name]
new_values[name] = args_map[name]
elif isinstance(value, Func):
# adding new signal with function
new_values[name] = value.get_result_type(self)
else:
# adding new signal
new_values[name] = sql_to_python(value)
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand All @@ -25,6 +24,7 @@
from typing_extensions import Self

from datachain.catalog import Catalog
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf_signature import UdfSignature
from datachain.query.batch import RowsOutput

Expand Down Expand Up @@ -172,7 +172,7 @@ def teardown(self):
def _init(
self,
sign: "UdfSignature",
params: SignalSchema,
params: "SignalSchema",
func: Optional[Callable],
):
self.params = params
Expand All @@ -183,7 +183,7 @@ def _init(
def _create(
cls,
sign: "UdfSignature",
params: SignalSchema,
params: "SignalSchema",
) -> "Self":
if isinstance(sign.func, AbstractUDF):
if not isinstance(sign.func, cls): # type: ignore[unreachable]
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,3 +2443,69 @@ def test_group_by_error(test_session):
SignalResolvingError, match="cannot resolve signal name 'col3': is not found"
):
dc.group_by(foo=func.sum("col2"), partition_by="col3")


def test_window_row_number(test_session):
from datachain import func

window = func.Window(partition_by="col1", order_by="col2")

ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)
.mutate(row_number=func.row_number().over(window))
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"col2": "int",
"row_number": "int",
}
assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts(
[
{"col1": "a", "col2": 1, "row_number": 1},
{"col1": "a", "col2": 2, "row_number": 2},
{"col1": "b", "col2": 3, "row_number": 1},
{"col1": "b", "col2": 4, "row_number": 2},
{"col1": "b", "col2": 5, "row_number": 3},
{"col1": "c", "col2": 6, "row_number": 1},
],
"col1",
"col2",
)


def test_window_first(test_session):
from datachain import func

window = func.Window(partition_by="col1", order_by="col2")

ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)
.group_by(
first=func.first("col2").over(window),
partition_by="col1",
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"first": "int",
}
assert sorted_dicts(ds.to_records(), "col1") == sorted_dicts(
[
{"col1": "a", "first": 1},
{"col1": "b", "first": 3},
{"col1": "c", "first": 6},
],
"col1",
)

0 comments on commit c24c925

Please sign in to comment.