-
Notifications
You must be signed in to change notification settings - Fork 326
Use a join for upsert deduplication #1685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
71b474c
ee94650
d5217cb
caa9c57
385b760
ad0bd9d
a452d83
8326db4
73a7fe0
690824f
33eead0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,51 +59,30 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols | |
""" | ||
Return a table with rows that need to be updated in the target table based on the join columns. | ||
|
||
When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred. | ||
Only matched rows that have an actual change to a non-key column value will be returned in the final output. | ||
The table is joined on the identifier columns, and then checked if there are any updated rows. | ||
Those are selected and everything is renamed correctly. | ||
""" | ||
all_columns = set(source_table.column_names) | ||
join_cols_set = set(join_cols) | ||
non_key_cols = all_columns - join_cols_set | ||
|
||
non_key_cols = list(all_columns - join_cols_set) | ||
if has_duplicate_rows(target_table, join_cols): | ||
raise ValueError("Target table has duplicate rows, aborting upsert") | ||
|
||
if len(target_table) == 0: | ||
# When the target table is empty, there is nothing to update :) | ||
return source_table.schema.empty_table() | ||
|
||
match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols]) | ||
|
||
matching_source_rows = source_table.filter(match_expr) | ||
|
||
rows_to_update = [] | ||
|
||
for index in range(matching_source_rows.num_rows): | ||
source_row = matching_source_rows.slice(index, 1) | ||
|
||
target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols]) | ||
|
||
matching_target_row = target_table.filter(target_filter) | ||
|
||
if matching_target_row.num_rows > 0: | ||
needs_update = False | ||
|
||
for non_key_col in non_key_cols: | ||
source_value = source_row.column(non_key_col)[0].as_py() | ||
target_value = matching_target_row.column(non_key_col)[0].as_py() | ||
|
||
if source_value != target_value: | ||
needs_update = True | ||
break | ||
|
||
if needs_update: | ||
rows_to_update.append(source_row) | ||
|
||
if rows_to_update: | ||
rows_to_update_table = pa.concat_tables(rows_to_update) | ||
else: | ||
rows_to_update_table = source_table.schema.empty_table() | ||
|
||
common_columns = set(source_table.column_names).intersection(set(target_table.column_names)) | ||
rows_to_update_table = rows_to_update_table.select(list(common_columns)) | ||
|
||
return rows_to_update_table | ||
diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. de morgans law in the wild 🥇 |
||
|
||
return ( | ||
source_table | ||
# We already know that the schema is compatible, this is to fix large_ types | ||
.cast(target_table.schema) | ||
.join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should we add since we only check if source_table has duplicates, the target_table might produce duplicates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch! Since we've already filtered the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Included a test 👍 |
||
.filter(diff_expr) | ||
.drop_columns([f"{col}-rhs" for col in non_key_cols]) | ||
.rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh this is a dictionary! https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.rename_columns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, only the non-join columns get postfixed :) |
||
# Finally cast to the original schema since it doesn't carry nullability: | ||
# https://github.com/apache/arrow/issues/45557 | ||
).cast(target_table.schema) |
Uh oh!
There was an error while loading. Please reload this page.