From 0c4268b097c9f19060b3980a646eb0e033ad6992 Mon Sep 17 00:00:00 2001 From: ilongin Date: Wed, 11 Dec 2024 10:47:15 +0100 Subject: [PATCH] added logic of different schemas when compare columns are not provided --- src/datachain/lib/diff.py | 18 +++++++++++++----- tests/unit/lib/test_datachain.py | 2 -- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/datachain/lib/diff.py b/src/datachain/lib/diff.py index f53bd6351..6060db992 100644 --- a/src/datachain/lib/diff.py +++ b/src/datachain/lib/diff.py @@ -89,6 +89,14 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: right_compare = right_compare or compare compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment] right_compare = right.signals_schema.resolve(*right_compare).db_signals() # type: ignore[assignment] + elif not compare and len(cols) != len(right_cols): + # here we will mark all rows that are not added or deleted as modified since + # there was no explicit list of compare columns provided (meaning we need + # to check all columns to determine if row is modified or unchanged), but + # the number of columns on left and right is not the same (one of the chains + # have additional column) + compare = None + right_compare = None else: compare = [c for c in cols if c in right_cols] # type: ignore[misc, assignment] right_compare = compare @@ -103,24 +111,24 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]: ] ) diff_cond.append((added_cond, "A")) - if modified: + if modified and compare: modified_cond = sa.or_( *[ C(c) != C(f"{_rprefix(c, rc)}{rc}") - for c, rc in zip(compare, right_compare) + for c, rc in zip(compare, right_compare) # type: ignore[arg-type] ] ) diff_cond.append((modified_cond, "M")) - if unchanged: + if unchanged and compare: unchanged_cond = sa.and_( *[ C(c) == C(f"{_rprefix(c, rc)}{rc}") - for c, rc in zip(compare, right_compare) + for c, rc in zip(compare, right_compare) # type: ignore[arg-type] ] ) diff_cond.append((unchanged_cond, "U")) - diff = sa.case(*diff_cond, else_=None).label(status_col) + diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col) left_right_merge = left.merge( right, on=on, right_on=right_on, inner=False, rname=rname diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index ea0ed2929..2ba3e7ca0 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -3214,7 +3214,6 @@ def test_compare_multiple_match_columns(test_session): def test_compare_additional_column_on_left(test_session): - pytest.skip() ds1 = DataChain.from_values( id=[1, 2, 4], name=["John", "Doe", "Andy"], @@ -3241,7 +3240,6 @@ def test_compare_additional_column_on_left(test_session): def test_compare_additional_column_on_right(test_session): - pytest.skip() ds1 = DataChain.from_values( id=[1, 2, 4], name=["John", "Doe", "Andy"],