diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 3836a37ed..dc38601c1 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1008,6 +1008,9 @@ def group_by( """Group rows by specified set of signals and return new signals with aggregated values. + The supported functions: + count(), sum(), avg(), min(), max(), any_value(), collect(), concat() + Example: ```py chain = chain.group_by( @@ -1071,13 +1074,22 @@ def mutate(self, **kwargs) -> "Self": Filename: name(), parent(), file_stem(), file_ext() Array: length(), sip_hash_64(), euclidean_distance(), cosine_distance() + Window: row_number(), rank(), dense_rank(), first() Example: ```py dc.mutate( - area=Column("image.height") * Column("image.width"), - extension=file_ext(Column("file.name")), - dist=cosine_distance(embedding_text, embedding_image) + area=Column("image.height") * Column("image.width"), + extension=file_ext(Column("file.name")), + dist=cosine_distance(embedding_text, embedding_image) + ) + ``` + + Window function example: + ```py + window = func.window(partition_by="file.parent", order_by="file.size") + dc.mutate( + row_number=func.row_number().over(window), ) ``` @@ -1088,7 +1100,7 @@ def mutate(self, **kwargs) -> "Self": Example: ```py dc.mutate( - newkey=Column("oldkey") + newkey=Column("oldkey") ) ``` """ @@ -1101,7 +1113,7 @@ def mutate(self, **kwargs) -> "Self": "Use a different name for the new column.", ) for col_name, expr in kwargs.items(): - if not isinstance(expr, Column) and isinstance(expr.type, NullType): + if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType): raise DataChainColumnError( col_name, f"Cannot infer type with expression {expr}" ) @@ -1113,6 +1125,9 @@ def mutate(self, **kwargs) -> "Self": # renaming existing column for signal in schema.db_signals(name=value.name, as_columns=True): mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr] + elif isinstance(value, Func): + # adding new signal + mutated[name] = value.get_column(schema) else: # adding new signal mutated[name] = value diff --git a/src/datachain/lib/func/__init__.py b/src/datachain/lib/func/__init__.py index 5b4c5524a..ba6f08027 100644 --- a/src/datachain/lib/func/__init__.py +++ b/src/datachain/lib/func/__init__.py @@ -1,5 +1,18 @@ -from .aggregate import any_value, avg, collect, concat, count, max, min, sum -from .func import Func +from .aggregate import ( + any_value, + avg, + collect, + concat, + count, + dense_rank, + first, + max, + min, + rank, + row_number, + sum, +) +from .func import Func, window __all__ = [ "Func", @@ -8,7 +21,12 @@ "collect", "concat", "count", + "dense_rank", + "first", "max", "min", + "rank", + "row_number", "sum", + "window", ] diff --git a/src/datachain/lib/func/aggregate.py b/src/datachain/lib/func/aggregate.py index cfe04beb6..00ae0077a 100644 --- a/src/datachain/lib/func/aggregate.py +++ b/src/datachain/lib/func/aggregate.py @@ -8,35 +8,346 @@ def count(col: Optional[str] = None) -> Func: - return Func(inner=sa_func.count, col=col, result_type=int) + """ + Returns the COUNT aggregate SQL function for the given column name. + + The COUNT function returns the number of rows in a table. + + Args: + col (str, optional): The name of the column for which to count rows. + If not provided, it defaults to counting all rows. + + Returns: + Func: A Func object that represents the COUNT aggregate function. + + Example: + ```py + dc.group_by( + count=func.count(), + partition_by="signal.category", + ) + ``` + + Notes: + - Result column will always be of type int. + """ + return Func("count", inner=sa_func.count, col=col, result_type=int) def sum(col: str) -> Func: - return Func(inner=sa_func.sum, col=col) + """ + Returns the SUM aggregate SQL function for the given column name. + + The SUM function returns the total sum of a numeric column in a table. + It sums up all the values for the specified column. + + Args: + col (str): The name of the column for which to calculate the sum. + + Returns: + Func: A Func object that represents the SUM aggregate function. + + Example: + ```py + dc.group_by( + files_size=func.sum("file.size"), + partition_by="signal.category", + ) + ``` + + Notes: + - The `sum` function should be used on numeric columns. + - Result column type will be the same as the input column type. + """ + return Func("sum", inner=sa_func.sum, col=col) def avg(col: str) -> Func: - return Func(inner=dc_func.aggregate.avg, col=col) + """ + Returns the AVG aggregate SQL function for the given column name. + + The AVG function returns the average of a numeric column in a table. + It calculates the mean of all values in the specified column. + + Args: + col (str): The name of the column for which to calculate the average. + + Returns: + Func: A Func object that represents the AVG aggregate function. + + Example: + ```py + dc.group_by( + avg_file_size=func.avg("file.size"), + partition_by="signal.category", + ) + ``` + + Notes: + - The `avg` function should be used on numeric columns. + - Result column will always be of type float. + """ + return Func("avg", inner=dc_func.aggregate.avg, col=col, result_type=float) def min(col: str) -> Func: - return Func(inner=sa_func.min, col=col) + """ + Returns the MIN aggregate SQL function for the given column name. + + The MIN function returns the smallest value in the specified column. + It can be used on both numeric and non-numeric columns to find the minimum value. + + Args: + col (str): The name of the column for which to find the minimum value. + + Returns: + Func: A Func object that represents the MIN aggregate function. + + Example: + ```py + dc.group_by( + smallest_file=func.min("file.size"), + partition_by="signal.category", + ) + ``` + + Notes: + - The `min` function can be used with numeric, date, and string columns. + - Result column will have the same type as the input column. + """ + return Func("min", inner=sa_func.min, col=col) def max(col: str) -> Func: - return Func(inner=sa_func.max, col=col) + """ + Returns the MAX aggregate SQL function for the given column name. + + The MAX function returns the smallest value in the specified column. + It can be used on both numeric and non-numeric columns to find the maximum value. + + Args: + col (str): The name of the column for which to find the maximum value. + + Returns: + Func: A Func object that represents the MAX aggregate function. + + Example: + ```py + dc.group_by( + largest_file=func.max("file.size"), + partition_by="signal.category", + ) + ``` + + Notes: + - The `max` function can be used with numeric, date, and string columns. + - Result column will have the same type as the input column. + """ + return Func("max", inner=sa_func.max, col=col) def any_value(col: str) -> Func: - return Func(inner=dc_func.aggregate.any_value, col=col) + """ + Returns the ANY_VALUE aggregate SQL function for the given column name. + + The ANY_VALUE function returns an arbitrary value from the specified column. + It is useful when you do not care which particular value is returned, + as long as it comes from one of the rows in the group. + + Args: + col (str): The name of the column from which to return an arbitrary value. + + Returns: + Func: A Func object that represents the ANY_VALUE aggregate function. + + Example: + ```py + dc.group_by( + file_example=func.any_value("file.name"), + partition_by="signal.category", + ) + ``` + + Notes: + - The `any_value` function can be used with any type of column. + - Result column will have the same type as the input column. + - The result of `any_value` is non-deterministic, + meaning it may return different values for different executions. + """ + return Func("any_value", inner=dc_func.aggregate.any_value, col=col) def collect(col: str) -> Func: - return Func(inner=dc_func.aggregate.collect, col=col, is_array=True) + """ + Returns the COLLECT aggregate SQL function for the given column name. + + The COLLECT function gathers all values from the specified column + into an array or similar structure. It is useful for combining values from a column + into a collection, often for further processing or aggregation. + + Args: + col (str): The name of the column from which to collect values. + + Returns: + Func: A Func object that represents the COLLECT aggregate function. + + Example: + ```py + dc.group_by( + signals=func.collect("signal"), + partition_by="signal.category", + ) + ``` + + Notes: + - The `collect` function can be used with numeric and string columns. + - Result column will have an array type. + """ + return Func("collect", inner=dc_func.aggregate.collect, col=col, is_array=True) def concat(col: str, separator="") -> Func: + """ + Returns the CONCAT aggregate SQL function for the given column name. + + The CONCAT function concatenates values from the specified column + into a single string. It is useful for merging text values from multiple rows + into a single combined value. + + Args: + col (str): The name of the column from which to concatenate values. + separator (str, optional): The separator to use between concatenated values. + Defaults to an empty string. + + Returns: + Func: A Func object that represents the CONCAT aggregate function. + + Example: + ```py + dc.group_by( + files=func.concat("file.name", separator=", "), + partition_by="signal.category", + ) + ``` + + Notes: + - The `concat` function can be used with string columns. + - Result column will have a string type. + """ + def inner(arg): return dc_func.aggregate.group_concat(arg, separator) - return Func(inner=inner, col=col, result_type=str) + return Func("concat", inner=inner, col=col, result_type=str) + + +def row_number() -> Func: + """ + Returns the ROW_NUMBER window function for SQL queries. + + The ROW_NUMBER function assigns a unique sequential integer to rows + within a partition of a result set, starting from 1 for the first row + in each partition. It is commonly used to generate row numbers within + partitions or ordered results. + + Returns: + Func: A Func object that represents the ROW_NUMBER window function. + + Example: + ```py + window = func.window(partition_by="signal.category", order_by="created_at") + dc.mutate( + row_number=func.row_number().over(window), + ) + ``` + + Note: + - The result column will always be of type int. + """ + return Func("row_number", inner=sa_func.row_number, result_type=int, is_window=True) + + +def rank() -> Func: + """ + Returns the RANK window function for SQL queries. + + The RANK function assigns a rank to each row within a partition of a result set, + with gaps in the ranking for ties. Rows with equal values receive the same rank, + and the next rank is skipped (i.e., if two rows are ranked 1, + the next row is ranked 3). + + Returns: + Func: A Func object that represents the RANK window function. + + Example: + ```py + window = func.window(partition_by="signal.category", order_by="created_at") + dc.mutate( + rank=func.rank().over(window), + ) + ``` + + Notes: + - The result column will always be of type int. + - The RANK function differs from ROW_NUMBER in that rows with the same value + in the ordering column(s) receive the same rank. + """ + return Func("rank", inner=sa_func.rank, result_type=int, is_window=True) + + +def dense_rank() -> Func: + """ + Returns the DENSE_RANK window function for SQL queries. + + The DENSE_RANK function assigns a rank to each row within a partition + of a result set, without gaps in the ranking for ties. Rows with equal values + receive the same rank, but the next rank is assigned consecutively + (i.e., if two rows are ranked 1, the next row will be ranked 2). + + Returns: + Func: A Func object that represents the DENSE_RANK window function. + + Example: + ```py + window = func.window(partition_by="signal.category", order_by="created_at") + dc.mutate( + dense_rank=func.dense_rank().over(window), + ) + ``` + + Notes: + - The result column will always be of type int. + - The DENSE_RANK function differs from RANK in that it does not leave gaps + in the ranking for tied values. + """ + return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True) + + +def first(col: str) -> Func: + """ + Returns the FIRST_VALUE window function for SQL queries. + + The FIRST_VALUE function returns the first value in an ordered set of values + within a partition. The first value is determined by the specified order + and can be useful for retrieving the leading value in a group of rows. + + Args: + col (str): The name of the column from which to retrieve the first value. + + Returns: + Func: A Func object that represents the FIRST_VALUE window function. + + Example: + ```py + window = func.window(partition_by="signal.category", order_by="created_at") + dc.mutate( + first_file=func.first("file.name").over(window), + ) + ``` + + Note: + - The result of `first_value` will always reflect the value of the first row + in the specified order. + - The result column will have the same type as the input column. + """ + return Func("first", inner=sa_func.first_value, col=col, is_window=True) diff --git a/src/datachain/lib/func/func.py b/src/datachain/lib/func/func.py index ef4f3781e..3e7373d52 100644 --- a/src/datachain/lib/func/func.py +++ b/src/datachain/lib/func/func.py @@ -1,7 +1,10 @@ +from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional +from sqlalchemy import desc + from datachain.lib.convert.python_to_sql import python_to_sql -from datachain.lib.utils import DataChainColumnError +from datachain.lib.utils import DataChainColumnError, DataChainParamsError from datachain.query.schema import Column, ColumnMeta if TYPE_CHECKING: @@ -9,18 +12,89 @@ from datachain.lib.signal_schema import SignalSchema +@dataclass +class Window: + """Represents a window specification for SQL window functions.""" + + partition_by: str + order_by: str + desc: bool = False + + +def window(partition_by: str, order_by: str, desc: bool = False) -> Window: + """ + Defines a window specification for SQL window functions. + + The `window` function specifies how to partition and order the result set + for the associated window function. It is used to define the scope of the rows + that the window function will operate on. + + Args: + partition_by (str): The column name by which to partition the result set. + Rows with the same value in the partition column + will be grouped together for the window function. + order_by (str): The column name by which to order the rows + within each partition. This determines the sequence in which + the window function is applied. + desc (bool, optional): If True, the rows will be ordered in descending order. + Defaults to False, which orders the rows + in ascending order. + + Returns: + Window: A Window object representing the window specification. + + Example: + ```py + window = func.window(partition_by="signal.category", order_by="created_at") + dc.mutate( + row_number=func.row_number().over(window), + ) + ``` + """ + return Window( + ColumnMeta.to_db_name(partition_by), + ColumnMeta.to_db_name(order_by), + desc, + ) + + class Func: + """Represents a function to be applied to a column in a SQL query.""" + def __init__( self, + name: str, inner: Callable, col: Optional[str] = None, result_type: Optional["DataType"] = None, is_array: bool = False, + is_window: bool = False, + window: Optional[Window] = None, ) -> None: + self.name = name self.inner = inner self.col = col self.result_type = result_type self.is_array = is_array + self.is_window = is_window + self.window = window + + def __str__(self) -> str: + return self.name + "()" + + def over(self, window: Window) -> "Func": + if not self.is_window: + raise DataChainParamsError(f"{self} doesn't support window (over())") + + return Func( + "over", + self.inner, + self.col, + self.result_type, + self.is_array, + self.is_window, + window, + ) @property def db_col(self) -> Optional[str]: @@ -33,31 +107,45 @@ def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: return list[col_type] if self.is_array else col_type # type: ignore[valid-type] def get_result_type(self, signals_schema: "SignalSchema") -> "DataType": - col_type = self.db_col_type(signals_schema) - if self.result_type: return self.result_type - if col_type: + if col_type := self.db_col_type(signals_schema): return col_type raise DataChainColumnError( - str(self.inner), + str(self), "Column name is required to infer result type", ) def get_column( self, signals_schema: "SignalSchema", label: Optional[str] = None ) -> Column: + col_type = self.get_result_type(signals_schema) + sql_type = python_to_sql(col_type) + if self.col: - if label == "collect": - print(label) - col_type = self.get_result_type(signals_schema) - col = Column(self.db_col, python_to_sql(col_type)) + col = Column(self.db_col, sql_type) func_col = self.inner(col) else: func_col = self.inner() + if self.is_window: + if not self.window: + raise DataChainParamsError( + f"Window function {self} requires over() clause with a window spec", + ) + func_col = func_col.over( + partition_by=self.window.partition_by, + order_by=( + desc(self.window.order_by) + if self.window.desc + else self.window.order_by + ), + ) + + func_col.type = sql_type + if label: func_col = func_col.label(label) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 5cdb1fe5c..9de9d6039 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -27,6 +27,7 @@ from datachain.lib.convert.unflatten import unflatten_to_json_pos from datachain.lib.data_model import DataModel, DataType, DataValue from datachain.lib.file import File +from datachain.lib.func import Func from datachain.lib.model_store import ModelStore from datachain.lib.utils import DataChainParamsError from datachain.query.schema import DEFAULT_DELIMITER, Column @@ -494,6 +495,9 @@ def mutate(self, args_map: dict) -> "SignalSchema": # changing the type of existing signal, e.g File -> ImageFile del new_values[name] new_values[name] = args_map[name] + elif isinstance(value, Func): + # adding new signal with function + new_values[name] = value.get_result_type(self) else: # adding new signal new_values[name] = sql_to_python(value) diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 2ce0257d5..8faa13a29 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -11,7 +11,6 @@ from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue from datachain.lib.file import File -from datachain.lib.signal_schema import SignalSchema from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError from datachain.query.batch import ( Batch, @@ -25,6 +24,7 @@ from typing_extensions import Self from datachain.catalog import Catalog + from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf_signature import UdfSignature from datachain.query.batch import RowsOutput @@ -172,7 +172,7 @@ def teardown(self): def _init( self, sign: "UdfSignature", - params: SignalSchema, + params: "SignalSchema", func: Optional[Callable], ): self.params = params @@ -183,7 +183,7 @@ def _init( def _create( cls, sign: "UdfSignature", - params: SignalSchema, + params: "SignalSchema", ) -> "Self": if isinstance(sign.func, AbstractUDF): if not isinstance(sign.func, cls): # type: ignore[unreachable] diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 1b4ff705a..de84de3ba 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -1337,7 +1337,7 @@ class FileInfo(DataModel): path: str = "" name: str = "" - def file_info(file: File) -> DataModel: + def file_info(file: File) -> FileInfo: full_path = file.source.rstrip("/") + "/" + file.path rel_path = posixpath.relpath(full_path, src_uri) path_parts = rel_path.split("/", 1) @@ -1374,6 +1374,101 @@ def file_info(file: File) -> DataModel: ) +@pytest.mark.parametrize("partition_by", ["file_info.path", "file_info__path"]) +@pytest.mark.parametrize("order_by", ["file_info.name", "file_info__name"]) +def test_window_signals(cloud_test_catalog, partition_by, order_by): + from datachain import func + + session = cloud_test_catalog.session + src_uri = cloud_test_catalog.src_uri + + class FileInfo(DataModel): + path: str = "" + name: str = "" + + def file_info(file: File) -> FileInfo: + full_path = file.source.rstrip("/") + "/" + file.path + rel_path = posixpath.relpath(full_path, src_uri) + path_parts = rel_path.split("/", 1) + return FileInfo( + path=path_parts[0] if len(path_parts) > 1 else "", + name=path_parts[1] if len(path_parts) > 1 else path_parts[0], + ) + + window = func.window(partition_by=partition_by, order_by=order_by, desc=True) + + ds = ( + DataChain.from_storage(src_uri, session=session) + .map(file_info, params=["file"], output={"file_info": FileInfo}) + .mutate(row_number=func.row_number().over(window)) + .save("my-ds") + ) + + results = {} + for r in ds.to_records(): + filename = ( + r["file_info__path"] + "/" + r["file_info__name"] + if r["file_info__path"] + else r["file_info__name"] + ) + results[filename] = r["row_number"] + + assert results == { + "cats/cat2": 1, + "cats/cat1": 2, + "description": 1, + "dogs/others/dog4": 1, + "dogs/dog3": 2, + "dogs/dog2": 3, + "dogs/dog1": 4, + } + + +def test_window_signals_random(cloud_test_catalog): + from datachain import func + + session = cloud_test_catalog.session + src_uri = cloud_test_catalog.src_uri + + class FileInfo(DataModel): + path: str = "" + name: str = "" + + def file_info(file: File) -> FileInfo: + full_path = file.source.rstrip("/") + "/" + file.path + rel_path = posixpath.relpath(full_path, src_uri) + path_parts = rel_path.split("/", 1) + return FileInfo( + path=path_parts[0] if len(path_parts) > 1 else "", + name=path_parts[1] if len(path_parts) > 1 else path_parts[0], + ) + + window = func.window(partition_by="file_info.path", order_by="sys.rand") + + ds = ( + DataChain.from_storage(src_uri, session=session) + .map(file_info, params=["file"], output={"file_info": FileInfo}) + .mutate(row_number=func.row_number().over(window)) + .filter(C("row_number") < 3) + .select_except("row_number") + .save("my-ds") + ) + + results = {} + for r in ds.to_records(): + results.setdefault(r["file_info__path"], []).append(r["file_info__name"]) + + assert results[""] == ["description"] + assert sorted(results["cats"]) == sorted(["cat1", "cat2"]) + + assert len(results["dogs"]) == 2 + all_dogs = ["dog1", "dog2", "dog3", "others/dog4"] + for dog in results["dogs"]: + assert dog in all_dogs + all_dogs.remove(dog) + assert len(all_dogs) == 2 + + def test_to_from_csv_remote(cloud_test_catalog_upload): ctc = cloud_test_catalog_upload path = f"{ctc.src_uri}/test.csv" diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 8d2ca157d..9c88eecdf 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1,5 +1,6 @@ import datetime import math +import re from collections.abc import Generator, Iterator from unittest.mock import ANY @@ -2131,7 +2132,7 @@ def test_group_by_int(test_session): "cnt": "int", "cnt_col": "int", "sum": "int", - "avg": "int", + "avg": "float", "min": "int", "max": "int", "value": "int", @@ -2263,6 +2264,8 @@ def test_group_by_str(test_session): .group_by( cnt=func.count(), cnt_col=func.count("col2"), + min=func.min("col2"), + max=func.max("col2"), concat=func.concat("col2"), concat_sep=func.concat("col2", separator=","), value=func.any_value("col2"), @@ -2276,6 +2279,8 @@ def test_group_by_str(test_session): "col1": "str", "cnt": "int", "cnt_col": "int", + "min": "str", + "max": "str", "concat": "str", "concat_sep": "str", "value": "str", @@ -2287,6 +2292,8 @@ def test_group_by_str(test_session): "col1": "a", "cnt": 2, "cnt_col": 2, + "min": "1", + "max": "2", "concat": "12", "concat_sep": "1,2", "value": ANY_VALUE("1", "2"), @@ -2296,6 +2303,8 @@ def test_group_by_str(test_session): "col1": "b", "cnt": 3, "cnt_col": 3, + "min": "3", + "max": "5", "concat": "345", "concat_sep": "3,4,5", "value": ANY_VALUE("3", "4", "5"), @@ -2305,6 +2314,8 @@ def test_group_by_str(test_session): "col1": "c", "cnt": 1, "cnt_col": 1, + "min": "6", + "max": "6", "concat": "6", "concat_sep": "6", "value": ANY_VALUE("6"), @@ -2443,3 +2454,116 @@ def test_group_by_error(test_session): SignalResolvingError, match="cannot resolve signal name 'col3': is not found" ): dc.group_by(foo=func.sum("col2"), partition_by="col3") + + +@pytest.mark.parametrize("desc", [True, False]) +def test_window_functions(test_session, desc): + from datachain import func + + window = func.window(partition_by="col1", order_by="col2", desc=desc) + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + .mutate( + row_number=func.row_number().over(window), + rank=func.rank().over(window), + dense_rank=func.dense_rank().over(window), + first=func.first("col2").over(window), + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "col2": "int", + "row_number": "int", + "rank": "int", + "dense_rank": "int", + "first": "int", + } + assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts( + [ + { + "col1": "a", + "col2": 1, + "row_number": 2 if desc else 1, + "rank": 2 if desc else 1, + "dense_rank": 2 if desc else 1, + "first": 2 if desc else 1, + }, + { + "col1": "a", + "col2": 2, + "row_number": 1 if desc else 2, + "rank": 1 if desc else 2, + "dense_rank": 1 if desc else 2, + "first": 2 if desc else 1, + }, + { + "col1": "b", + "col2": 3, + "row_number": 3 if desc else 1, + "rank": 3 if desc else 1, + "dense_rank": 3 if desc else 1, + "first": 5 if desc else 3, + }, + { + "col1": "b", + "col2": 4, + "row_number": 2, + "rank": 2, + "dense_rank": 2, + "first": 5 if desc else 3, + }, + { + "col1": "b", + "col2": 5, + "row_number": 1 if desc else 3, + "rank": 1 if desc else 3, + "dense_rank": 1 if desc else 3, + "first": 5 if desc else 3, + }, + { + "col1": "c", + "col2": 6, + "row_number": 1, + "rank": 1, + "dense_rank": 1, + "first": 6, + }, + ], + "col1", + "col2", + ) + + +def test_window_error(test_session): + from datachain import func + + window = func.window(partition_by="col1", order_by="col2") + + dc = DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + + with pytest.raises( + DataChainParamsError, + match=re.escape( + "Window function first() requires over() clause with a window spec", + ), + ): + dc.mutate(first=func.first("col2")) + + with pytest.raises( + DataChainParamsError, + match=re.escape( + "sum() doesn't support window (over())", + ), + ): + dc.mutate(first=func.sum("col2").over(window))