Skip to content

Commit

Permalink
Remove 'partition_by' requirement from 'group_dy' method (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour authored Dec 2, 2024
1 parent bf7d670 commit 00e03c8
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 16 deletions.
5 changes: 3 additions & 2 deletions src/datachain/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from sqlalchemy import BindParameter, ColumnElement, desc
from sqlalchemy import BindParameter, Case, ColumnElement, desc
from sqlalchemy.ext.hybrid import Comparator

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
Expand Down Expand Up @@ -71,7 +72,7 @@ def _db_cols(self) -> Sequence[ColT]:
return (
[
col
if isinstance(col, (Func, BindParameter))
if isinstance(col, (Func, BindParameter, Case, Comparator))
else ColumnMeta.to_db_name(
col.name if isinstance(col, ColumnElement) else col
)
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def select_except(self, *args: str) -> "Self":
def group_by(
self,
*,
partition_by: Union[str, Func, Sequence[Union[str, Func]]],
partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None,
**kwargs: Func,
) -> "Self":
"""Group rows by specified set of signals and return new signals
Expand All @@ -1167,10 +1167,10 @@ def group_by(
)
```
"""
if isinstance(partition_by, (str, Func)):
if partition_by is None:
partition_by = []
elif isinstance(partition_by, (str, Func)):
partition_by = [partition_by]
if not partition_by:
raise ValueError("At least one column should be provided for partition_by")

partition_by_columns: list[Column] = []
signal_columns: list[Column] = []
Expand Down
2 changes: 0 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,6 @@ class SQLGroupBy(SQLClause):
def apply_sql_clause(self, query) -> Select:
if not self.cols:
raise ValueError("No columns to select")
if not self.group_by:
raise ValueError("No columns to group by")

subquery = query.subquery()

Expand Down
8 changes: 0 additions & 8 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2740,14 +2740,6 @@ def test_group_by_error(test_session):
session=test_session,
)

with pytest.raises(TypeError):
dc.group_by(cnt=func.count())

with pytest.raises(
ValueError, match="At least one column should be provided for partition_by"
):
dc.group_by(cnt=func.count(), partition_by=())

with pytest.raises(
ValueError, match="At least one column should be provided for group_by"
):
Expand Down

0 comments on commit 00e03c8

Please sign in to comment.