diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 85f2618ca..315fb4d9e 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1624,6 +1624,82 @@ def subtract( # type: ignore[override] ) return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type] + def compare( + self, + other: "DataChain", + on: Union[str, Sequence[str]], + right_on: Optional[Union[str, Sequence[str]]] = None, + compare: Optional[Union[str, Sequence[str]]] = None, + right_compare: Optional[Union[str, Sequence[str]]] = None, + added: bool = True, + deleted: bool = True, + modified: bool = True, + unchanged: bool = False, + status_col: Optional[str] = None, + ) -> "DataChain": + """Comparing two chains by identifying rows that are added, deleted, modified + or unchanged. Result is the new chain that has additional column with possible + values: `A`, `D`, `M`, `U` representing added, deleted, modified and unchanged + rows respectively. Note that if only one "status" is asked, by setting proper + flags, this additional column is not created as it would have only one value + for all rows. Beside additional diff column, new chain has schema of the chain + on which method was called. + + Parameters: + other: Chain to calculate diff from. + on: Column or list of columns to match on. If both chains have the + same columns then this column is enough for the match. Otherwise, + `right_on` parameter has to specify the columns for the other chain. + This value is used to find corresponding row in other dataset. If not + found there, row is considered as added (or removed if vice versa), and + if found then row can be either modified or unchanged. + right_on: Optional column or list of columns + for the `other` to match. + compare: Column or list of columns to compare on. If both chains have + the same columns then this column is enough for the compare. Otherwise, + `right_compare` parameter has to specify the columns for the other + chain. This value is used to see if row is modified or unchanged. If + not set, all columns will be used for comparison + right_compare: Optional column or list of columns + for the `other` to compare to. + added (bool): Whether to return added rows in resulting chain. + deleted (bool): Whether to return deleted rows in resulting chain. + modified (bool): Whether to return modified rows in resulting chain. + unchanged (bool): Whether to return unchanged rows in resulting chain. + status_col (str): Name of the new column that is created in resulting chain + representing diff status. + + Example: + ```py + diff = persons.diff( + new_persons, + on=["id"], + right_on=["other_id"], + compare=["name"], + added=True, + deleted=True, + modified=True, + unchanged=True, + status_col="diff" + ) + ``` + """ + from datachain.lib.diff import compare as chain_compare + + return chain_compare( + self, + other, + on, + right_on=right_on, + compare=compare, + right_compare=right_compare, + added=added, + deleted=deleted, + modified=modified, + unchanged=unchanged, + status_col=status_col, + ) + @classmethod def from_values( cls, diff --git a/src/datachain/lib/diff.py b/src/datachain/lib/diff.py new file mode 100644 index 000000000..c8ae39303 --- /dev/null +++ b/src/datachain/lib/diff.py @@ -0,0 +1,198 @@ +import random +import string +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union + +import sqlalchemy as sa + +from datachain.lib.signal_schema import SignalSchema +from datachain.query.schema import Column +from datachain.sql.types import String + +if TYPE_CHECKING: + from datachain.lib.dc import DataChain + + +C = Column + + +def compare( # noqa: PLR0912, PLR0915, C901 + left: "DataChain", + right: "DataChain", + on: Union[str, Sequence[str]], + right_on: Optional[Union[str, Sequence[str]]] = None, + compare: Optional[Union[str, Sequence[str]]] = None, + right_compare: Optional[Union[str, Sequence[str]]] = None, + added: bool = True, + deleted: bool = True, + modified: bool = True, + unchanged: bool = False, + status_col: Optional[str] = None, +) -> "DataChain": + """Comparing two chains by identifying rows that are added, deleted, modified + or unchanged""" + dialect = left._query.dialect + + rname = "right_" + + def _rprefix(c: str, rc: str) -> str: + """Returns prefix of right of two companion left - right columns + from merge. If companion columns have the same name then prefix will + be present in right column name, otherwise it won't. + """ + return rname if c == rc else "" + + def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: + return [obj] if isinstance(obj, str) else list(obj) + + if on is None: + raise ValueError("'on' must be specified") + + on = _to_list(on) + if right_on: + right_on = _to_list(right_on) + if len(on) != len(right_on): + raise ValueError("'on' and 'right_on' must be have the same length") + + if compare: + compare = _to_list(compare) + + if right_compare: + if not compare: + raise ValueError("'compare' must be defined if 'right_compare' is defined") + + right_compare = _to_list(right_compare) + if len(compare) != len(right_compare): + raise ValueError( + "'compare' and 'right_compare' must be have the same length" + ) + + if not any([added, deleted, modified, unchanged]): + raise ValueError( + "At least one of added, deleted, modified, unchanged flags must be set" + ) + + # we still need status column for internal implementation even if not + # needed in output + need_status_col = bool(status_col) + status_col = status_col or "diff_" + "".join( + random.choice(string.ascii_letters) # noqa: S311 + for _ in range(10) + ) + + # calculate on and compare column names + right_on = right_on or on + cols = left.signals_schema.clone_without_sys_signals().db_signals() + right_cols = right.signals_schema.clone_without_sys_signals().db_signals() + + on = left.signals_schema.resolve(*on).db_signals() # type: ignore[assignment] + right_on = right.signals_schema.resolve(*right_on).db_signals() # type: ignore[assignment] + if compare: + right_compare = right_compare or compare + compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment] + right_compare = right.signals_schema.resolve(*right_compare).db_signals() # type: ignore[assignment] + elif not compare and len(cols) != len(right_cols): + # here we will mark all rows that are not added or deleted as modified since + # there was no explicit list of compare columns provided (meaning we need + # to check all columns to determine if row is modified or unchanged), but + # the number of columns on left and right is not the same (one of the chains + # have additional column) + compare = None + right_compare = None + else: + compare = [c for c in cols if c in right_cols] # type: ignore[misc, assignment] + right_compare = compare + + diff_cond = [] + + if added: + added_cond = sa.and_( + *[ + C(c) == None # noqa: E711 + for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)] + ] + ) + diff_cond.append((added_cond, "A")) + if modified and compare: + modified_cond = sa.or_( + *[ + C(c) != C(f"{_rprefix(c, rc)}{rc}") + for c, rc in zip(compare, right_compare) # type: ignore[arg-type] + ] + ) + diff_cond.append((modified_cond, "M")) + if unchanged and compare: + unchanged_cond = sa.and_( + *[ + C(c) == C(f"{_rprefix(c, rc)}{rc}") + for c, rc in zip(compare, right_compare) # type: ignore[arg-type] + ] + ) + diff_cond.append((unchanged_cond, "U")) + + diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col) + diff.type = String() + + left_right_merge = left.merge( + right, on=on, right_on=right_on, inner=False, rname=rname + ) + left_right_merge_select = left_right_merge._query.select( + *( + [C(c) for c in left_right_merge.signals_schema.db_signals("sys")] + + [C(c) for c in on] + + [C(c) for c in cols if c not in on] + + [diff] + ) + ) + + diff_col = sa.literal("D").label(status_col) + diff_col.type = String() + + right_left_merge = right.merge( + left, on=right_on, right_on=on, inner=False, rname=rname + ).filter( + sa.and_( + *[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711 + ) + ) + right_left_merge_select = right_left_merge._query.select( + *( + [C(c) for c in right_left_merge.signals_schema.db_signals("sys")] + + [ + C(c) # type: ignore[misc] + if c == rc + else sa.literal( + left._query.column_types[c].default_value(dialect) # type: ignore[index] + ).label(c) + for c, rc in zip(on, right_on) + ] + + [ + C(c) # type: ignore[misc] + if c in right_cols + else sa.literal( + left._query.column_types[c].default_value(dialect) # type: ignore[index] + ).label(c) # type: ignore[arg-type] + for c in cols + if c not in on + ] + + [diff_col] + ) + ) + + if not deleted: + res = left_right_merge_select + elif deleted and not any([added, modified, unchanged]): + res = right_left_merge_select + else: + res = left_right_merge_select.union(right_left_merge_select) + + res = res.filter(C(status_col) != None) # noqa: E711 + + schema = left.signals_schema + if need_status_col: + res = res.select() + schema = SignalSchema({status_col: str}) | schema + else: + res = res.select_except(C(status_col)) + + return left._evolve(query=res, signal_schema=schema) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 567156cb8..073053023 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -1069,6 +1069,7 @@ def __init__( if "sys__id" in self.column_types: self.column_types.pop("sys__id") self.starting_step = QueryStep(self.catalog, name, self.version) + self.dialect = self.catalog.warehouse.db.dialect def __iter__(self): return iter(self.db_results()) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 2535935df..af138c914 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2944,3 +2944,409 @@ def test_window_error(test_session): ), ): dc.mutate(first=func.sum("col2").over(window)) + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("unchanged", (True, False)) +@pytest.mark.parametrize("status_col", ("diff", None)) +@pytest.mark.parametrize("save", (True, False)) +def test_compare(test_session, added, deleted, modified, unchanged, status_col, save): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + if not any([added, deleted, modified, unchanged]): + with pytest.raises(ValueError) as exc_info: + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + unchanged=unchanged, + on=["id"], + status_col=status_col, + ) + assert str(exc_info.value) == ( + "At least one of added, deleted, modified, unchanged flags must be set" + ) + return + + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + unchanged=unchanged, + on=["id"], + status_col="diff", + ) + + if save: + diff.save("diff") + diff = DataChain.from_dataset("diff") + + expected = [] + if modified: + expected.append(("M", 1, "John1")) + if added: + expected.append(("A", 2, "Doe")) + if deleted: + expected.append(("D", 3, "Mark")) + if unchanged: + expected.append(("U", 4, "Andy")) + + collect_fields = ["diff", "id", "name"] + if not status_col: + expected = [row[1:] for row in expected] + collect_fields = collect_fields[1:] + + assert list(diff.order_by("id").collect(*collect_fields)) == expected + + +def test_compare_with_from_dataset(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + # this adds sys columns to ds1 and ds2 + ds1 = DataChain.from_dataset("ds1") + ds2 = DataChain.from_dataset("ds2") + + diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + + assert list(diff.order_by("id").collect("diff", "id", "name")) == [ + ("M", 1, "John1"), + ("A", 2, "Doe"), + ("D", 3, "Mark"), + ("U", 4, "Andy"), + ] + + +@pytest.mark.parametrize("added", (True,)) +@pytest.mark.parametrize("deleted", (True,)) +@pytest.mark.parametrize("modified", (True,)) +@pytest.mark.parametrize("unchanged", (True,)) +@pytest.mark.parametrize("right_name", ("other_name",)) +def test_compare_with_explicit_compare_fields( + test_session, added, deleted, modified, unchanged, right_name +): + if not any([added, deleted, modified, unchanged]): + pytest.skip("This case is tested in another test") + + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + city=["New York", "Boston", "San Francisco"], + session=test_session, + ).save("ds1") + + ds2_data = { + "id": [1, 3, 4], + "city": ["Washington", "Seattle", "Miami"], + f"{right_name}": ["John", "Mark", "Andy"], + "session": test_session, + } + + ds2 = DataChain.from_values(**ds2_data).save("ds2") + + diff = ds1.compare( + ds2, + on=["id"], + compare=["name"], + right_compare=[right_name], + added=added, + deleted=deleted, + modified=modified, + unchanged=unchanged, + status_col="diff", + ) + + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + + expected = [] + if modified: + expected.append(("M", 1, "John1", "New York")) + if added: + expected.append(("A", 2, "Doe", "Boston")) + if deleted: + expected.append( + ( + "D", + 3, + string_default if right_name == "other_name" else "Mark", + "Seattle", + ) + ) + if unchanged: + expected.append(("U", 4, "Andy", "San Francisco")) + + collect_fields = ["diff", "id", "name", "city"] + assert list(diff.order_by("id").collect(*collect_fields)) == expected + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("unchanged", (True, False)) +def test_compare_different_left_right_on_columns( + test_session, added, deleted, modified, unchanged +): + if not any([added, deleted, modified, unchanged]): + pytest.skip("This case is tested in another test") + + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + ds2 = DataChain.from_values( + other_id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + unchanged=unchanged, + on=["id"], + right_on=["other_id"], + status_col="diff", + ) + + int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect) + + expected = [] + if unchanged: + expected.append(("U", 4, "Andy")) + if added: + expected.append(("A", 2, "Doe")) + if modified: + expected.append(("M", 1, "John1")) + if deleted: + expected.append(("D", int_default, "Mark")) + + collect_fields = ["diff", "id", "name"] + assert list(diff.order_by("name").collect(*collect_fields)) == expected + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("unchanged", (True, False)) +@pytest.mark.parametrize("on_self", (True, False)) +def test_compare_on_equal_datasets( + test_session, added, deleted, modified, unchanged, on_self +): + if not any([added, deleted, modified, unchanged]): + pytest.skip("This case is tested in another test") + + ds1 = DataChain.from_values( + id=[1, 2, 3], + name=["John", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + if on_self: + ds2 = ds1 + else: + ds2 = DataChain.from_values( + id=[1, 2, 3], + name=["John", "Doe", "Andy"], + session=test_session, + ).save("ds2") + + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + unchanged=unchanged, + on=["id"], + status_col="diff", + ) + + if not unchanged: + expected = [] + else: + expected = [ + ("U", 1, "John"), + ("U", 2, "Doe"), + ("U", 3, "Andy"), + ] + + collect_fields = ["diff", "id", "name"] + assert list(diff.order_by("id").collect(*collect_fields)) == expected + + +def test_compare_multiple_columns(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + city=["Paris", "Berlin", "Tokyo"], + session=test_session, + ).save("ds2") + + diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John", "city": "London"}, + {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, + {"diff": "D", "id": 3, "name": "Mark", "city": "Berlin"}, + {"diff": "U", "id": 4, "name": "Andy", "city": "Tokyo"}, + ], + "id", + ) + + +def test_compare_multiple_match_columns(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "John", "Andy"], + city=["Paris", "Berlin", "Tokyo"], + session=test_session, + ).save("ds2") + + diff = ds1.compare(ds2, unchanged=True, on=["id", "name"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John", "city": "London"}, + {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, + {"diff": "D", "id": 3, "name": "John", "city": "Berlin"}, + {"diff": "U", "id": 4, "name": "Andy", "city": "Tokyo"}, + ], + "id", + ) + + +def test_compare_additional_column_on_left(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + + diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John", "city": "London"}, + {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, + {"diff": "D", "id": 3, "name": "Mark", "city": string_default}, + {"diff": "M", "id": 4, "name": "Andy", "city": "Tokyo"}, + ], + "id", + ) + + +def test_compare_additional_column_on_right(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds2") + + diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John"}, + {"diff": "A", "id": 2, "name": "Doe"}, + {"diff": "D", "id": 3, "name": "Mark"}, + {"diff": "M", "id": 4, "name": "Andy"}, + ], + "id", + ) + + +def test_compare_missing_on(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=None) + + assert str(exc_info.value) == "'on' must be specified" + + +def test_compare_right_on_wrong_length(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=["id"], right_on=["id", "name"]) + + assert str(exc_info.value) == "'on' and 'right_on' must be have the same length" + + +def test_compare_right_compare_defined_but_not_compare(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=["id"], right_compare=["name"]) + + assert str(exc_info.value) == ( + "'compare' must be defined if 'right_compare' is defined" + ) + + +def test_compare_right_compare_wrong_length(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=["id"], compare=["name"], right_compare=["name", "city"]) + + assert str(exc_info.value) == ( + "'compare' and 'right_compare' must be have the same length" + )