diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 91e674444..e59109504 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -829,8 +829,19 @@ def mutate(self, **kwargs) -> "Self": ) ``` """ - chain = super().mutate(**kwargs) - chain.signals_schema = self.signals_schema.mutate(kwargs) + mutated = {} + schema = self.signals_schema + for name, value in kwargs.items(): + if isinstance(value, Column): + # renaming existing column + for signal in schema.db_signals(name=value.name, as_columns=True): + mutated[signal.name.replace(value.name, name, 1)] = signal + else: + # adding new signal + mutated[name] = value + + chain = super().mutate(**mutated) + chain.signals_schema = schema.mutate(kwargs) return chain @property @@ -1099,7 +1110,7 @@ def subtract( # type: ignore[override] ) else: signals = self.signals_schema.resolve(*on).db_signals() - return super()._subtract(other, signals) + return super()._subtract(other, signals) # type: ignore[arg-type] @classmethod def from_values( diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 59f305c5f..338afd06d 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -25,7 +25,7 @@ from datachain.lib.file import File from datachain.lib.model_store import ModelStore from datachain.lib.utils import DataChainParamsError -from datachain.query.schema import DEFAULT_DELIMITER +from datachain.query.schema import DEFAULT_DELIMITER, Column if TYPE_CHECKING: from datachain.catalog import Catalog @@ -222,13 +222,30 @@ def row_to_features( res.append(obj) return res - def db_signals(self) -> list[str]: - return [ + def db_signals( + self, name: Optional[str] = None, as_columns=False + ) -> Union[list[str], list[Column]]: + """ + Returns DB columns as strings or Column objects with proper types + Optionally, it can filter results by specific object, returning only his signals + """ + signals = [ DEFAULT_DELIMITER.join(path) - for path, _, has_subtree, _ in self.get_flat_tree() + if not as_columns + else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type)) + for path, _type, has_subtree, _ in self.get_flat_tree() if not has_subtree ] + if name: + signals = [ + s + for s in signals + if str(s) == name or str(s).startswith(f"{name}{DEFAULT_DELIMITER}") + ] + + return signals # type: ignore[return-value] + def resolve(self, *names: str) -> "SignalSchema": schema = {} for field in names: @@ -282,7 +299,18 @@ def clone_without_file_signals(self) -> "SignalSchema": return SignalSchema(schema) def mutate(self, args_map: dict) -> "SignalSchema": - return SignalSchema(self.values | sql_to_python(args_map)) + new_values = self.values.copy() + + for name, value in args_map.items(): + if isinstance(value, Column) and value.name in self.values: + # renaming existing signal + del new_values[value.name] + new_values[name] = self.values[value.name] + else: + # adding new signal + new_values.update(sql_to_python({name: value})) + + return SignalSchema(new_values) def clone_without_sys_signals(self) -> "SignalSchema": schema = copy.deepcopy(self.values) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 4ea9650c7..a126beefb 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1304,3 +1304,59 @@ def func(key, val) -> Iterator[tuple[File, str]]: assert ds.limit(3).gen(res=func).limit(2).count() == 2 assert ds.limit(2).gen(res=func).limit(3).count() == 3 assert ds.limit(3).gen(res=func).limit(10).count() == 9 + + +def test_rename_non_object_column_name_with_mutate(catalog): + ds = DataChain.from_values(ids=[1, 2, 3]) + ds = ds.mutate(my_ids=Column("ids")) + + assert ds.signals_schema.values == {"my_ids": int} + assert list(ds.order_by("my_ids").collect("my_ids")) == [1, 2, 3] + + ds.save("mutated") + + ds = DataChain(name="mutated") + assert ds.signals_schema.values.get("my_ids") is int + assert "ids" not in ds.signals_schema.values + assert list(ds.order_by("my_ids").collect("my_ids")) == [1, 2, 3] + + +def test_rename_object_column_name_with_mutate(catalog): + names = ["a", "b", "c"] + sizes = [1, 2, 3] + files = [File(name=name, size=size) for name, size in zip(names, sizes)] + + ds = DataChain.from_values(file=files, ids=[1, 2, 3]) + ds = ds.mutate(fname=Column("file.name")) + + assert list(ds.order_by("fname").collect("fname")) == ["a", "b", "c"] + assert ds.signals_schema.values == {"file": File, "ids": int, "fname": str} + + # check that persist after saving + ds.save("mutated") + + ds = DataChain(name="mutated") + assert ds.signals_schema.values.get("file") is File + assert ds.signals_schema.values.get("ids") is int + assert ds.signals_schema.values.get("fname") is str + assert list(ds.order_by("fname").collect("fname")) == ["a", "b", "c"] + + +def test_rename_object_name_with_mutate(catalog): + names = ["a", "b", "c"] + sizes = [1, 2, 3] + files = [File(name=name, size=size) for name, size in zip(names, sizes)] + + ds = DataChain.from_values(file=files, ids=[1, 2, 3]) + ds = ds.mutate(my_file=Column("file")) + + assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"] + assert ds.signals_schema.values == {"my_file": File, "ids": int} + + ds.save("mutated") + + ds = DataChain(name="mutated") + assert ds.signals_schema.values.get("my_file") is File + assert ds.signals_schema.values.get("ids") is int + assert "file" not in ds.signals_schema.values + assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"] diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 3df1764b6..2be6bf3cd 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -3,7 +3,7 @@ import pytest -from datachain import DataModel +from datachain import Column, DataModel from datachain.lib.convert.flatten import flatten from datachain.lib.file import File from datachain.lib.signal_schema import ( @@ -251,7 +251,7 @@ def test_print_types(): assert SignalSchema._type_to_str(t) == v -def test_bd_signals(): +def test_db_signals(): spec = {"name": str, "age": float, "fr": MyType2} lst = list(SignalSchema(spec).db_signals()) @@ -264,6 +264,33 @@ def test_bd_signals(): ] +def test_db_signals_filtering_by_name(): + schema = SignalSchema({"name": str, "age": float, "fr": MyType2}) + + assert list(schema.db_signals(name="fr")) == [ + "fr__name", + "fr__deep__aa", + "fr__deep__bb", + ] + assert list(schema.db_signals(name="name")) == ["name"] + assert list(schema.db_signals(name="missing")) == [] + + +def test_db_signals_as_columns(): + spec = {"name": str, "age": float, "fr": MyType2} + lst = list(SignalSchema(spec).db_signals(as_columns=True)) + + assert all(isinstance(s, Column) for s in lst) + + assert [(c.name, type(c.type)) for c in lst] == [ + ("name", String), + ("age", Float), + ("fr__name", String), + ("fr__deep__aa", Int64), + ("fr__deep__bb", String), + ] + + def test_row_to_objs(): spec = {"name": str, "age": float, "fr": MyType2} schema = SignalSchema(spec)