diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 138e2a131..44f2893a8 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1073,9 +1073,9 @@ def mutate(self, **kwargs) -> "Self": Example: ```py dc.mutate( - area=Column("image.height") * Column("image.width"), - extension=file_ext(Column("file.name")), - dist=cosine_distance(embedding_text, embedding_image) + area=Column("image.height") * Column("image.width"), + extension=file_ext(Column("file.name")), + dist=cosine_distance(embedding_text, embedding_image) ) ``` @@ -1086,7 +1086,7 @@ def mutate(self, **kwargs) -> "Self": Example: ```py dc.mutate( - newkey=Column("oldkey") + newkey=Column("oldkey") ) ``` """ @@ -1099,7 +1099,7 @@ def mutate(self, **kwargs) -> "Self": "Use a different name for the new column.", ) for col_name, expr in kwargs.items(): - if not isinstance(expr, Column) and isinstance(expr.type, NullType): + if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType): raise DataChainColumnError( col_name, f"Cannot infer type with expression {expr}" ) @@ -1111,6 +1111,9 @@ def mutate(self, **kwargs) -> "Self": # renaming existing column for signal in schema.db_signals(name=value.name, as_columns=True): mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr] + elif isinstance(value, Func): + # adding new signal + mutated[name] = value.get_column(schema) else: # adding new signal mutated[name] = value diff --git a/src/datachain/lib/func/__init__.py b/src/datachain/lib/func/__init__.py index 5b4c5524a..45a7d31e9 100644 --- a/src/datachain/lib/func/__init__.py +++ b/src/datachain/lib/func/__init__.py @@ -1,14 +1,21 @@ from .aggregate import any_value, avg, collect, concat, count, max, min, sum -from .func import Func +from .func import Func, Window +from .window import dense_rank, first, percent_rank, rank, row_number __all__ = [ "Func", + "Window", "any_value", "avg", "collect", "concat", "count", + "dense_rank", + "first", "max", "min", + "percent_rank", + "rank", + "row_number", "sum", ] diff --git a/src/datachain/lib/func/aggregate.py b/src/datachain/lib/func/aggregate.py index cfe04beb6..ee5a7b63d 100644 --- a/src/datachain/lib/func/aggregate.py +++ b/src/datachain/lib/func/aggregate.py @@ -16,7 +16,7 @@ def sum(col: str) -> Func: def avg(col: str) -> Func: - return Func(inner=dc_func.aggregate.avg, col=col) + return Func(inner=dc_func.aggregate.avg, col=col, result_type=float) def min(col: str) -> Func: diff --git a/src/datachain/lib/func/func.py b/src/datachain/lib/func/func.py index 84f9c981c..4b11c50fe 100644 --- a/src/datachain/lib/func/func.py +++ b/src/datachain/lib/func/func.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, Callable, Optional +from sqlalchemy import desc + from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.utils import DataChainColumnError from datachain.query.schema import Column, ColumnMeta @@ -9,6 +11,13 @@ from datachain.lib.signal_schema import SignalSchema +class Window: + def __init__(self, *, partition_by: str, order_by: str, desc: bool = False) -> None: + self.partition_by = partition_by + self.order_by = order_by + self.desc = desc + + class Func: def __init__( self, @@ -16,11 +25,30 @@ def __init__( col: Optional[str] = None, result_type: Optional["DataType"] = None, is_array: bool = False, + is_window: bool = False, + window: Optional[Window] = None, ) -> None: self.inner = inner self.col = col self.result_type = result_type self.is_array = is_array + self.is_window = is_window + self.window = window + + def over(self, window: Window) -> "Func": + if not self.is_window: + raise DataChainColumnError( + str(self.inner), "Window function is not supported" + ) + + return Func( + self.inner, + self.col, + self.result_type, + self.is_array, + self.is_window, + window, + ) @property def db_col(self) -> Optional[str]: @@ -56,6 +84,22 @@ def get_column( else: func_col = self.inner() + if self.is_window: + if not self.window: + raise DataChainColumnError( + str(self.inner), "Window function requires window" + ) + func_col = func_col.over( + partition_by=self.window.partition_by, + order_by=( + desc(self.window.order_by) + if self.window.desc + else self.window.order_by + ), + ) + + func_col.type = sql_type + if label: func_col = func_col.label(label) diff --git a/src/datachain/lib/func/window.py b/src/datachain/lib/func/window.py new file mode 100644 index 000000000..e0788eaaa --- /dev/null +++ b/src/datachain/lib/func/window.py @@ -0,0 +1,25 @@ +from sqlalchemy import func as sa_func + +from datachain.sql import functions as dc_func + +from .func import Func + + +def row_number() -> Func: + return Func(inner=sa_func.row_number, result_type=int, is_window=True) + + +def rank() -> Func: + return Func(inner=sa_func.rank, result_type=int, is_window=True) + + +def dense_rank() -> Func: + return Func(inner=dc_func.dense_rank, result_type=int, is_window=True) + + +def percent_rank() -> Func: + return Func(inner=dc_func.percent_rank, result_type=float, is_window=True) + + +def first(col: str) -> Func: + return Func(inner=sa_func.first_value, col=col, is_window=True) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 5cdb1fe5c..9de9d6039 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -27,6 +27,7 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos from datachain.lib.data_model import DataModel, DataType, DataValue from datachain.lib.file import File +from datachain.lib.func import Func from datachain.lib.model_store import ModelStore from datachain.lib.utils import DataChainParamsError from datachain.query.schema import DEFAULT_DELIMITER, Column @@ -494,6 +495,9 @@ def mutate(self, args_map: dict) -> "SignalSchema": # changing the type of existing signal, e.g File -> ImageFile del new_values[name] new_values[name] = args_map[name] + elif isinstance(value, Func): + # adding new signal with function + new_values[name] = value.get_result_type(self) else: # adding new signal new_values[name] = sql_to_python(value) diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 2ce0257d5..8faa13a29 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -11,7 +11,6 @@ from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue from datachain.lib.file import File -from datachain.lib.signal_schema import SignalSchema from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError from datachain.query.batch import ( Batch, @@ -25,6 +24,7 @@ from typing_extensions import Self from datachain.catalog import Catalog + from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf_signature import UdfSignature from datachain.query.batch import RowsOutput @@ -172,7 +172,7 @@ def teardown(self): def _init( self, sign: "UdfSignature", - params: SignalSchema, + params: "SignalSchema", func: Optional[Callable], ): self.params = params @@ -183,7 +183,7 @@ def _init( def _create( cls, sign: "UdfSignature", - params: SignalSchema, + params: "SignalSchema", ) -> "Self": if isinstance(sign.func, AbstractUDF): if not isinstance(sign.func, cls): # type: ignore[unreachable] diff --git a/src/datachain/sql/functions/__init__.py b/src/datachain/sql/functions/__init__.py index c8d4ef0de..17d67832c 100644 --- a/src/datachain/sql/functions/__init__.py +++ b/src/datachain/sql/functions/__init__.py @@ -4,6 +4,7 @@ from .aggregate import avg from .conditional import greatest, least from .random import rand +from .window import dense_rank, percent_rank count = func.count sum = func.sum @@ -14,12 +15,14 @@ "array", "avg", "count", + "dense_rank", "func", "greatest", "least", "max", "min", "path", + "percent_rank", "rand", "string", "sum", diff --git a/src/datachain/sql/functions/window.py b/src/datachain/sql/functions/window.py new file mode 100644 index 000000000..c8bfbee82 --- /dev/null +++ b/src/datachain/sql/functions/window.py @@ -0,0 +1,9 @@ +from sqlalchemy.sql.functions import GenericFunction + + +class dense_rank(GenericFunction): # noqa: N801 + inherit_cache = True + + +class percent_rank(GenericFunction): # noqa: N801 + inherit_cache = True diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 8d2ca157d..686d3d824 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2131,7 +2131,7 @@ def test_group_by_int(test_session): "cnt": "int", "cnt_col": "int", "sum": "int", - "avg": "int", + "avg": "float", "min": "int", "max": "int", "value": "int", @@ -2443,3 +2443,114 @@ def test_group_by_error(test_session): SignalResolvingError, match="cannot resolve signal name 'col3': is not found" ): dc.group_by(foo=func.sum("col2"), partition_by="col3") + + +@pytest.mark.parametrize("desc", [True, False]) +def test_window_functions(test_session, desc): + from datachain import func + + window = func.Window(partition_by="col1", order_by="col2", desc=desc) + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + .mutate( + row_number=func.row_number().over(window), + rank=func.rank().over(window), + dense_rank=func.dense_rank().over(window), + percent_rank=func.percent_rank().over(window), + first=func.first("col2").over(window), + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "col2": "int", + "row_number": "int", + "rank": "int", + "dense_rank": "int", + "percent_rank": "float", + "first": "int", + } + assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts( + [ + { + "col1": "a", + "col2": 1, + "row_number": 2 if desc else 1, + "rank": 2 if desc else 1, + "dense_rank": 2 if desc else 1, + "percent_rank": 1.0 if desc else 0.0, + "first": 2 if desc else 1, + }, + { + "col1": "a", + "col2": 2, + "row_number": 1 if desc else 2, + "rank": 1 if desc else 2, + "dense_rank": 1 if desc else 2, + "percent_rank": 0.0 if desc else 1.0, + "first": 2 if desc else 1, + }, + { + "col1": "b", + "col2": 3, + "row_number": 3 if desc else 1, + "rank": 3 if desc else 1, + "dense_rank": 3 if desc else 1, + "percent_rank": 1.0 if desc else 0.0, + "first": 5 if desc else 3, + }, + { + "col1": "b", + "col2": 4, + "row_number": 2, + "rank": 2, + "dense_rank": 2, + "percent_rank": 0.5, + "first": 5 if desc else 3, + }, + { + "col1": "b", + "col2": 5, + "row_number": 1 if desc else 3, + "rank": 1 if desc else 3, + "dense_rank": 1 if desc else 3, + "percent_rank": 0.0 if desc else 1.0, + "first": 5 if desc else 3, + }, + { + "col1": "c", + "col2": 6, + "row_number": 1, + "rank": 1, + "dense_rank": 1, + "percent_rank": 0.0, + "first": 6, + }, + ], + "col1", + "col2", + ) + + +def test_window_error(test_session): + from datachain import func + + window = func.Window(partition_by="col1", order_by="col2") + + dc = DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + + with pytest.raises(DataChainColumnError, match="Window function requires window"): + dc.mutate(first=func.first("col2")) + + with pytest.raises(DataChainColumnError, match="Window function is not supported"): + dc.mutate(first=func.sum("col2").over(window))