diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 1cebba204..a3278b232 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1634,12 +1634,12 @@ def compare( added: bool = True, deleted: bool = True, modified: bool = True, - unchanged: bool = False, + same: bool = False, status_col: Optional[str] = None, ) -> "DataChain": """Comparing two chains by identifying rows that are added, deleted, modified - or unchanged. Result is the new chain that has additional column with possible - values: `A`, `D`, `M`, `U` representing added, deleted, modified and unchanged + or same. Result is the new chain that has additional column with possible + values: `A`, `D`, `M`, `U` representing added, deleted, modified and same rows respectively. Note that if only one "status" is asked, by setting proper flags, this additional column is not created as it would have only one value for all rows. Beside additional diff column, new chain has schema of the chain @@ -1652,20 +1652,20 @@ def compare( `right_on` parameter has to specify the columns for the other chain. This value is used to find corresponding row in other dataset. If not found there, row is considered as added (or removed if vice versa), and - if found then row can be either modified or unchanged. + if found then row can be either modified or same. right_on: Optional column or list of columns for the `other` to match. compare: Column or list of columns to compare on. If both chains have the same columns then this column is enough for the compare. Otherwise, `right_compare` parameter has to specify the columns for the other - chain. This value is used to see if row is modified or unchanged. If + chain. This value is used to see if row is modified or same. If not set, all columns will be used for comparison right_compare: Optional column or list of columns for the `other` to compare to. added (bool): Whether to return added rows in resulting chain. deleted (bool): Whether to return deleted rows in resulting chain. modified (bool): Whether to return modified rows in resulting chain. - unchanged (bool): Whether to return unchanged rows in resulting chain. + same (bool): Whether to return unchanged rows in resulting chain. status_col (str): Name of the new column that is created in resulting chain representing diff status. @@ -1679,7 +1679,7 @@ def compare( added=True, deleted=True, modified=True, - unchanged=True, + same=True, status_col="diff" ) ``` @@ -1696,7 +1696,7 @@ def compare( added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, status_col=status_col, ) @@ -1708,7 +1708,7 @@ def diff( added: bool = True, modified: bool = True, deleted: bool = False, - unchanged: bool = False, + same: bool = False, status_col: Optional[str] = None, ) -> "DataChain": """Similar to `.compare()`, which is more generic method to calculate difference @@ -1724,12 +1724,12 @@ def diff( `right_on` parameter has to specify the file signal for the other chain. This value is used to find corresponding row in other dataset. If not found there, row is considered as added (or removed if vice versa), and - if found then row can be either modified or unchanged. + if found then row can be either modified or same. right_on: Optional file signal for the `other` to match. added (bool): Whether to return added rows in resulting chain. deleted (bool): Whether to return deleted rows in resulting chain. modified (bool): Whether to return modified rows in resulting chain. - unchanged (bool): Whether to return unchanged rows in resulting chain. + same (bool): Whether to return unchanged rows in resulting chain. status_col (str): Optional name of the new column that is created in resulting chain representing diff status. @@ -1742,7 +1742,7 @@ def diff( added=True, deleted=True, modified=True, - unchanged=True, + same=True, status_col="diff" ) ``` @@ -1769,7 +1769,7 @@ def get_file_signals(file: str, signals): added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, status_col=status_col, ) diff --git a/src/datachain/lib/diff.py b/src/datachain/lib/diff.py index 6a46395b0..45c6cf644 100644 --- a/src/datachain/lib/diff.py +++ b/src/datachain/lib/diff.py @@ -26,11 +26,11 @@ def compare( # noqa: PLR0912, PLR0915, C901 added: bool = True, deleted: bool = True, modified: bool = True, - unchanged: bool = True, + same: bool = True, status_col: Optional[str] = None, ) -> "DataChain": """Comparing two chains by identifying rows that are added, deleted, modified - or unchanged""" + or same""" dialect = left._query.dialect rname = "right_" @@ -67,9 +67,9 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: "'compare' and 'right_compare' must be have the same length" ) - if not any([added, deleted, modified, unchanged]): + if not any([added, deleted, modified, same]): raise ValueError( - "At least one of added, deleted, modified, unchanged flags must be set" + "At least one of added, deleted, modified, same flags must be set" ) # we still need status column for internal implementation even if not @@ -94,7 +94,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: elif not compare and len(cols) != len(right_cols): # here we will mark all rows that are not added or deleted as modified since # there was no explicit list of compare columns provided (meaning we need - # to check all columns to determine if row is modified or unchanged), but + # to check all columns to determine if row is modified or same), but # the number of columns on left and right is not the same (one of the chains # have additional column) compare = None @@ -121,14 +121,14 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: ] ) diff_cond.append((modified_cond, "M")) - if unchanged and compare: - unchanged_cond = sa.and_( + if same and compare: + same_cond = sa.and_( *[ C(c) == C(f"{_rprefix(c, rc)}{rc}") for c, rc in zip(compare, right_compare) # type: ignore[arg-type] ] ) - diff_cond.append((unchanged_cond, "U")) + diff_cond.append((same_cond, "S")) diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col) diff.type = String() @@ -181,7 +181,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: if not deleted: res = left_right_merge_select - elif deleted and not any([added, modified, unchanged]): + elif deleted and not any([added, modified, same]): res = right_left_merge_select else: res = left_right_merge_select.union(right_left_merge_select) diff --git a/tests/unit/lib/test_diff.py b/tests/unit/lib/test_diff.py index 182052dce..a19bd6a1a 100644 --- a/tests/unit/lib/test_diff.py +++ b/tests/unit/lib/test_diff.py @@ -10,10 +10,10 @@ @pytest.mark.parametrize("added", (True, False)) @pytest.mark.parametrize("deleted", (True, False)) @pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) +@pytest.mark.parametrize("same", (True, False)) @pytest.mark.parametrize("status_col", ("diff", None)) @pytest.mark.parametrize("save", (True, False)) -def test_compare(test_session, added, deleted, modified, unchanged, status_col, save): +def test_compare(test_session, added, deleted, modified, same, status_col, save): ds1 = DataChain.from_values( id=[1, 2, 4], name=["John1", "Doe", "Andy"], @@ -26,19 +26,19 @@ def test_compare(test_session, added, deleted, modified, unchanged, status_col, session=test_session, ).save("ds2") - if not any([added, deleted, modified, unchanged]): + if not any([added, deleted, modified, same]): with pytest.raises(ValueError) as exc_info: diff = ds1.compare( ds2, added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, on=["id"], status_col=status_col, ) assert str(exc_info.value) == ( - "At least one of added, deleted, modified, unchanged flags must be set" + "At least one of added, deleted, modified, same flags must be set" ) return @@ -47,7 +47,7 @@ def test_compare(test_session, added, deleted, modified, unchanged, status_col, added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, on=["id"], status_col="diff", ) @@ -63,8 +63,8 @@ def test_compare(test_session, added, deleted, modified, unchanged, status_col, expected.append(("A", 2, "Doe")) if deleted: expected.append(("D", 3, "Mark")) - if unchanged: - expected.append(("U", 4, "Andy")) + if same: + expected.append(("S", 4, "Andy")) collect_fields = ["diff", "id", "name"] if not status_col: @@ -91,25 +91,25 @@ def test_compare_with_from_dataset(test_session): ds1 = DataChain.from_dataset("ds1") ds2 = DataChain.from_dataset("ds2") - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") assert list(diff.order_by("id").collect("diff", "id", "name")) == [ ("M", 1, "John1"), ("A", 2, "Doe"), ("D", 3, "Mark"), - ("U", 4, "Andy"), + ("S", 4, "Andy"), ] @pytest.mark.parametrize("added", (True, False)) @pytest.mark.parametrize("deleted", (True, False)) @pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) +@pytest.mark.parametrize("same", (True, False)) @pytest.mark.parametrize("right_name", ("other_name", "name")) def test_compare_with_explicit_compare_fields( - test_session, added, deleted, modified, unchanged, right_name + test_session, added, deleted, modified, same, right_name ): - if not any([added, deleted, modified, unchanged]): + if not any([added, deleted, modified, same]): pytest.skip("This case is tested in another test") ds1 = DataChain.from_values( @@ -136,7 +136,7 @@ def test_compare_with_explicit_compare_fields( added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, status_col="diff", ) @@ -156,8 +156,8 @@ def test_compare_with_explicit_compare_fields( "Seattle", ) ) - if unchanged: - expected.append(("U", 4, "Andy", "San Francisco")) + if same: + expected.append(("S", 4, "Andy", "San Francisco")) collect_fields = ["diff", "id", "name", "city"] assert list(diff.order_by("id").collect(*collect_fields)) == expected @@ -166,11 +166,11 @@ def test_compare_with_explicit_compare_fields( @pytest.mark.parametrize("added", (True, False)) @pytest.mark.parametrize("deleted", (True, False)) @pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) +@pytest.mark.parametrize("same", (True, False)) def test_compare_different_left_right_on_columns( - test_session, added, deleted, modified, unchanged + test_session, added, deleted, modified, same ): - if not any([added, deleted, modified, unchanged]): + if not any([added, deleted, modified, same]): pytest.skip("This case is tested in another test") ds1 = DataChain.from_values( @@ -190,7 +190,7 @@ def test_compare_different_left_right_on_columns( added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, on=["id"], right_on=["other_id"], status_col="diff", @@ -199,8 +199,8 @@ def test_compare_different_left_right_on_columns( int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect) expected = [] - if unchanged: - expected.append(("U", 4, "Andy")) + if same: + expected.append(("S", 4, "Andy")) if added: expected.append(("A", 2, "Doe")) if modified: @@ -215,12 +215,12 @@ def test_compare_different_left_right_on_columns( @pytest.mark.parametrize("added", (True, False)) @pytest.mark.parametrize("deleted", (True, False)) @pytest.mark.parametrize("modified", (True, False)) -@pytest.mark.parametrize("unchanged", (True, False)) +@pytest.mark.parametrize("same", (True, False)) @pytest.mark.parametrize("on_self", (True, False)) def test_compare_on_equal_datasets( - test_session, added, deleted, modified, unchanged, on_self + test_session, added, deleted, modified, same, on_self ): - if not any([added, deleted, modified, unchanged]): + if not any([added, deleted, modified, same]): pytest.skip("This case is tested in another test") ds1 = DataChain.from_values( @@ -243,18 +243,18 @@ def test_compare_on_equal_datasets( added=added, deleted=deleted, modified=modified, - unchanged=unchanged, + same=same, on=["id"], status_col="diff", ) - if not unchanged: + if not same: expected = [] else: expected = [ - ("U", 1, "John"), - ("U", 2, "Doe"), - ("U", 3, "Andy"), + ("S", 1, "John"), + ("S", 2, "Doe"), + ("S", 3, "Andy"), ] collect_fields = ["diff", "id", "name"] @@ -275,14 +275,14 @@ def test_compare_multiple_columns(test_session): session=test_session, ).save("ds2") - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( [ {"diff": "M", "id": 1, "name": "John", "city": "London"}, {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, {"diff": "D", "id": 3, "name": "Mark", "city": "Berlin"}, - {"diff": "U", "id": 4, "name": "Andy", "city": "Tokyo"}, + {"diff": "S", "id": 4, "name": "Andy", "city": "Tokyo"}, ], "id", ) @@ -302,14 +302,14 @@ def test_compare_multiple_match_columns(test_session): session=test_session, ).save("ds2") - diff = ds1.compare(ds2, unchanged=True, on=["id", "name"], status_col="diff") + diff = ds1.compare(ds2, same=True, on=["id", "name"], status_col="diff") assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( [ {"diff": "M", "id": 1, "name": "John", "city": "London"}, {"diff": "A", "id": 2, "name": "Doe", "city": "New York"}, {"diff": "D", "id": 3, "name": "John", "city": "Berlin"}, - {"diff": "U", "id": 4, "name": "Andy", "city": "Tokyo"}, + {"diff": "S", "id": 4, "name": "Andy", "city": "Tokyo"}, ], "id", ) @@ -330,7 +330,7 @@ def test_compare_additional_column_on_left(test_session): string_default = String.default_value(test_session.catalog.warehouse.db.dialect) - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( [ @@ -356,7 +356,7 @@ def test_compare_additional_column_on_right(test_session): session=test_session, ).save("ds2") - diff = ds1.compare(ds2, unchanged=True, on=["id"], status_col="diff") + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( [ @@ -433,7 +433,7 @@ def test_diff(test_session, status_col): added=True, deleted=True, modified=True, - unchanged=True, + same=True, on="file", status_col=status_col, ) @@ -442,7 +442,7 @@ def test_diff(test_session, status_col): ("M", fs1_updated, 1), ("A", fs2, 2), ("D", fs3, 3), - ("U", fs4, 4), + ("S", fs4, 4), ] collect_fields = ["diff", "file", "score"] @@ -476,7 +476,7 @@ class Nested(BaseModel): added=True, deleted=True, modified=True, - unchanged=True, + same=True, on="nested.file", status_col=status_col, ) @@ -485,7 +485,7 @@ class Nested(BaseModel): ("M", fs1_updated, 1), ("A", fs2, 2), ("D", fs3, 3), - ("U", fs4, 4), + ("S", fs4, 4), ] collect_fields = ["diff", "nested", "score"]