Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/datachain/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.convert.sql_to_python import sql_to_python
from datachain.lib.model_store import ModelStore
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
from datachain.query.schema import Column, ColumnMeta
from datachain.sql.functions import numeric
Expand Down Expand Up @@ -415,6 +416,20 @@ def get_column(
label: str | None = None,
table: "TableClause | None" = None,
) -> Column:
# Guard against using complex (pydantic) object columns in SQL funcs
if signals_schema and self._db_cols:
for arg in self._db_cols:
# _db_cols normalizes known columns to strings; skip non-string args
if not isinstance(arg, str):
continue
t_with_sub = signals_schema.get_column_type(arg, with_subtree=True)
if ModelStore.is_pydantic(t_with_sub):
raise DataChainParamsError(
f"Function {self.name} doesn't support complex object "
f"columns like '{arg}'. Use a leaf field (e.g., "
f"'{arg}.path') or use UDFs to operate on complex objects."
)

col_type = self.get_result_type(signals_schema)
sql_type = python_to_sql(col_type)

Expand All @@ -434,6 +449,7 @@ def get_col(col: ColT, string_as_literal=False) -> ColT:
return col

cols = [get_col(col) for col in self._db_cols]

kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
func_col = self.inner(*cols, *self.args, **kwargs)

Expand Down Expand Up @@ -470,9 +486,8 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
if isinstance(col, ColumnElement) and not hasattr(col, "name"):
return sql_to_python(col)

return signals_schema.get_column_type(
col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
)
name = col.name if isinstance(col, ColumnElement) else col # type: ignore[assignment]
return signals_schema.get_column_type(name) # type: ignore[arg-type]


def _truediv(a, b):
Expand Down
31 changes: 31 additions & 0 deletions tests/func/functions/test_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

import datachain as dc
from datachain import func
from datachain.lib.utils import DataChainParamsError


def test_aggregate_avg(test_session):
Expand Down Expand Up @@ -169,3 +172,31 @@ class Data(dc.DataModel):
assert all(isinstance(x[3], int) and x[3] in (20, 40, 60) for x in ds)
assert all(isinstance(x[4], float) and x[4] in (2.0, 4.0, 6.0) for x in ds)
assert all(isinstance(x[5], str) and x[5] in ("x", "y", "z") for x in ds)


def test_funcs_disallow_complex_object_collect(test_session):
class Rec(dc.DataModel):
i: int

ds = dc.read_values(
id=(1, 2),
rec=(Rec(i=1), Rec(i=2)),
session=test_session,
)

with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
ds.group_by(res=func.collect("rec"), partition_by="id")


def test_funcs_disallow_complex_object_min(test_session):
class Rec(dc.DataModel):
i: int

ds = dc.read_values(
id=(1, 2),
rec=(Rec(i=1), Rec(i=2)),
session=test_session,
)

with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
ds.group_by(res=func.min("rec"), partition_by="id")
13 changes: 13 additions & 0 deletions tests/func/functions/test_conditional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

import datachain as dc
from datachain import func
from datachain.lib.utils import DataChainParamsError
from tests.utils import skip_if_not_sqlite


Expand Down Expand Up @@ -174,3 +177,13 @@ class Data(dc.DataModel):
(20, 2.0, 20, 2.0, 40, 4.0, 40, 3.5, 40, 1.5),
(30, 3.0, 25, 2.5, 60, 6.0, 50, 3.5, 60, 1.5),
]


def test_conditional_funcs_disallow_complex_object_greatest(test_session):
class Rec(dc.DataModel):
i: int

ds = dc.read_values(id=[1, 2], rec=[Rec(i=1), Rec(i=2)], session=test_session)

with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
ds.mutate(t=func.greatest("rec", 1))
4 changes: 2 additions & 2 deletions tests/unit/test_datachain_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ def agg_persons(persons):
.sample(10)
.offset(2)
.limit(5)
.group_by(age_avg=func.avg("persons.age"), partition_by="persons.name")
.group_by(age_avg=func.avg("persons.ages"), partition_by="persons.name")
.select("persons.name", "age_avg")
.subtract(
players_chain,
on=["persons.name"],
right_on=["player.name"],
)
.hash()
) == "2c8d3fffade5574a418c45545f4c821cbe734f648cfcbfa55843673796bc35eb"
) == "ff0ab3df5e69f5e4f14ee7ddbeeddecfa1f237540b83ba5166ca3671625c6d4d"


def test_diff(test_session):
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy import Label

import datachain as dc
from datachain import func
from datachain.func import (
and_,
bit_hamming_distance,
Expand All @@ -18,6 +19,7 @@
from datachain.func.random import rand
from datachain.func.string import length as strlen
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.utils import DataChainParamsError
from datachain.sql.sqlite.base import (
sqlite_bit_hamming_distance,
sqlite_byte_hamming_distance,
Expand Down Expand Up @@ -69,6 +71,19 @@ def test_get_column():
assert col.name == "rand"


def test_get_column_disallow_complex_object_in_sql_funcs():
class Rec(dc.DataModel):
i: int

schema = SignalSchema({"rec": Rec})

with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
func.min("rec").get_column(schema)

with pytest.raises(DataChainParamsError, match="doesn't support complex object"):
func.collect("rec").get_column(schema)


def test_add():
rnd1, rnd2 = rand(), rand()

Expand Down