Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DataChain.diff() #718

Merged
merged 11 commits into from
Dec 20, 2024
71 changes: 71 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,77 @@ def compare(
status_col=status_col,
)

def diff(
self,
other: "DataChain",
on: str = "file",
right_on: Optional[str] = None,
added: bool = True,
deleted: bool = True,
modified: bool = True,
unchanged: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think diff() should have only added=True and modified=True. the rest should be False.

We need to optimize it for delta update use case which process only new and changed files.

status_col: Optional[str] = None,
) -> "DataChain":
"""Similar as .compare() but for file based chains, i.e. those that have
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mention the similarity to compare() in the end of the description. We cannot assum that user knows compare() alread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added better docs

File object, or it's derivatives, in it. For matching file `source` and
`path` are used, and for comparing file `version` and `etag`.

Parameters:
other: Chain to calculate diff from.
on: File signal to match on. If both chains have the
same file signal then this column is enough for the match. Otherwise,
`right_on` parameter has to specify the file signal for the other chain.
This value is used to find corresponding row in other dataset. If not
found there, row is considered as added (or removed if vice versa), and
if found then row can be either modified or unchanged.
right_on: Optional file signal for the `other` to match.
added (bool): Whether to return added rows in resulting chain.
deleted (bool): Whether to return deleted rows in resulting chain.
modified (bool): Whether to return modified rows in resulting chain.
unchanged (bool): Whether to return unchanged rows in resulting chain.
status_col (str): Name of the new column that is created in resulting chain
representing diff status.

Example:
```py
diff = images.diff(
new_images,
on="file",
right_on="other_file",
added=True,
deleted=True,
modified=True,
unchanged=True,
status_col="diff"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's empty by default. The same in the description.

)
```
"""
on_file_signals = ["source", "path"]
compare_file_signals = ["version", "etag"]

def get_file_signals(file: str, signals):
return [f"{file}.{c}" for c in signals]

right_on = right_on or on

on_cols = get_file_signals(on, on_file_signals)
right_on_cols = get_file_signals(right_on, on_file_signals)
compare_cols = get_file_signals(on, compare_file_signals)
right_compare_cols = get_file_signals(right_on, compare_file_signals)

return self.compare(
other,
on_cols,
right_on=right_on_cols,
compare=compare_cols,
right_compare=right_compare_cols,
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
status_col=status_col,
)

@classmethod
def from_values(
cls,
Expand Down
105 changes: 105 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3350,3 +3350,108 @@ def test_compare_right_compare_wrong_length(test_session):
assert str(exc_info.value) == (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test file is massive 😱

Could you please extract compare and diff test to a separate file like test_diff.py. See test_merge.py as an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"'compare' and 'right_compare' must be have the same length"
)


@pytest.mark.parametrize("added", (True, False))
@pytest.mark.parametrize("deleted", (True, False))
@pytest.mark.parametrize("modified", (True, False))
@pytest.mark.parametrize("unchanged", (True, False))
@pytest.mark.parametrize("status_col", ("diff", None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test looks overcomplicated 🙂

Is it possible to separate this to small tests instead of large parametrized one?

Also, I don't see value in all modified, deleted, unchanged statuses since compare() tests have to cover this, not diff(). The same for status_col. 2-3 simple tests should cover diff() pretty well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified it

def test_diff(test_session, added, deleted, modified, unchanged, status_col):
if not any([added, deleted, modified, unchanged]):
pytest.skip("This case is tested in another test")

fs1 = File(source="s1", path="p1", version="2", etag="e2")
fs1_updated = File(source="s1", path="p1", version="1", etag="e1")
fs2 = File(source="s2", path="p2", version="1", etag="e1")
fs3 = File(source="s3", path="p3", version="1", etag="e1")
fs4 = File(source="s4", path="p4", version="1", etag="e1")

ds1 = DataChain.from_values(
file=[fs1_updated, fs2, fs4], score=[1, 2, 4], session=test_session
)
ds2 = DataChain.from_values(
file=[fs1, fs3, fs4], score=[1, 3, 4], session=test_session
)

diff = ds1.diff(
ds2,
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
on="file",
status_col=status_col,
)

expected = []
if modified:
expected.append(("M", fs1_updated, 1))
if added:
expected.append(("A", fs2, 2))
if deleted:
expected.append(("D", fs3, 3))
if unchanged:
expected.append(("U", fs4, 4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I was sure that we decided to get rid of U and use S (Same) instead, didn't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will change it in my followup PR. I forgot to change it in that main one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to S .. also changed flag name from unchanged to same


collect_fields = ["diff", "file", "score"]
if not status_col:
expected = [row[1:] for row in expected]
collect_fields = collect_fields[1:]

assert list(diff.order_by("file.source").collect(*collect_fields)) == expected


@pytest.mark.parametrize("added", (True, False))
@pytest.mark.parametrize("deleted", (True, False))
@pytest.mark.parametrize("modified", (True, False))
@pytest.mark.parametrize("unchanged", (True, False))
@pytest.mark.parametrize("status_col", ("diff", None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - to complicated to unit test and it test mostly compare(), not diff().
I'd appreciate it if you could simplify it.

def test_diff_nested(test_session, added, deleted, modified, unchanged, status_col):
if not any([added, deleted, modified, unchanged]):
pytest.skip("This case is tested in another test")

class Nested(BaseModel):
file: File

fs1 = Nested(file=File(source="s1", path="p1", version="2", etag="e2"))
fs1_updated = Nested(file=File(source="s1", path="p1", version="1", etag="e1"))
fs2 = Nested(file=File(source="s2", path="p2", version="1", etag="e1"))
fs3 = Nested(file=File(source="s3", path="p3", version="1", etag="e1"))
fs4 = Nested(file=File(source="s4", path="p4", version="1", etag="e1"))

ds1 = DataChain.from_values(
nested=[fs1_updated, fs2, fs4], score=[1, 2, 4], session=test_session
)
ds2 = DataChain.from_values(
nested=[fs1, fs3, fs4], score=[1, 3, 4], session=test_session
)

diff = ds1.diff(
ds2,
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
on="nested.file",
status_col=status_col,
)

expected = []
if modified:
expected.append(("M", fs1_updated, 1))
if added:
expected.append(("A", fs2, 2))
if deleted:
expected.append(("D", fs3, 3))
if unchanged:
expected.append(("U", fs4, 4))

collect_fields = ["diff", "nested", "score"]
if not status_col:
expected = [row[1:] for row in expected]
collect_fields = collect_fields[1:]

assert (
list(diff.order_by("nested.file.source").collect(*collect_fields)) == expected
)
Loading