Skip to content

Commit

Permalink
Fix renaming object or normal signal with .mutate() (#217)
Browse files Browse the repository at this point in the history
* fix renaming object or normal signal with mutate

* removed print and added more tests

* added test for has object

* refactoring tests

* added one assert in test

* simplifying code, removing has_object

* removing sys from tests

* fixing test
  • Loading branch information
ilongin authored Aug 5, 2024
1 parent 7832e10 commit 1e5178b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 10 deletions.
17 changes: 14 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,19 @@ def mutate(self, **kwargs) -> "Self":
)
```
"""
chain = super().mutate(**kwargs)
chain.signals_schema = self.signals_schema.mutate(kwargs)
mutated = {}
schema = self.signals_schema
for name, value in kwargs.items():
if isinstance(value, Column):
# renaming existing column
for signal in schema.db_signals(name=value.name, as_columns=True):
mutated[signal.name.replace(value.name, name, 1)] = signal
else:
# adding new signal
mutated[name] = value

chain = super().mutate(**mutated)
chain.signals_schema = schema.mutate(kwargs)
return chain

@property
Expand Down Expand Up @@ -1099,7 +1110,7 @@ def subtract( # type: ignore[override]
)
else:
signals = self.signals_schema.resolve(*on).db_signals()
return super()._subtract(other, signals)
return super()._subtract(other, signals) # type: ignore[arg-type]

@classmethod
def from_values(
Expand Down
38 changes: 33 additions & 5 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from datachain.lib.file import File
from datachain.lib.model_store import ModelStore
from datachain.lib.utils import DataChainParamsError
from datachain.query.schema import DEFAULT_DELIMITER
from datachain.query.schema import DEFAULT_DELIMITER, Column

if TYPE_CHECKING:
from datachain.catalog import Catalog
Expand Down Expand Up @@ -222,13 +222,30 @@ def row_to_features(
res.append(obj)
return res

def db_signals(self) -> list[str]:
return [
def db_signals(
self, name: Optional[str] = None, as_columns=False
) -> Union[list[str], list[Column]]:
"""
Returns DB columns as strings or Column objects with proper types
Optionally, it can filter results by specific object, returning only his signals
"""
signals = [
DEFAULT_DELIMITER.join(path)
for path, _, has_subtree, _ in self.get_flat_tree()
if not as_columns
else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type))
for path, _type, has_subtree, _ in self.get_flat_tree()
if not has_subtree
]

if name:
signals = [
s
for s in signals
if str(s) == name or str(s).startswith(f"{name}{DEFAULT_DELIMITER}")
]

return signals # type: ignore[return-value]

def resolve(self, *names: str) -> "SignalSchema":
schema = {}
for field in names:
Expand Down Expand Up @@ -282,7 +299,18 @@ def clone_without_file_signals(self) -> "SignalSchema":
return SignalSchema(schema)

def mutate(self, args_map: dict) -> "SignalSchema":
return SignalSchema(self.values | sql_to_python(args_map))
new_values = self.values.copy()

for name, value in args_map.items():
if isinstance(value, Column) and value.name in self.values:
# renaming existing signal
del new_values[value.name]
new_values[name] = self.values[value.name]
else:
# adding new signal
new_values.update(sql_to_python({name: value}))

return SignalSchema(new_values)

def clone_without_sys_signals(self) -> "SignalSchema":
schema = copy.deepcopy(self.values)
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,3 +1304,59 @@ def func(key, val) -> Iterator[tuple[File, str]]:
assert ds.limit(3).gen(res=func).limit(2).count() == 2
assert ds.limit(2).gen(res=func).limit(3).count() == 3
assert ds.limit(3).gen(res=func).limit(10).count() == 9


def test_rename_non_object_column_name_with_mutate(catalog):
ds = DataChain.from_values(ids=[1, 2, 3])
ds = ds.mutate(my_ids=Column("ids"))

assert ds.signals_schema.values == {"my_ids": int}
assert list(ds.order_by("my_ids").collect("my_ids")) == [1, 2, 3]

ds.save("mutated")

ds = DataChain(name="mutated")
assert ds.signals_schema.values.get("my_ids") is int
assert "ids" not in ds.signals_schema.values
assert list(ds.order_by("my_ids").collect("my_ids")) == [1, 2, 3]


def test_rename_object_column_name_with_mutate(catalog):
names = ["a", "b", "c"]
sizes = [1, 2, 3]
files = [File(name=name, size=size) for name, size in zip(names, sizes)]

ds = DataChain.from_values(file=files, ids=[1, 2, 3])
ds = ds.mutate(fname=Column("file.name"))

assert list(ds.order_by("fname").collect("fname")) == ["a", "b", "c"]
assert ds.signals_schema.values == {"file": File, "ids": int, "fname": str}

# check that persist after saving
ds.save("mutated")

ds = DataChain(name="mutated")
assert ds.signals_schema.values.get("file") is File
assert ds.signals_schema.values.get("ids") is int
assert ds.signals_schema.values.get("fname") is str
assert list(ds.order_by("fname").collect("fname")) == ["a", "b", "c"]


def test_rename_object_name_with_mutate(catalog):
names = ["a", "b", "c"]
sizes = [1, 2, 3]
files = [File(name=name, size=size) for name, size in zip(names, sizes)]

ds = DataChain.from_values(file=files, ids=[1, 2, 3])
ds = ds.mutate(my_file=Column("file"))

assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"]
assert ds.signals_schema.values == {"my_file": File, "ids": int}

ds.save("mutated")

ds = DataChain(name="mutated")
assert ds.signals_schema.values.get("my_file") is File
assert ds.signals_schema.values.get("ids") is int
assert "file" not in ds.signals_schema.values
assert list(ds.order_by("my_file.name").collect("my_file.name")) == ["a", "b", "c"]
31 changes: 29 additions & 2 deletions tests/unit/lib/test_signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from datachain import DataModel
from datachain import Column, DataModel
from datachain.lib.convert.flatten import flatten
from datachain.lib.file import File
from datachain.lib.signal_schema import (
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_print_types():
assert SignalSchema._type_to_str(t) == v


def test_bd_signals():
def test_db_signals():
spec = {"name": str, "age": float, "fr": MyType2}
lst = list(SignalSchema(spec).db_signals())

Expand All @@ -264,6 +264,33 @@ def test_bd_signals():
]


def test_db_signals_filtering_by_name():
schema = SignalSchema({"name": str, "age": float, "fr": MyType2})

assert list(schema.db_signals(name="fr")) == [
"fr__name",
"fr__deep__aa",
"fr__deep__bb",
]
assert list(schema.db_signals(name="name")) == ["name"]
assert list(schema.db_signals(name="missing")) == []


def test_db_signals_as_columns():
spec = {"name": str, "age": float, "fr": MyType2}
lst = list(SignalSchema(spec).db_signals(as_columns=True))

assert all(isinstance(s, Column) for s in lst)

assert [(c.name, type(c.type)) for c in lst] == [
("name", String),
("age", Float),
("fr__name", String),
("fr__deep__aa", Int64),
("fr__deep__bb", String),
]


def test_row_to_objs():
spec = {"name": str, "age": float, "fr": MyType2}
schema = SignalSchema(spec)
Expand Down

0 comments on commit 1e5178b

Please sign in to comment.