Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DataChain.diff() #718

Merged
merged 11 commits into from
Dec 20, 2024
89 changes: 81 additions & 8 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,12 +1634,12 @@ def compare(
added: bool = True,
deleted: bool = True,
modified: bool = True,
unchanged: bool = False,
same: bool = False,
status_col: Optional[str] = None,
) -> "DataChain":
"""Comparing two chains by identifying rows that are added, deleted, modified
or unchanged. Result is the new chain that has additional column with possible
values: `A`, `D`, `M`, `U` representing added, deleted, modified and unchanged
or same. Result is the new chain that has additional column with possible
values: `A`, `D`, `M`, `U` representing added, deleted, modified and same
rows respectively. Note that if only one "status" is asked, by setting proper
flags, this additional column is not created as it would have only one value
for all rows. Beside additional diff column, new chain has schema of the chain
Expand All @@ -1652,20 +1652,20 @@ def compare(
`right_on` parameter has to specify the columns for the other chain.
This value is used to find corresponding row in other dataset. If not
found there, row is considered as added (or removed if vice versa), and
if found then row can be either modified or unchanged.
if found then row can be either modified or same.
right_on: Optional column or list of columns
for the `other` to match.
compare: Column or list of columns to compare on. If both chains have
the same columns then this column is enough for the compare. Otherwise,
`right_compare` parameter has to specify the columns for the other
chain. This value is used to see if row is modified or unchanged. If
chain. This value is used to see if row is modified or same. If
not set, all columns will be used for comparison
right_compare: Optional column or list of columns
for the `other` to compare to.
added (bool): Whether to return added rows in resulting chain.
deleted (bool): Whether to return deleted rows in resulting chain.
modified (bool): Whether to return modified rows in resulting chain.
unchanged (bool): Whether to return unchanged rows in resulting chain.
same (bool): Whether to return unchanged rows in resulting chain.
status_col (str): Name of the new column that is created in resulting chain
representing diff status.

Expand All @@ -1679,7 +1679,7 @@ def compare(
added=True,
deleted=True,
modified=True,
unchanged=True,
same=True,
status_col="diff"
)
```
Expand All @@ -1696,7 +1696,80 @@ def compare(
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
same=same,
status_col=status_col,
)

def diff(
self,
other: "DataChain",
on: str = "file",
right_on: Optional[str] = None,
added: bool = True,
modified: bool = True,
deleted: bool = False,
same: bool = False,
status_col: Optional[str] = None,
) -> "DataChain":
"""Similar to `.compare()`, which is more generic method to calculate difference
between two chains. Unlike `.compare()`, this method works only on those chains
that have `File` object, or it's derivatives, in it. File `source` and `path`
are used for matching, and file `version` and `etag` for comparing, while in
`.compare()` user needs to provide arbitrary columns for matching and comparing.

Parameters:
other: Chain to calculate diff from.
on: File signal to match on. If both chains have the
same file signal then this column is enough for the match. Otherwise,
`right_on` parameter has to specify the file signal for the other chain.
This value is used to find corresponding row in other dataset. If not
found there, row is considered as added (or removed if vice versa), and
if found then row can be either modified or same.
right_on: Optional file signal for the `other` to match.
added (bool): Whether to return added rows in resulting chain.
deleted (bool): Whether to return deleted rows in resulting chain.
modified (bool): Whether to return modified rows in resulting chain.
same (bool): Whether to return unchanged rows in resulting chain.
status_col (str): Optional name of the new column that is created in
resulting chain representing diff status.

Example:
```py
diff = images.diff(
new_images,
on="file",
right_on="other_file",
added=True,
deleted=True,
modified=True,
same=True,
status_col="diff"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's empty by default. The same in the description.

)
```
"""
on_file_signals = ["source", "path"]
compare_file_signals = ["version", "etag"]

def get_file_signals(file: str, signals):
return [f"{file}.{c}" for c in signals]

right_on = right_on or on

on_cols = get_file_signals(on, on_file_signals)
right_on_cols = get_file_signals(right_on, on_file_signals)
compare_cols = get_file_signals(on, compare_file_signals)
right_compare_cols = get_file_signals(right_on, compare_file_signals)

return self.compare(
other,
on_cols,
right_on=right_on_cols,
compare=compare_cols,
right_compare=right_compare_cols,
added=added,
deleted=deleted,
modified=modified,
same=same,
status_col=status_col,
)

Expand Down
37 changes: 18 additions & 19 deletions src/datachain/lib/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def compare( # noqa: PLR0912, PLR0915, C901
added: bool = True,
deleted: bool = True,
modified: bool = True,
unchanged: bool = False,
same: bool = True,
status_col: Optional[str] = None,
) -> "DataChain":
"""Comparing two chains by identifying rows that are added, deleted, modified
or unchanged"""
or same"""
dialect = left._query.dialect

rname = "right_"
Expand Down Expand Up @@ -67,9 +67,9 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
"'compare' and 'right_compare' must be have the same length"
)

if not any([added, deleted, modified, unchanged]):
if not any([added, deleted, modified, same]):
raise ValueError(
"At least one of added, deleted, modified, unchanged flags must be set"
"At least one of added, deleted, modified, same flags must be set"
)

# we still need status column for internal implementation even if not
Expand All @@ -94,7 +94,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
elif not compare and len(cols) != len(right_cols):
# here we will mark all rows that are not added or deleted as modified since
# there was no explicit list of compare columns provided (meaning we need
# to check all columns to determine if row is modified or unchanged), but
# to check all columns to determine if row is modified or same), but
# the number of columns on left and right is not the same (one of the chains
# have additional column)
compare = None
Expand All @@ -121,14 +121,14 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
]
)
diff_cond.append((modified_cond, "M"))
if unchanged and compare:
unchanged_cond = sa.and_(
if same and compare:
same_cond = sa.and_(
*[
C(c) == C(f"{_rprefix(c, rc)}{rc}")
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((unchanged_cond, "U"))
diff_cond.append((same_cond, "S"))

diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col)
diff.type = String()
Expand All @@ -155,23 +155,22 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
*[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711
)
)

def _default_val(chain: "DataChain", col: str):
col_type = chain._query.column_types[col] # type: ignore[index]
val = sa.literal(col_type.default_value(dialect)).label(col)
val.type = col_type()
return val

right_left_merge_select = right_left_merge._query.select(
*(
[C(c) for c in right_left_merge.signals_schema.db_signals("sys")]
+ [
C(c) # type: ignore[misc]
if c == rc
else sa.literal(
left._query.column_types[c].default_value(dialect) # type: ignore[index]
).label(c)
C(c) if c == rc else _default_val(left, c)
for c, rc in zip(on, right_on)
]
+ [
C(c) # type: ignore[misc]
if c in right_cols
else sa.literal(
left._query.column_types[c].default_value(dialect) # type: ignore[index]
).label(c) # type: ignore[arg-type]
C(c) if c in right_cols else _default_val(left, c) # type: ignore[arg-type]
for c in cols
if c not in on
]
Expand All @@ -181,7 +180,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:

if not deleted:
res = left_right_merge_select
elif deleted and not any([added, modified, unchanged]):
elif deleted and not any([added, modified, same]):
res = right_left_merge_select
else:
res = left_right_merge_select.union(right_left_merge_select)
Expand Down
Loading
Loading