diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 138e2a131..44f2893a8 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1073,9 +1073,9 @@ def mutate(self, **kwargs) -> "Self": Example: ```py dc.mutate( - area=Column("image.height") * Column("image.width"), - extension=file_ext(Column("file.name")), - dist=cosine_distance(embedding_text, embedding_image) + area=Column("image.height") * Column("image.width"), + extension=file_ext(Column("file.name")), + dist=cosine_distance(embedding_text, embedding_image) ) ``` @@ -1086,7 +1086,7 @@ def mutate(self, **kwargs) -> "Self": Example: ```py dc.mutate( - newkey=Column("oldkey") + newkey=Column("oldkey") ) ``` """ @@ -1099,7 +1099,7 @@ def mutate(self, **kwargs) -> "Self": "Use a different name for the new column.", ) for col_name, expr in kwargs.items(): - if not isinstance(expr, Column) and isinstance(expr.type, NullType): + if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType): raise DataChainColumnError( col_name, f"Cannot infer type with expression {expr}" ) @@ -1111,6 +1111,9 @@ def mutate(self, **kwargs) -> "Self": # renaming existing column for signal in schema.db_signals(name=value.name, as_columns=True): mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr] + elif isinstance(value, Func): + # adding new signal + mutated[name] = value.get_column(schema) else: # adding new signal mutated[name] = value diff --git a/src/datachain/lib/func/__init__.py b/src/datachain/lib/func/__init__.py index 5b4c5524a..d42ceb539 100644 --- a/src/datachain/lib/func/__init__.py +++ b/src/datachain/lib/func/__init__.py @@ -1,14 +1,18 @@ from .aggregate import any_value, avg, collect, concat, count, max, min, sum -from .func import Func +from .func import Func, Window +from .window import first, row_number __all__ = [ "Func", + "Window", "any_value", "avg", "collect", "concat", "count", + "first", "max", "min", + "row_number", "sum", ] diff --git a/src/datachain/lib/func/func.py b/src/datachain/lib/func/func.py index 84f9c981c..5738497e9 100644 --- a/src/datachain/lib/func/func.py +++ b/src/datachain/lib/func/func.py @@ -3,12 +3,19 @@ from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.utils import DataChainColumnError from datachain.query.schema import Column, ColumnMeta +from datachain.sql.types import NullType if TYPE_CHECKING: from datachain import DataType from datachain.lib.signal_schema import SignalSchema +class Window: + def __init__(self, *, partition_by: str, order_by: str) -> None: + self.partition_by = partition_by + self.order_by = order_by + + class Func: def __init__( self, @@ -16,11 +23,30 @@ def __init__( col: Optional[str] = None, result_type: Optional["DataType"] = None, is_array: bool = False, + is_window: bool = False, + window: Optional[Window] = None, ) -> None: self.inner = inner self.col = col self.result_type = result_type self.is_array = is_array + self.is_window = is_window + self.window = window + + def over(self, window: Window) -> "Func": + if not self.is_window: + raise DataChainColumnError( + str(self.inner), "Window function is not supported" + ) + + return Func( + self.inner, + self.col, + self.result_type, + self.is_array, + self.is_window, + window, + ) @property def db_col(self) -> Optional[str]: @@ -56,6 +82,19 @@ def get_column( else: func_col = self.inner() + if self.is_window: + if not self.window: + raise DataChainColumnError( + str(self.inner), "Window function requires window" + ) + func_col = func_col.over( + partition_by=self.window.partition_by, + order_by=self.window.order_by, + ) + + if isinstance(func_col.type, NullType): + func_col.type = sql_type + if label: func_col = func_col.label(label) diff --git a/src/datachain/lib/func/window.py b/src/datachain/lib/func/window.py new file mode 100644 index 000000000..f4089c6c5 --- /dev/null +++ b/src/datachain/lib/func/window.py @@ -0,0 +1,11 @@ +from sqlalchemy import func as sa_func + +from .func import Func + + +def row_number() -> Func: + return Func(inner=sa_func.row_number, result_type=int, is_window=True) + + +def first(col: str) -> Func: + return Func(inner=sa_func.first_value, col=col, is_window=True) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 5cdb1fe5c..9de9d6039 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -27,6 +27,7 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos from datachain.lib.data_model import DataModel, DataType, DataValue from datachain.lib.file import File +from datachain.lib.func import Func from datachain.lib.model_store import ModelStore from datachain.lib.utils import DataChainParamsError from datachain.query.schema import DEFAULT_DELIMITER, Column @@ -494,6 +495,9 @@ def mutate(self, args_map: dict) -> "SignalSchema": # changing the type of existing signal, e.g File -> ImageFile del new_values[name] new_values[name] = args_map[name] + elif isinstance(value, Func): + # adding new signal with function + new_values[name] = value.get_result_type(self) else: # adding new signal new_values[name] = sql_to_python(value) diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 2ce0257d5..8faa13a29 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -11,7 +11,6 @@ from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue from datachain.lib.file import File -from datachain.lib.signal_schema import SignalSchema from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError from datachain.query.batch import ( Batch, @@ -25,6 +24,7 @@ from typing_extensions import Self from datachain.catalog import Catalog + from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf_signature import UdfSignature from datachain.query.batch import RowsOutput @@ -172,7 +172,7 @@ def teardown(self): def _init( self, sign: "UdfSignature", - params: SignalSchema, + params: "SignalSchema", func: Optional[Callable], ): self.params = params @@ -183,7 +183,7 @@ def _init( def _create( cls, sign: "UdfSignature", - params: SignalSchema, + params: "SignalSchema", ) -> "Self": if isinstance(sign.func, AbstractUDF): if not isinstance(sign.func, cls): # type: ignore[unreachable] diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 8d2ca157d..999cee4f7 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2443,3 +2443,70 @@ def test_group_by_error(test_session): SignalResolvingError, match="cannot resolve signal name 'col3': is not found" ): dc.group_by(foo=func.sum("col2"), partition_by="col3") + + +def test_window_row_number(test_session): + from datachain import func + + window = func.Window(partition_by="col1", order_by="col2") + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + .mutate(row_number=func.row_number().over(window)) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "col2": "int", + "row_number": "int", + } + assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts( + [ + {"col1": "a", "col2": 1, "row_number": 1}, + {"col1": "a", "col2": 2, "row_number": 2}, + {"col1": "b", "col2": 3, "row_number": 1}, + {"col1": "b", "col2": 4, "row_number": 2}, + {"col1": "b", "col2": 5, "row_number": 3}, + {"col1": "c", "col2": 6, "row_number": 1}, + ], + "col1", + "col2", + ) + + +def test_window_first(test_session): + from datachain import func + + window = func.Window(partition_by="col1", order_by="col2") + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + .mutate(first=func.first("col2").over(window)) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "col2": "int", + "first": "int", + } + assert sorted_dicts(ds.to_records(), "col1") == sorted_dicts( + [ + {"col1": "a", "col2": 1, "first": 1}, + {"col1": "a", "col2": 2, "first": 1}, + {"col1": "b", "col2": 3, "first": 3}, + {"col1": "b", "col2": 4, "first": 3}, + {"col1": "b", "col2": 5, "first": 3}, + {"col1": "c", "col2": 6, "first": 6}, + ], + "col1", + )