Skip to content

Commit a18c0de

Browse files
committed
improve error messages in func when complex column passed
1 parent cdf8ee9 commit a18c0de

File tree

4 files changed

+84
-3
lines changed

4 files changed

+84
-3
lines changed

src/datachain/func/func.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from datachain.lib.convert.python_to_sql import python_to_sql
99
from datachain.lib.convert.sql_to_python import sql_to_python
10+
from datachain.lib.model_store import ModelStore
1011
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
1112
from datachain.query.schema import Column, ColumnMeta
1213
from datachain.sql.functions import numeric
@@ -415,6 +416,21 @@ def get_column(
415416
label: str | None = None,
416417
table: "TableClause | None" = None,
417418
) -> Column:
419+
# Guard against using complex (pydantic) object columns in SQL funcs that
420+
# don't support them. Suggest using leaf fields or agg() instead.
421+
if signals_schema and self._db_cols:
422+
for arg in self._db_cols:
423+
# _db_cols normalizes known columns to strings; skip non-string args
424+
if not isinstance(arg, str):
425+
continue
426+
t_with_sub = signals_schema.get_column_type(arg, with_subtree=True)
427+
if ModelStore.is_pydantic(t_with_sub):
428+
raise DataChainParamsError(
429+
f"Function {self.name} doesn't support complex object "
430+
f"columns like '{arg}'. Use a leaf field (e.g., "
431+
f"'{arg}.path') or use UDFs to operate on complex objects."
432+
)
433+
418434
col_type = self.get_result_type(signals_schema)
419435
sql_type = python_to_sql(col_type)
420436

@@ -434,6 +450,7 @@ def get_col(col: ColT, string_as_literal=False) -> ColT:
434450
return col
435451

436452
cols = [get_col(col) for col in self._db_cols]
453+
437454
kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
438455
func_col = self.inner(*cols, *self.args, **kwargs)
439456

@@ -470,9 +487,14 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
470487
if isinstance(col, ColumnElement) and not hasattr(col, "name"):
471488
return sql_to_python(col)
472489

473-
return signals_schema.get_column_type(
474-
col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
475-
)
490+
name = col.name if isinstance(col, ColumnElement) else col # type: ignore[assignment]
491+
from datachain.lib.signal_schema import SignalResolvingError
492+
493+
try:
494+
return signals_schema.get_column_type(name) # type: ignore[arg-type]
495+
except SignalResolvingError:
496+
# Fallback for complex objects (models) at top-level
497+
return signals_schema.get_column_type(name, with_subtree=True) # type: ignore[arg-type]
476498

477499

478500
def _truediv(a, b):

tests/func/functions/test_aggregate.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import pytest
2+
13
import datachain as dc
24
from datachain import func
5+
from datachain.lib.utils import DataChainParamsError
36

47

58
def test_aggregate_avg(test_session):
@@ -169,3 +172,31 @@ class Data(dc.DataModel):
169172
assert all(isinstance(x[3], int) and x[3] in (20, 40, 60) for x in ds)
170173
assert all(isinstance(x[4], float) and x[4] in (2.0, 4.0, 6.0) for x in ds)
171174
assert all(isinstance(x[5], str) and x[5] in ("x", "y", "z") for x in ds)
175+
176+
177+
def test_funcs_disallow_complex_object_collect(test_session):
178+
class Rec(dc.DataModel):
179+
i: int
180+
181+
ds = dc.read_values(
182+
id=(1, 2),
183+
rec=(Rec(i=1), Rec(i=2)),
184+
session=test_session,
185+
)
186+
187+
with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
188+
ds.group_by(res=func.collect("rec"), partition_by="id")
189+
190+
191+
def test_funcs_disallow_complex_object_min(test_session):
192+
class Rec(dc.DataModel):
193+
i: int
194+
195+
ds = dc.read_values(
196+
id=(1, 2),
197+
rec=(Rec(i=1), Rec(i=2)),
198+
session=test_session,
199+
)
200+
201+
with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
202+
ds.group_by(res=func.min("rec"), partition_by="id")

tests/func/functions/test_conditional.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import pytest
2+
13
import datachain as dc
24
from datachain import func
5+
from datachain.lib.utils import DataChainParamsError
36
from tests.utils import skip_if_not_sqlite
47

58

@@ -174,3 +177,13 @@ class Data(dc.DataModel):
174177
(20, 2.0, 20, 2.0, 40, 4.0, 40, 3.5, 40, 1.5),
175178
(30, 3.0, 25, 2.5, 60, 6.0, 50, 3.5, 60, 1.5),
176179
]
180+
181+
182+
def test_conditional_funcs_disallow_complex_object_greatest(test_session):
183+
class Rec(dc.DataModel):
184+
i: int
185+
186+
ds = dc.read_values(id=[1, 2], rec=[Rec(i=1), Rec(i=2)], session=test_session)
187+
188+
with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
189+
ds.mutate(t=func.greatest("rec", 1))

tests/unit/test_func.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sqlalchemy import Label
33

44
import datachain as dc
5+
from datachain import func
56
from datachain.func import (
67
and_,
78
bit_hamming_distance,
@@ -18,6 +19,7 @@
1819
from datachain.func.random import rand
1920
from datachain.func.string import length as strlen
2021
from datachain.lib.signal_schema import SignalSchema
22+
from datachain.lib.utils import DataChainParamsError
2123
from datachain.sql.sqlite.base import (
2224
sqlite_bit_hamming_distance,
2325
sqlite_byte_hamming_distance,
@@ -69,6 +71,19 @@ def test_get_column():
6971
assert col.name == "rand"
7072

7173

74+
def test_get_column_disallow_complex_object_in_sql_funcs():
75+
class Rec(dc.DataModel):
76+
i: int
77+
78+
schema = SignalSchema({"rec": Rec})
79+
80+
with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
81+
func.min("rec").get_column(schema)
82+
83+
with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
84+
func.collect("rec").get_column(schema)
85+
86+
7287
def test_add():
7388
rnd1, rnd2 = rand(), rand()
7489

0 commit comments

Comments
 (0)