From 00e03c8297641082c31e92a4ce63d1439929bdfe Mon Sep 17 00:00:00 2001 From: Vladimir Rudnykh Date: Mon, 2 Dec 2024 13:37:36 +0700 Subject: [PATCH] Remove 'partition_by' requirement from 'group_dy' method (#649) --- src/datachain/func/func.py | 5 +++-- src/datachain/lib/dc.py | 8 ++++---- src/datachain/query/dataset.py | 2 -- tests/unit/lib/test_datachain.py | 8 -------- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py index 3f2352eba..7c8d6cd75 100644 --- a/src/datachain/func/func.py +++ b/src/datachain/func/func.py @@ -2,7 +2,8 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from sqlalchemy import BindParameter, ColumnElement, desc +from sqlalchemy import BindParameter, Case, ColumnElement, desc +from sqlalchemy.ext.hybrid import Comparator from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.utils import DataChainColumnError, DataChainParamsError @@ -71,7 +72,7 @@ def _db_cols(self) -> Sequence[ColT]: return ( [ col - if isinstance(col, (Func, BindParameter)) + if isinstance(col, (Func, BindParameter, Case, Comparator)) else ColumnMeta.to_db_name( col.name if isinstance(col, ColumnElement) else col ) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 94f3782db..fbab81afa 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1150,7 +1150,7 @@ def select_except(self, *args: str) -> "Self": def group_by( self, *, - partition_by: Union[str, Func, Sequence[Union[str, Func]]], + partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None, **kwargs: Func, ) -> "Self": """Group rows by specified set of signals and return new signals @@ -1167,10 +1167,10 @@ def group_by( ) ``` """ - if isinstance(partition_by, (str, Func)): + if partition_by is None: + partition_by = [] + elif isinstance(partition_by, (str, Func)): partition_by = [partition_by] - if not partition_by: - raise ValueError("At least one column should be provided for partition_by") partition_by_columns: list[Column] = [] signal_columns: list[Column] = [] diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 07cfab151..effc24caf 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -966,8 +966,6 @@ class SQLGroupBy(SQLClause): def apply_sql_clause(self, query) -> Select: if not self.cols: raise ValueError("No columns to select") - if not self.group_by: - raise ValueError("No columns to group by") subquery = query.subquery() diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 0e42664f2..5802227a4 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2740,14 +2740,6 @@ def test_group_by_error(test_session): session=test_session, ) - with pytest.raises(TypeError): - dc.group_by(cnt=func.count()) - - with pytest.raises( - ValueError, match="At least one column should be provided for partition_by" - ): - dc.group_by(cnt=func.count(), partition_by=()) - with pytest.raises( ValueError, match="At least one column should be provided for group_by" ):