diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index fd8c7cac1c..f60ac1e3ee 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import itertools import uuid from abc import ABC, abstractmethod from datetime import datetime @@ -466,6 +467,46 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl return base_metadata.model_copy(update=metadata_updates) +@_apply_table_update.register(RemoveSnapshotsUpdate) +def _(update: RemoveSnapshotsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + for remove_snapshot_id in update.snapshot_ids: + if not any(snapshot.snapshot_id == remove_snapshot_id for snapshot in base_metadata.snapshots): + raise ValueError(f"Snapshot with snapshot id {remove_snapshot_id} does not exist: {base_metadata.snapshots}") + + snapshots = [ + ( + snapshot.model_copy(update={"parent_snapshot_id": None}) + if snapshot.parent_snapshot_id in update.snapshot_ids + else snapshot + ) + for snapshot in base_metadata.snapshots + if snapshot.snapshot_id not in update.snapshot_ids + ] + snapshot_log = [ + snapshot_log_entry + for snapshot_log_entry in base_metadata.snapshot_log + if snapshot_log_entry.snapshot_id not in update.snapshot_ids + ] + + remove_ref_updates = ( + RemoveSnapshotRefUpdate(ref_name=ref_name) + for ref_name, ref in base_metadata.refs.items() + if ref.snapshot_id in update.snapshot_ids + ) + remove_statistics_updates = ( + RemoveStatisticsUpdate(statistics_file.snapshot_id) + for statistics_file in base_metadata.statistics + if statistics_file.snapshot_id in update.snapshot_ids + ) + updates = itertools.chain(remove_ref_updates, remove_statistics_updates) + new_metadata = base_metadata + for upd in updates: + new_metadata = _apply_table_update(upd, new_metadata, context) + + context.add_update(update) + return new_metadata.model_copy(update={"snapshots": snapshots, "snapshot_log": snapshot_log}) + + @_apply_table_update.register(RemoveSnapshotRefUpdate) def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.ref_name not in base_metadata.refs: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 4836c7bbad..69bbab527e 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -79,6 +79,7 @@ AssertTableUUID, RemovePropertiesUpdate, RemoveSnapshotRefUpdate, + RemoveSnapshotsUpdate, RemoveStatisticsUpdate, SetDefaultSortOrderUpdate, SetPropertiesUpdate, @@ -795,6 +796,42 @@ def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None: def test_update_remove_snapshots(table_v2: Table) -> None: + REMOVE_SNAPSHOT = 3051729675574597004 + KEEP_SNAPSHOT = 3055729675574597004 + # assert fixture data to easily understand the test assumptions + assert len(table_v2.metadata.snapshots) == 2 + assert len(table_v2.metadata.snapshot_log) == 2 + assert len(table_v2.metadata.refs) == 2 + update = RemoveSnapshotsUpdate(snapshot_ids=[REMOVE_SNAPSHOT]) + new_metadata = update_table_metadata(table_v2.metadata, (update,)) + assert len(new_metadata.snapshots) == 1 + assert new_metadata.snapshots[0].snapshot_id == KEEP_SNAPSHOT + assert new_metadata.snapshots[0].parent_snapshot_id is None + assert new_metadata.current_snapshot_id == KEEP_SNAPSHOT + assert new_metadata.last_updated_ms > table_v2.metadata.last_updated_ms + assert len(new_metadata.snapshot_log) == 1 + assert new_metadata.snapshot_log[0].snapshot_id == KEEP_SNAPSHOT + assert len(new_metadata.refs) == 1 + assert new_metadata.refs["main"].snapshot_id == KEEP_SNAPSHOT + + +def test_update_remove_snapshots_doesnt_exist(table_v2: Table) -> None: + update = RemoveSnapshotsUpdate( + snapshot_ids=[123], + ) + with pytest.raises(ValueError, match="Snapshot with snapshot id 123 does not exist"): + update_table_metadata(table_v2.metadata, (update,)) + + +def test_update_remove_snapshots_remove_current_snapshot_id(table_v2: Table) -> None: + update = RemoveSnapshotsUpdate(snapshot_ids=[3055729675574597004]) + new_metadata = update_table_metadata(table_v2.metadata, (update,)) + assert len(new_metadata.refs) == 1 + assert new_metadata.refs["test"].snapshot_id == 3051729675574597004 + assert new_metadata.current_snapshot_id is None + + +def test_update_remove_snapshot_ref(table_v2: Table) -> None: # assert fixture data to easily understand the test assumptions assert len(table_v2.metadata.refs) == 2 update = RemoveSnapshotRefUpdate(ref_name="test")