Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a join for upsert deduplication #1685

Merged
merged 11 commits into from
Feb 21, 2025
7 changes: 7 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,13 @@ def upsert(
if upsert_util.has_duplicate_rows(df, join_cols):
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()
Expand Down
57 changes: 18 additions & 39 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,51 +59,30 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
"""
Return a table with rows that need to be updated in the target table based on the join columns.

When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred.
Only matched rows that have an actual change to a non-key column value will be returned in the final output.
The table is joined on the identifier columns, and then checked if there are any updated rows.
Those are selected and everything is renamed correctly.
"""
all_columns = set(source_table.column_names)
join_cols_set = set(join_cols)
non_key_cols = all_columns - join_cols_set

non_key_cols = list(all_columns - join_cols_set)
if has_duplicate_rows(target_table, join_cols):
raise ValueError("Target table has duplicate rows, aborting upsert")

if len(target_table) == 0:
# When the target table is empty, there is nothing to update :)
return source_table.schema.empty_table()

match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols])

matching_source_rows = source_table.filter(match_expr)

rows_to_update = []

for index in range(matching_source_rows.num_rows):
source_row = matching_source_rows.slice(index, 1)

target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols])

matching_target_row = target_table.filter(target_filter)

if matching_target_row.num_rows > 0:
needs_update = False

for non_key_col in non_key_cols:
source_value = source_row.column(non_key_col)[0].as_py()
target_value = matching_target_row.column(non_key_col)[0].as_py()

if source_value != target_value:
needs_update = True
break

if needs_update:
rows_to_update.append(source_row)

if rows_to_update:
rows_to_update_table = pa.concat_tables(rows_to_update)
else:
rows_to_update_table = source_table.schema.empty_table()

common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
rows_to_update_table = rows_to_update_table.select(list(common_columns))

return rows_to_update_table
diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

de morgans law in the wild 🥇


return (
source_table
# We already know that the schema is compatible, this is to fix large_ types
.cast(target_table.schema)
.join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we add coalesce_keys=True here to avoid duplicates in the resulting join table

since we only check if source_table has duplicates, the target_table might produce duplicates

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Since we've already filtered the target_table, I think we could also do the check there, it isn't that expensive anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Included a test 👍

.filter(diff_expr)
.drop_columns([f"{col}-rhs" for col in non_key_cols])
.rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is a dictionary! https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.rename_columns
and the non-join columns will be ignored by create_match_filter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, only the non-join columns get postfixed :)

# Finally cast to the original schema since it doesn't carry nullability:
# https://github.com/apache/arrow/issues/45557
).cast(target_table.schema)
42 changes: 42 additions & 0 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,48 @@ def test_create_match_filter_single_condition() -> None:
)


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"

_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "inhabitants", IntegerType(), required=True),
# Mark City as the identifier field, also known as the primary-key
identifier_field_ids=[1],
)

tbl = catalog.create_table(identifier, schema=schema)

arrow_schema = pa.schema(
[
pa.field("city", pa.string(), nullable=False),
pa.field("inhabitants", pa.int32(), nullable=False),
]
)

# Write some data
df = pa.Table.from_pylist(
[
{"city": "Drachten", "inhabitants": 45019},
{"city": "Drachten", "inhabitants": 45019},
],
schema=arrow_schema,
)
tbl.append(df)

df = pa.Table.from_pylist(
[
# Will be updated, the inhabitants has been updated
{"city": "Drachten", "inhabitants": 45505},
],
schema=arrow_schema,
)

with pytest.raises(ValueError, match="Target table has duplicate rows, aborting upsert"):
_ = tbl.upsert(df)


def test_upsert_without_identifier_fields(catalog: Catalog) -> None:
identifier = "default.test_upsert_without_identifier_fields"
_drop_table(catalog, identifier)
Expand Down