From 20c73b292e7eacd637ddca6021c867c6f59c1d5b Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Fri, 20 Dec 2024 16:43:12 +0100 Subject: [PATCH] Added `DataChain.diff()` (#718) Added `DataChain.diff()` --- src/datachain/lib/dc.py | 89 +++++- src/datachain/lib/diff.py | 37 ++- tests/unit/lib/test_datachain.py | 406 ------------------------- tests/unit/lib/test_diff.py | 498 +++++++++++++++++++++++++++++++ 4 files changed, 597 insertions(+), 433 deletions(-) create mode 100644 tests/unit/lib/test_diff.py diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 315fb4d9e..a3278b232 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1634,12 +1634,12 @@ def compare( added: bool = True, deleted: bool = True, modified: bool = True, - unchanged: bool = False, + same: bool = False, status_col: Optional[str] = None, ) -> "DataChain": """Comparing two chains by identifying rows that are added, deleted, modified - or unchanged. Result is the new chain that has additional column with possible - values: `A`, `D`, `M`, `U` representing added, deleted, modified and unchanged + or same. Result is the new chain that has additional column with possible + values: `A`, `D`, `M`, `U` representing added, deleted, modified and same rows respectively. Note that if only one "status" is asked, by setting proper flags, this additional column is not created as it would have only one value for all rows. Beside additional diff column, new chain has schema of the chain @@ -1652,20 +1652,20 @@ def compare( `right_on` parameter has to specify the columns for the other chain. This value is used to find corresponding row in other dataset. If not found there, row is considered as added (or removed if vice versa), and - if found then row can be either modified or unchanged. + if found then row can be either modified or same. right_on: Optional column or list of columns for the `other` to match. compare: Column or list of columns to compare on. If both chains have the same columns then this column is enough for the compare. Otherwise, `right_compare` parameter has to specify the columns for the other - chain. This value is used to see if row is modified or unchanged. If + chain. This value is used to see if row is modified or same. If not set, all columns will be used for comparison right_compare: Optional column or list of columns for the `other` to compare to. added (bool): Whether to return added rows in resulting chain. deleted (bool): Whether to return deleted rows in resulting chain. modified (bool): Whether to return modified rows in resulting chain. - unchanged (bool): Whether to return unchanged rows in resulting chain. + same (bool): Whether to return unchanged rows in resulting chain. status_col (str): Name of the new column that is created in resulting chain representing diff status. @@ -1679,7 +1679,7 @@ def compare( added=True, deleted=True, modified=True, - unchanged=True, + same=True, status_col="diff" ) ``` @@ -1696,7 +1696,80 @@ def compare( added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, + status_col=status_col, + ) + + def diff( + self, + other: "DataChain", + on: str = "file", + right_on: Optional[str] = None, + added: bool = True, + modified: bool = True, + deleted: bool = False, + same: bool = False, + status_col: Optional[str] = None, + ) -> "DataChain": + """Similar to `.compare()`, which is more generic method to calculate difference + between two chains. Unlike `.compare()`, this method works only on those chains + that have `File` object, or it's derivatives, in it. File `source` and `path` + are used for matching, and file `version` and `etag` for comparing, while in + `.compare()` user needs to provide arbitrary columns for matching and comparing. + + Parameters: + other: Chain to calculate diff from. + on: File signal to match on. If both chains have the + same file signal then this column is enough for the match. Otherwise, + `right_on` parameter has to specify the file signal for the other chain. + This value is used to find corresponding row in other dataset. If not + found there, row is considered as added (or removed if vice versa), and + if found then row can be either modified or same. + right_on: Optional file signal for the `other` to match. + added (bool): Whether to return added rows in resulting chain. + deleted (bool): Whether to return deleted rows in resulting chain. + modified (bool): Whether to return modified rows in resulting chain. + same (bool): Whether to return unchanged rows in resulting chain. + status_col (str): Optional name of the new column that is created in + resulting chain representing diff status. + + Example: + ```py + diff = images.diff( + new_images, + on="file", + right_on="other_file", + added=True, + deleted=True, + modified=True, + same=True, + status_col="diff" + ) + ``` + """ + on_file_signals = ["source", "path"] + compare_file_signals = ["version", "etag"] + + def get_file_signals(file: str, signals): + return [f"{file}.{c}" for c in signals] + + right_on = right_on or on + + on_cols = get_file_signals(on, on_file_signals) + right_on_cols = get_file_signals(right_on, on_file_signals) + compare_cols = get_file_signals(on, compare_file_signals) + right_compare_cols = get_file_signals(right_on, compare_file_signals) + + return self.compare( + other, + on_cols, + right_on=right_on_cols, + compare=compare_cols, + right_compare=right_compare_cols, + added=added, + deleted=deleted, + modified=modified, + same=same, status_col=status_col, ) diff --git a/src/datachain/lib/diff.py b/src/datachain/lib/diff.py index c8ae39303..cb2f79b77 100644 --- a/src/datachain/lib/diff.py +++ b/src/datachain/lib/diff.py @@ -26,11 +26,11 @@ def compare( # noqa: PLR0912, PLR0915, C901 added: bool = True, deleted: bool = True, modified: bool = True, - unchanged: bool = False, + same: bool = True, status_col: Optional[str] = None, ) -> "DataChain": """Comparing two chains by identifying rows that are added, deleted, modified - or unchanged""" + or same""" dialect = left._query.dialect rname = "right_" @@ -67,9 +67,9 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: "'compare' and 'right_compare' must be have the same length" ) - if not any([added, deleted, modified, unchanged]): + if not any([added, deleted, modified, same]): raise ValueError( - "At least one of added, deleted, modified, unchanged flags must be set" + "At least one of added, deleted, modified, same flags must be set" ) # we still need status column for internal implementation even if not @@ -94,7 +94,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: elif not compare and len(cols) != len(right_cols): # here we will mark all rows that are not added or deleted as modified since # there was no explicit list of compare columns provided (meaning we need - # to check all columns to determine if row is modified or unchanged), but + # to check all columns to determine if row is modified or same), but # the number of columns on left and right is not the same (one of the chains # have additional column) compare = None @@ -121,14 +121,14 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: ] ) diff_cond.append((modified_cond, "M")) - if unchanged and compare: - unchanged_cond = sa.and_( + if same and compare: + same_cond = sa.and_( *[ C(c) == C(f"{_rprefix(c, rc)}{rc}") for c, rc in zip(compare, right_compare) # type: ignore[arg-type] ] ) - diff_cond.append((unchanged_cond, "U")) + diff_cond.append((same_cond, "S")) diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col) diff.type = String() @@ -155,23 +155,22 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: *[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711 ) ) + + def _default_val(chain: "DataChain", col: str): + col_type = chain._query.column_types[col] # type: ignore[index] + val = sa.literal(col_type.default_value(dialect)).label(col) + val.type = col_type() + return val + right_left_merge_select = right_left_merge._query.select( *( [C(c) for c in right_left_merge.signals_schema.db_signals("sys")] + [ - C(c) # type: ignore[misc] - if c == rc - else sa.literal( - left._query.column_types[c].default_value(dialect) # type: ignore[index] - ).label(c) + C(c) if c == rc else _default_val(left, c) for c, rc in zip(on, right_on) ] + [ - C(c) # type: ignore[misc] - if c in right_cols - else sa.literal( - left._query.column_types[c].default_value(dialect) # type: ignore[index] - ).label(c) # type: ignore[arg-type] + C(c) if c in right_cols else _default_val(left, c) # type: ignore[arg-type] for c in cols if c not in on ] @@ -181,7 +180,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: if not deleted: res = left_right_merge_select - elif deleted and not any([added, modified, unchanged]): + elif deleted and not any([added, modified, same]): res = right_left_merge_select else: res = left_right_merge_select.union(right_left_merge_select) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index af138c914..2535935df 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2944,409 +2944,3 @@ def test_window_error(test_session): ), ): dc.mutate(first=func.sum("col2").over(window)) - - -@pytest.mark.parametrize("added", (True, False)) -@pytest.mark.parametrize("deleted", (True, False)) -@pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) -@pytest.mark.parametrize("status_col", ("diff", None)) -@pytest.mark.parametrize("save", (True, False)) -def test_compare(test_session, added, deleted, modified, unchanged, status_col, save): - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John1", "Doe", "Andy"], - session=test_session, - ).save("ds1") - - ds2 = DataChain.from_values( - id=[1, 3, 4], - name=["John", "Mark", "Andy"], - session=test_session, - ).save("ds2") - - if not any([added, deleted, modified, unchanged]): - with pytest.raises(ValueError) as exc_info: - diff = ds1.compare( - ds2, - added=added, - deleted=deleted, - modified=modified, - unchanged=unchanged, - on=["id"], - status_col=status_col, - ) - assert str(exc_info.value) == ( - "At least one of added, deleted, modified, unchanged flags must be set" - ) - return - - diff = ds1.compare( - ds2, - added=added, - deleted=deleted, - modified=modified, - unchanged=unchanged, - on=["id"], - status_col="diff", - ) - - if save: - diff.save("diff") - diff = DataChain.from_dataset("diff") - - expected = [] - if modified: - expected.append(("M", 1, "John1")) - if added: - expected.append(("A", 2, "Doe")) - if deleted: - expected.append(("D", 3, "Mark")) - if unchanged: - expected.append(("U", 4, "Andy")) - - collect_fields = ["diff", "id", "name"] - if not status_col: - expected = [row[1:] for row in expected] - collect_fields = collect_fields[1:] - - assert list(diff.order_by("id").collect(*collect_fields)) == expected - - -def test_compare_with_from_dataset(test_session): - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John1", "Doe", "Andy"], - session=test_session, - ).save("ds1") - - ds2 = DataChain.from_values( - id=[1, 3, 4], - name=["John", "Mark", "Andy"], - session=test_session, - ).save("ds2") - - # this adds sys columns to ds1 and ds2 - ds1 = DataChain.from_dataset("ds1") - ds2 = DataChain.from_dataset("ds2") - - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") - - assert list(diff.order_by("id").collect("diff", "id", "name")) == [ - ("M", 1, "John1"), - ("A", 2, "Doe"), - ("D", 3, "Mark"), - ("U", 4, "Andy"), - ] - - -@pytest.mark.parametrize("added", (True,)) -@pytest.mark.parametrize("deleted", (True,)) -@pytest.mark.parametrize("modified", (True,)) -@pytest.mark.parametrize("unchanged", (True,)) -@pytest.mark.parametrize("right_name", ("other_name",)) -def test_compare_with_explicit_compare_fields( - test_session, added, deleted, modified, unchanged, right_name -): - if not any([added, deleted, modified, unchanged]): - pytest.skip("This case is tested in another test") - - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John1", "Doe", "Andy"], - city=["New York", "Boston", "San Francisco"], - session=test_session, - ).save("ds1") - - ds2_data = { - "id": [1, 3, 4], - "city": ["Washington", "Seattle", "Miami"], - f"{right_name}": ["John", "Mark", "Andy"], - "session": test_session, - } - - ds2 = DataChain.from_values(**ds2_data).save("ds2") - - diff = ds1.compare( - ds2, - on=["id"], - compare=["name"], - right_compare=[right_name], - added=added, - deleted=deleted, - modified=modified, - unchanged=unchanged, - status_col="diff", - ) - - string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - - expected = [] - if modified: - expected.append(("M", 1, "John1", "New York")) - if added: - expected.append(("A", 2, "Doe", "Boston")) - if deleted: - expected.append( - ( - "D", - 3, - string_default if right_name == "other_name" else "Mark", - "Seattle", - ) - ) - if unchanged: - expected.append(("U", 4, "Andy", "San Francisco")) - - collect_fields = ["diff", "id", "name", "city"] - assert list(diff.order_by("id").collect(*collect_fields)) == expected - - -@pytest.mark.parametrize("added", (True, False)) -@pytest.mark.parametrize("deleted", (True, False)) -@pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) -def test_compare_different_left_right_on_columns( - test_session, added, deleted, modified, unchanged -): - if not any([added, deleted, modified, unchanged]): - pytest.skip("This case is tested in another test") - - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John1", "Doe", "Andy"], - session=test_session, - ).save("ds1") - - ds2 = DataChain.from_values( - other_id=[1, 3, 4], - name=["John", "Mark", "Andy"], - session=test_session, - ).save("ds2") - - diff = ds1.compare( - ds2, - added=added, - deleted=deleted, - modified=modified, - unchanged=unchanged, - on=["id"], - right_on=["other_id"], - status_col="diff", - ) - - int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect) - - expected = [] - if unchanged: - expected.append(("U", 4, "Andy")) - if added: - expected.append(("A", 2, "Doe")) - if modified: - expected.append(("M", 1, "John1")) - if deleted: - expected.append(("D", int_default, "Mark")) - - collect_fields = ["diff", "id", "name"] - assert list(diff.order_by("name").collect(*collect_fields)) == expected - - -@pytest.mark.parametrize("added", (True, False)) -@pytest.mark.parametrize("deleted", (True, False)) -@pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) -@pytest.mark.parametrize("on_self", (True, False)) -def test_compare_on_equal_datasets( - test_session, added, deleted, modified, unchanged, on_self -): - if not any([added, deleted, modified, unchanged]): - pytest.skip("This case is tested in another test") - - ds1 = DataChain.from_values( - id=[1, 2, 3], - name=["John", "Doe", "Andy"], - session=test_session, - ).save("ds1") - - if on_self: - ds2 = ds1 - else: - ds2 = DataChain.from_values( - id=[1, 2, 3], - name=["John", "Doe", "Andy"], - session=test_session, - ).save("ds2") - - diff = ds1.compare( - ds2, - added=added, - deleted=deleted, - modified=modified, - unchanged=unchanged, - on=["id"], - status_col="diff", - ) - - if not unchanged: - expected = [] - else: - expected = [ - ("U", 1, "John"), - ("U", 2, "Doe"), - ("U", 3, "Andy"), - ] - - collect_fields = ["diff", "id", "name"] - assert list(diff.order_by("id").collect(*collect_fields)) == expected - - -def test_compare_multiple_columns(test_session): - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John", "Doe", "Andy"], - city=["London", "New York", "Tokyo"], - session=test_session, - ).save("ds1") - ds2 = DataChain.from_values( - id=[1, 3, 4], - name=["John", "Mark", "Andy"], - city=["Paris", "Berlin", "Tokyo"], - session=test_session, - ).save("ds2") - - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") - - assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( - [ - {"diff": "M", "id": 1, "name": "John", "city": "London"}, - {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, - {"diff": "D", "id": 3, "name": "Mark", "city": "Berlin"}, - {"diff": "U", "id": 4, "name": "Andy", "city": "Tokyo"}, - ], - "id", - ) - - -def test_compare_multiple_match_columns(test_session): - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John", "Doe", "Andy"], - city=["London", "New York", "Tokyo"], - session=test_session, - ).save("ds1") - ds2 = DataChain.from_values( - id=[1, 3, 4], - name=["John", "John", "Andy"], - city=["Paris", "Berlin", "Tokyo"], - session=test_session, - ).save("ds2") - - diff = ds1.compare(ds2, unchanged=True, on=["id", "name"], status_col="diff") - - assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( - [ - {"diff": "M", "id": 1, "name": "John", "city": "London"}, - {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, - {"diff": "D", "id": 3, "name": "John", "city": "Berlin"}, - {"diff": "U", "id": 4, "name": "Andy", "city": "Tokyo"}, - ], - "id", - ) - - -def test_compare_additional_column_on_left(test_session): - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John", "Doe", "Andy"], - city=["London", "New York", "Tokyo"], - session=test_session, - ).save("ds1") - ds2 = DataChain.from_values( - id=[1, 3, 4], - name=["John", "Mark", "Andy"], - session=test_session, - ).save("ds2") - - string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") - - assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( - [ - {"diff": "M", "id": 1, "name": "John", "city": "London"}, - {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, - {"diff": "D", "id": 3, "name": "Mark", "city": string_default}, - {"diff": "M", "id": 4, "name": "Andy", "city": "Tokyo"}, - ], - "id", - ) - - -def test_compare_additional_column_on_right(test_session): - ds1 = DataChain.from_values( - id=[1, 2, 4], - name=["John", "Doe", "Andy"], - session=test_session, - ).save("ds1") - ds2 = DataChain.from_values( - id=[1, 3, 4], - name=["John", "Mark", "Andy"], - city=["London", "New York", "Tokyo"], - session=test_session, - ).save("ds2") - - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") - - assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( - [ - {"diff": "M", "id": 1, "name": "John"}, - {"diff": "A", "id": 2, "name": "Doe"}, - {"diff": "D", "id": 3, "name": "Mark"}, - {"diff": "M", "id": 4, "name": "Andy"}, - ], - "id", - ) - - -def test_compare_missing_on(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") - - with pytest.raises(ValueError) as exc_info: - ds1.compare(ds2, on=None) - - assert str(exc_info.value) == "'on' must be specified" - - -def test_compare_right_on_wrong_length(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") - - with pytest.raises(ValueError) as exc_info: - ds1.compare(ds2, on=["id"], right_on=["id", "name"]) - - assert str(exc_info.value) == "'on' and 'right_on' must be have the same length" - - -def test_compare_right_compare_defined_but_not_compare(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") - - with pytest.raises(ValueError) as exc_info: - ds1.compare(ds2, on=["id"], right_compare=["name"]) - - assert str(exc_info.value) == ( - "'compare' must be defined if 'right_compare' is defined" - ) - - -def test_compare_right_compare_wrong_length(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") - - with pytest.raises(ValueError) as exc_info: - ds1.compare(ds2, on=["id"], compare=["name"], right_compare=["name", "city"]) - - assert str(exc_info.value) == ( - "'compare' and 'right_compare' must be have the same length" - ) diff --git a/tests/unit/lib/test_diff.py b/tests/unit/lib/test_diff.py new file mode 100644 index 000000000..a19bd6a1a --- /dev/null +++ b/tests/unit/lib/test_diff.py @@ -0,0 +1,498 @@ +import pytest +from pydantic import BaseModel + +from datachain.lib.dc import DataChain +from datachain.lib.file import File +from datachain.sql.types import Int64, String +from tests.utils import sorted_dicts + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("same", (True, False)) +@pytest.mark.parametrize("status_col", ("diff", None)) +@pytest.mark.parametrize("save", (True, False)) +def test_compare(test_session, added, deleted, modified, same, status_col, save): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + if not any([added, deleted, modified, same]): + with pytest.raises(ValueError) as exc_info: + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + same=same, + on=["id"], + status_col=status_col, + ) + assert str(exc_info.value) == ( + "At least one of added, deleted, modified, same flags must be set" + ) + return + + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + same=same, + on=["id"], + status_col="diff", + ) + + if save: + diff.save("diff") + diff = DataChain.from_dataset("diff") + + expected = [] + if modified: + expected.append(("M", 1, "John1")) + if added: + expected.append(("A", 2, "Doe")) + if deleted: + expected.append(("D", 3, "Mark")) + if same: + expected.append(("S", 4, "Andy")) + + collect_fields = ["diff", "id", "name"] + if not status_col: + expected = [row[1:] for row in expected] + collect_fields = collect_fields[1:] + + assert list(diff.order_by("id").collect(*collect_fields)) == expected + + +def test_compare_with_from_dataset(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + # this adds sys columns to ds1 and ds2 + ds1 = DataChain.from_dataset("ds1") + ds2 = DataChain.from_dataset("ds2") + + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") + + assert list(diff.order_by("id").collect("diff", "id", "name")) == [ + ("M", 1, "John1"), + ("A", 2, "Doe"), + ("D", 3, "Mark"), + ("S", 4, "Andy"), + ] + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("same", (True, False)) +@pytest.mark.parametrize("right_name", ("other_name", "name")) +def test_compare_with_explicit_compare_fields( + test_session, added, deleted, modified, same, right_name +): + if not any([added, deleted, modified, same]): + pytest.skip("This case is tested in another test") + + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + city=["New York", "Boston", "San Francisco"], + session=test_session, + ).save("ds1") + + ds2_data = { + "id": [1, 3, 4], + "city": ["Washington", "Seattle", "Miami"], + f"{right_name}": ["John", "Mark", "Andy"], + "session": test_session, + } + + ds2 = DataChain.from_values(**ds2_data).save("ds2") + + diff = ds1.compare( + ds2, + on=["id"], + compare=["name"], + right_compare=[right_name], + added=added, + deleted=deleted, + modified=modified, + same=same, + status_col="diff", + ) + + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + + expected = [] + if modified: + expected.append(("M", 1, "John1", "New York")) + if added: + expected.append(("A", 2, "Doe", "Boston")) + if deleted: + expected.append( + ( + "D", + 3, + string_default if right_name == "other_name" else "Mark", + "Seattle", + ) + ) + if same: + expected.append(("S", 4, "Andy", "San Francisco")) + + collect_fields = ["diff", "id", "name", "city"] + assert list(diff.order_by("id").collect(*collect_fields)) == expected + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("same", (True, False)) +def test_compare_different_left_right_on_columns( + test_session, added, deleted, modified, same +): + if not any([added, deleted, modified, same]): + pytest.skip("This case is tested in another test") + + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John1", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + ds2 = DataChain.from_values( + other_id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + same=same, + on=["id"], + right_on=["other_id"], + status_col="diff", + ) + + int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect) + + expected = [] + if same: + expected.append(("S", 4, "Andy")) + if added: + expected.append(("A", 2, "Doe")) + if modified: + expected.append(("M", 1, "John1")) + if deleted: + expected.append(("D", int_default, "Mark")) + + collect_fields = ["diff", "id", "name"] + assert list(diff.order_by("name").collect(*collect_fields)) == expected + + +@pytest.mark.parametrize("added", (True, False)) +@pytest.mark.parametrize("deleted", (True, False)) +@pytest.mark.parametrize("modified", (True, False)) +@pytest.mark.parametrize("same", (True, False)) +@pytest.mark.parametrize("on_self", (True, False)) +def test_compare_on_equal_datasets( + test_session, added, deleted, modified, same, on_self +): + if not any([added, deleted, modified, same]): + pytest.skip("This case is tested in another test") + + ds1 = DataChain.from_values( + id=[1, 2, 3], + name=["John", "Doe", "Andy"], + session=test_session, + ).save("ds1") + + if on_self: + ds2 = ds1 + else: + ds2 = DataChain.from_values( + id=[1, 2, 3], + name=["John", "Doe", "Andy"], + session=test_session, + ).save("ds2") + + diff = ds1.compare( + ds2, + added=added, + deleted=deleted, + modified=modified, + same=same, + on=["id"], + status_col="diff", + ) + + if not same: + expected = [] + else: + expected = [ + ("S", 1, "John"), + ("S", 2, "Doe"), + ("S", 3, "Andy"), + ] + + collect_fields = ["diff", "id", "name"] + assert list(diff.order_by("id").collect(*collect_fields)) == expected + + +def test_compare_multiple_columns(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + city=["Paris", "Berlin", "Tokyo"], + session=test_session, + ).save("ds2") + + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John", "city": "London"}, + {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, + {"diff": "D", "id": 3, "name": "Mark", "city": "Berlin"}, + {"diff": "S", "id": 4, "name": "Andy", "city": "Tokyo"}, + ], + "id", + ) + + +def test_compare_multiple_match_columns(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "John", "Andy"], + city=["Paris", "Berlin", "Tokyo"], + session=test_session, + ).save("ds2") + + diff = ds1.compare(ds2, same=True, on=["id", "name"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John", "city": "London"}, + {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, + {"diff": "D", "id": 3, "name": "John", "city": "Berlin"}, + {"diff": "S", "id": 4, "name": "Andy", "city": "Tokyo"}, + ], + "id", + ) + + +def test_compare_additional_column_on_left(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + session=test_session, + ).save("ds2") + + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John", "city": "London"}, + {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, + {"diff": "D", "id": 3, "name": "Mark", "city": string_default}, + {"diff": "M", "id": 4, "name": "Andy", "city": "Tokyo"}, + ], + "id", + ) + + +def test_compare_additional_column_on_right(test_session): + ds1 = DataChain.from_values( + id=[1, 2, 4], + name=["John", "Doe", "Andy"], + session=test_session, + ).save("ds1") + ds2 = DataChain.from_values( + id=[1, 3, 4], + name=["John", "Mark", "Andy"], + city=["London", "New York", "Tokyo"], + session=test_session, + ).save("ds2") + + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") + + assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( + [ + {"diff": "M", "id": 1, "name": "John"}, + {"diff": "A", "id": 2, "name": "Doe"}, + {"diff": "D", "id": 3, "name": "Mark"}, + {"diff": "M", "id": 4, "name": "Andy"}, + ], + "id", + ) + + +def test_compare_missing_on(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=None) + + assert str(exc_info.value) == "'on' must be specified" + + +def test_compare_right_on_wrong_length(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=["id"], right_on=["id", "name"]) + + assert str(exc_info.value) == "'on' and 'right_on' must be have the same length" + + +def test_compare_right_compare_defined_but_not_compare(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=["id"], right_compare=["name"]) + + assert str(exc_info.value) == ( + "'compare' must be defined if 'right_compare' is defined" + ) + + +def test_compare_right_compare_wrong_length(test_session): + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + + with pytest.raises(ValueError) as exc_info: + ds1.compare(ds2, on=["id"], compare=["name"], right_compare=["name", "city"]) + + assert str(exc_info.value) == ( + "'compare' and 'right_compare' must be have the same length" + ) + + +@pytest.mark.parametrize("status_col", ("diff", None)) +def test_diff(test_session, status_col): + fs1 = File(source="s1", path="p1", version="2", etag="e2") + fs1_updated = File(source="s1", path="p1", version="1", etag="e1") + fs2 = File(source="s2", path="p2", version="1", etag="e1") + fs3 = File(source="s3", path="p3", version="1", etag="e1") + fs4 = File(source="s4", path="p4", version="1", etag="e1") + + ds1 = DataChain.from_values( + file=[fs1_updated, fs2, fs4], score=[1, 2, 4], session=test_session + ) + ds2 = DataChain.from_values( + file=[fs1, fs3, fs4], score=[1, 3, 4], session=test_session + ) + + diff = ds1.diff( + ds2, + added=True, + deleted=True, + modified=True, + same=True, + on="file", + status_col=status_col, + ) + + expected = [ + ("M", fs1_updated, 1), + ("A", fs2, 2), + ("D", fs3, 3), + ("S", fs4, 4), + ] + + collect_fields = ["diff", "file", "score"] + if not status_col: + expected = [row[1:] for row in expected] + collect_fields = collect_fields[1:] + + assert list(diff.order_by("file.source").collect(*collect_fields)) == expected + + +@pytest.mark.parametrize("status_col", ("diff", None)) +def test_diff_nested(test_session, status_col): + class Nested(BaseModel): + file: File + + fs1 = Nested(file=File(source="s1", path="p1", version="2", etag="e2")) + fs1_updated = Nested(file=File(source="s1", path="p1", version="1", etag="e1")) + fs2 = Nested(file=File(source="s2", path="p2", version="1", etag="e1")) + fs3 = Nested(file=File(source="s3", path="p3", version="1", etag="e1")) + fs4 = Nested(file=File(source="s4", path="p4", version="1", etag="e1")) + + ds1 = DataChain.from_values( + nested=[fs1_updated, fs2, fs4], score=[1, 2, 4], session=test_session + ) + ds2 = DataChain.from_values( + nested=[fs1, fs3, fs4], score=[1, 3, 4], session=test_session + ) + + diff = ds1.diff( + ds2, + added=True, + deleted=True, + modified=True, + same=True, + on="nested.file", + status_col=status_col, + ) + + expected = [ + ("M", fs1_updated, 1), + ("A", fs2, 2), + ("D", fs3, 3), + ("S", fs4, 4), + ] + + collect_fields = ["diff", "nested", "score"] + if not status_col: + expected = [row[1:] for row in expected] + collect_fields = collect_fields[1:] + + assert ( + list(diff.order_by("nested.file.source").collect(*collect_fields)) == expected + )