diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index d84c82ec2a..d81fd4ad07 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -1270,6 +1270,24 @@ with table.update_spec() as update: update.rename_field("bucketed_id", "sharded_id") ``` +## Sort order updates + +Users can update the sort order on existing tables for new data. See [sorting](https://iceberg.apache.org/spec/#sorting) for more details. + +The API to use when updating a sort order is the `update_sort_order` API on the table. + +Sort orders can only be updated by adding a new sort order. They cannot be deleted or modified. + +### Updating a sort order on a table + +To create a new sort order, you can use either the `asc` or `desc` API depending on whether you want you data sorted in ascending or descending order. Both take the name of the field, the sort order transform, and a null order that describes the order of null values when sorted. + +```python +with table.update_sort_order() as update: + update.desc("event_ts", DayTransform(), NullOrder.NULLS_FIRST) + update.asc("some_field", IdentityTransform(), NullOrder.NULLS_LAST) +``` + ## Table properties Set and remove properties through the `Transaction` API: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9e9de52dee..972efc8c47 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -80,6 +80,7 @@ from pyiceberg.schema import Schema from pyiceberg.table.inspect import InspectTable from pyiceberg.table.locations import LocationProvider, load_location_provider +from pyiceberg.table.maintenance import MaintenanceTable from pyiceberg.table.metadata import ( INITIAL_SEQUENCE_NUMBER, TableMetadata, @@ -87,7 +88,7 @@ from pyiceberg.table.name_mapping import ( NameMapping, ) -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import ( Snapshot, SnapshotLogEntry, @@ -120,6 +121,7 @@ UpdateSnapshot, _FastAppendFiles, ) +from pyiceberg.table.update.sorting import UpdateSortOrder from pyiceberg.table.update.spec import UpdateSpec from pyiceberg.table.update.statistics import UpdateStatistics from pyiceberg.transforms import IdentityTransform @@ -141,12 +143,14 @@ from pyiceberg.utils.properties import property_as_bool if TYPE_CHECKING: + import bodo.pandas as bd import daft import pandas as pd import polars as pl import pyarrow as pa import ray from duckdb import DuckDBPyConnection + from pyiceberg_core.datafusion import IcebergDataFusionTable from pyiceberg.catalog import Catalog @@ -192,6 +196,9 @@ class TableProperties: WRITE_TARGET_FILE_SIZE_BYTES = "write.target-file-size-bytes" WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT = 512 * 1024 * 1024 # 512 MB + WRITE_AVRO_COMPRESSION = "write.avro.compression-codec" + WRITE_AVRO_COMPRESSION_DEFAULT = "gzip" + DEFAULT_WRITE_METRICS_MODE = "write.metadata.metrics.default" DEFAULT_WRITE_METRICS_MODE_DEFAULT = "truncate(16)" @@ -209,6 +216,9 @@ class TableProperties: WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT = True WRITE_DATA_PATH = "write.data.path" + + WRITE_FILE_FORMAT = "write.format.default" + WRITE_FILE_FORMAT_DEFAULT = "parquet" WRITE_METADATA_PATH = "write.metadata.path" DELETE_MODE = "write.delete.mode" @@ -218,7 +228,7 @@ class TableProperties: DEFAULT_NAME_MAPPING = "schema.name-mapping.default" FORMAT_VERSION = "format-version" - DEFAULT_FORMAT_VERSION = 2 + DEFAULT_FORMAT_VERSION: TableVersion = 2 MANIFEST_TARGET_SIZE_BYTES = "commit.manifest.target-size-bytes" MANIFEST_TARGET_SIZE_BYTES_DEFAULT = 8 * 1024 * 1024 # 8 MB @@ -291,8 +301,6 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ if self._autocommit: self.commit_transaction() - self._updates = () - self._requirements = () return self @@ -398,7 +406,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE expr = Or(expr, match_partition_expression) return expr - def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _FastAppendFiles: + def _append_snapshot_producer( + self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH + ) -> _FastAppendFiles: """Determine the append type based on table properties. Args: @@ -411,7 +421,7 @@ def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _Fas TableProperties.MANIFEST_MERGE_ENABLED, TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, ) - update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties) + update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch) return update_snapshot.merge_append() if manifest_merge_enabled else update_snapshot.fast_append() def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: @@ -431,13 +441,29 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive name_mapping=self.table_metadata.name_mapping(), ) - def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> UpdateSnapshot: + def update_sort_order(self, case_sensitive: bool = True) -> UpdateSortOrder: + """Create a new UpdateSortOrder to update the sort order of this table. + + Args: + case_sensitive: If field names are case-sensitive. + + Returns: + A new UpdateSortOrder. + """ + return UpdateSortOrder( + self, + case_sensitive=case_sensitive, + ) + + def update_snapshot( + self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + ) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: A new UpdateSnapshot """ - return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties) + return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) def update_statistics(self) -> UpdateStatistics: """ @@ -448,13 +474,14 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. Args: df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the append operation """ try: import pyarrow as pa @@ -466,18 +493,15 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - if unsupported_partitions := [ - field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform - ]: - raise ValueError( - f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." - ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( - self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + self.table_metadata.schema(), + provided_schema=df.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=self.table_metadata.format_version, ) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = list( @@ -488,7 +512,9 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) for data_file in data_files: append_files.append_data_file(data_file) - def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def dynamic_partition_overwrite( + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + ) -> None: """ Shorthand for overwriting existing partitions with a PyArrow table. @@ -499,6 +525,7 @@ def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[st Args: df: The Arrow dataframe that will be used to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the dynamic partition overwrite operation """ try: import pyarrow as pa @@ -521,7 +548,10 @@ def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[st downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( - self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + self.table_metadata.schema(), + provided_schema=df.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=self.table_metadata.format_version, ) # If dataframe does not have data, there is no need to overwrite @@ -537,9 +567,9 @@ def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[st partitions_to_overwrite = {data_file.partition for data_file in data_files} delete_filter = self._build_partition_predicate(partition_records=partitions_to_overwrite) - self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties) + self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties, branch=branch) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: append_files.commit_uuid = append_snapshot_commit_uuid for data_file in data_files: append_files.append_data_file(data_file) @@ -550,6 +580,7 @@ def overwrite( overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for adding a table overwrite with a PyArrow table to the transaction. @@ -557,15 +588,16 @@ def overwrite( An overwrite may produce zero or more snapshots based on the operation: - DELETE: In case existing Parquet files can be dropped completely. - - REPLACE: In case existing Parquet files need to be rewritten. + - OVERWRITE: In case existing Parquet files need to be rewritten to drop rows that match the overwrite filter. - APPEND: In case new data is being inserted into the table. Args: df: The Arrow dataframe that will be used to overwrite the table overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite - case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary + case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive + branch: Branch Reference to run the overwrite operation """ try: import pyarrow as pa @@ -577,22 +609,24 @@ def overwrite( if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - if unsupported_partitions := [ - field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform - ]: - raise ValueError( - f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." - ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( - self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + self.table_metadata.schema(), + provided_schema=df.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=self.table_metadata.format_version, ) if overwrite_filter != AlwaysFalse(): # Only delete when the filter is != AlwaysFalse - self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) + self.delete( + delete_filter=overwrite_filter, + case_sensitive=case_sensitive, + snapshot_properties=snapshot_properties, + branch=branch, + ) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = _dataframe_to_data_files( @@ -606,6 +640,7 @@ def delete( delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for deleting record from a table. @@ -613,12 +648,13 @@ def delete( A delete may produce zero or more snapshots based on the operation: - DELETE: In case existing Parquet files can be dropped completely. - - REPLACE: In case existing Parquet files need to be rewritten + - OVERWRITE: In case existing Parquet files need to be rewritten to drop rows that match the delete filter. Args: delete_filter: A boolean expression to delete rows from a table snapshot_properties: Custom properties to be added to the snapshot summary case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive + branch: Branch Reference to run the delete operation """ from pyiceberg.io.pyarrow import ( ArrowScan, @@ -635,7 +671,7 @@ def delete( if isinstance(delete_filter, str): delete_filter = _parse_row_filter(delete_filter) - with self.update_snapshot(snapshot_properties=snapshot_properties).delete() as delete_snapshot: + with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot: delete_snapshot.delete_by_predicate(delete_filter, case_sensitive) # Check if there are any files that require an actual rewrite of a data file @@ -643,7 +679,10 @@ def delete( bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive) preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter) - files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files() + file_scan = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive) + if branch is not None: + file_scan = file_scan.use_ref(branch) + files = file_scan.plan_files() commit_uuid = uuid.uuid4() counter = itertools.count(0) @@ -685,7 +724,9 @@ def delete( ) if len(replaced_files) > 0: - with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot: + with self.update_snapshot( + snapshot_properties=snapshot_properties, branch=branch + ).overwrite() as overwrite_snapshot: overwrite_snapshot.commit_uuid = commit_uuid for original_data_file, replaced_data_files in replaced_files: overwrite_snapshot.delete_data_file(original_data_file) @@ -695,8 +736,156 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records") + def upsert( + self, + df: pa.Table, + join_cols: Optional[List[str]] = None, + when_matched_update_all: bool = True, + when_not_matched_insert_all: bool = True, + case_sensitive: bool = True, + branch: Optional[str] = MAIN_BRANCH, + ) -> UpsertResult: + """Shorthand API for performing an upsert to an iceberg table. + + Args: + + df: The input dataframe to upsert with the table's data. + join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. + when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing + when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table + case_sensitive: Bool indicating if the match should be case-sensitive + branch: Branch Reference to run the upsert operation + + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it + + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it + + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) + + + Returns: + An UpsertResult class (contains details of rows updated and inserted) + """ + try: + import pyarrow as pa # noqa: F401 + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + from pyiceberg.io.pyarrow import expression_to_pyarrow + from pyiceberg.table import upsert_util + + if join_cols is None: + join_cols = [] + for field_id in self.table_metadata.schema().identifier_field_ids: + col = self.table_metadata.schema().find_column_name(field_id) + if col is not None: + join_cols.append(col) + else: + raise ValueError(f"Field-ID could not be found: {join_cols}") + + if len(join_cols) == 0: + raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") + + if not when_matched_update_all and not when_not_matched_insert_all: + raise ValueError("no upsert options selected...exiting") + + if upsert_util.has_duplicate_rows(df, join_cols): + raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") + + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible + + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + self.table_metadata.schema(), + provided_schema=df.schema, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=self.table_metadata.format_version, + ) + + # get list of rows that exist so we don't have to load the entire target table + matched_predicate = upsert_util.create_match_filter(df, join_cols) + + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. + + matched_iceberg_record_batches_scan = DataScan( + table_metadata=self.table_metadata, + io=self._table.io, + row_filter=matched_predicate, + case_sensitive=case_sensitive, + ) + + if branch in self.table_metadata.refs: + matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) + + matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() + + batches_to_overwrite = [] + overwrite_predicates = [] + rows_to_insert = df + + for batch in matched_iceberg_record_batches: + rows = pa.Table.from_batches([batch]) + + if when_matched_update_all: + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed + # this extra step avoids unnecessary IO and writes + rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) + + if len(rows_to_update) > 0: + # build the match predicate filter + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) + + batches_to_overwrite.append(rows_to_update) + overwrite_predicates.append(overwrite_mask_predicate) + + if when_not_matched_insert_all: + expr_match = upsert_util.create_match_filter(rows, join_cols) + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) + expr_match_arrow = expression_to_pyarrow(expr_match_bound) + + # Filter rows per batch. + rows_to_insert = rows_to_insert.filter(~expr_match_arrow) + + update_row_cnt = 0 + insert_row_cnt = 0 + + if batches_to_overwrite: + rows_to_update = pa.concat_tables(batches_to_overwrite) + update_row_cnt = len(rows_to_update) + self.overwrite( + rows_to_update, + overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], + branch=branch, + ) + + if when_not_matched_insert_all: + insert_row_cnt = len(rows_to_insert) + if rows_to_insert: + self.append(rows_to_insert, branch=branch) + + return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + def add_files( - self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True + self, + file_paths: List[str], + snapshot_properties: Dict[str, str] = EMPTY_DICT, + check_duplicate_files: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand API for adding files as data files to the table transaction. @@ -716,7 +905,7 @@ def add_files( import pyarrow.compute as pc expr = pc.field("file_path").isin(file_paths) - referenced_files = [file["file_path"] for file in self._table.inspect.files().filter(expr).to_pylist()] + referenced_files = [file["file_path"] for file in self._table.inspect.data_files().filter(expr).to_pylist()] if referenced_files: raise ValueError(f"Cannot add files that are already referenced by table, files: {', '.join(referenced_files)}") @@ -725,12 +914,12 @@ def add_files( self.set_properties( **{TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json()} ) - with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: data_files = _parquet_files_to_data_files( table_metadata=self.table_metadata, file_paths=file_paths, io=self._table.io ) for data_file in data_files: - update_snapshot.append_data_file(data_file) + append_files.append_data_file(data_file) def update_spec(self) -> UpdateSpec: """Create a new UpdateSpec to update the partitioning of the table. @@ -774,13 +963,15 @@ def commit_transaction(self) -> Table: updates=self._updates, requirements=self._requirements, ) - return self._table - else: - return self._table + + self._updates = () + self._requirements = () + + return self._table class CreateTableTransaction(Transaction): - """A transaction that involves the creation of a a new table.""" + """A transaction that involves the creation of a new table.""" def _initial_changes(self, table_metadata: TableMetadata) -> None: """Set the initial changes that can reconstruct the initial table metadata when creating the CreateTableTransaction.""" @@ -791,7 +982,7 @@ def _initial_changes(self, table_metadata: TableMetadata) -> None: schema: Schema = table_metadata.schema() self._updates += ( - AddSchemaUpdate(schema_=schema, last_column_id=schema.highest_field_id), + AddSchemaUpdate(schema_=schema), SetCurrentSchemaUpdate(schema_id=-1), ) @@ -825,11 +1016,15 @@ def commit_transaction(self) -> Table: Returns: The table with the updates applied. """ - self._requirements = (AssertCreate(),) - self._table._do_commit( # pylint: disable=W0212 - updates=self._updates, - requirements=self._requirements, - ) + if len(self._updates) > 0: + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=(AssertCreate(),), + ) + + self._updates = () + self._requirements = () + return self._table @@ -907,6 +1102,15 @@ def inspect(self) -> InspectTable: """ return InspectTable(self) + @property + def maintenance(self) -> MaintenanceTable: + """Return the MaintenanceTable object for maintenance. + + Returns: + MaintenanceTable object based on this Table. + """ + return MaintenanceTable(self) + def refresh(self) -> Table: """Refresh the current table metadata. @@ -1113,6 +1317,14 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive name_mapping=self.name_mapping(), ) + def update_sort_order(self, case_sensitive: bool = True) -> UpdateSortOrder: + """Create a new UpdateSortOrder to update the sort order of this table. + + Returns: + A new UpdateSortOrder. + """ + return UpdateSortOrder(transaction=Transaction(self, autocommit=True), case_sensitive=case_sensitive) + def name_mapping(self) -> Optional[NameMapping]: """Return the table's field-id NameMapping.""" return self.metadata.name_mapping() @@ -1124,6 +1336,7 @@ def upsert( when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -1134,6 +1347,7 @@ def upsert( when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table case_sensitive: Bool indicating if the match should be case-sensitive + branch: Branch Reference to run the upsert operation To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids @@ -1159,95 +1373,41 @@ def upsert( Returns: An UpsertResult class (contains details of rows updated and inserted) """ - try: - import pyarrow as pa # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - from pyiceberg.io.pyarrow import expression_to_pyarrow - from pyiceberg.table import upsert_util - - if join_cols is None: - join_cols = [] - for field_id in self.schema().identifier_field_ids: - col = self.schema().find_column_name(field_id) - if col is not None: - join_cols.append(col) - else: - raise ValueError(f"Field-ID could not be found: {join_cols}") - - if len(join_cols) == 0: - raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") - - if not when_matched_update_all and not when_not_matched_insert_all: - raise ValueError("no upsert options selected...exiting") - - if upsert_util.has_duplicate_rows(df, join_cols): - raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") - - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible - - downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_pyarrow_schema_compatible( - self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) - - # get list of rows that exist so we don't have to load the entire target table - matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() - - update_row_cnt = 0 - insert_row_cnt = 0 - with self.transaction() as tx: - if when_matched_update_all: - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed - # this extra step avoids unnecessary IO and writes - rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) - - update_row_cnt = len(rows_to_update) - - if len(rows_to_update) > 0: - # build the match predicate filter - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) - - tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) - - if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - rows_to_insert = df.filter(~expr_match_arrow) - - insert_row_cnt = len(rows_to_insert) - - if insert_row_cnt > 0: - tx.append(rows_to_insert) - - return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + return tx.upsert( + df=df, + join_cols=join_cols, + when_matched_update_all=when_matched_update_all, + when_not_matched_insert_all=when_not_matched_insert_all, + case_sensitive=case_sensitive, + branch=branch, + ) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to the table. Args: df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the append operation """ with self.transaction() as tx: - tx.append(df=df, snapshot_properties=snapshot_properties) + tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) - def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def dynamic_partition_overwrite( + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + ) -> None: """Shorthand for dynamic overwriting the table with a PyArrow table. Old partitions are auto detected and replaced with data files created for input arrow table. Args: df: The Arrow dataframe that will be used to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the dynamic partition overwrite operation """ with self.transaction() as tx: - tx.dynamic_partition_overwrite(df=df, snapshot_properties=snapshot_properties) + tx.dynamic_partition_overwrite(df=df, snapshot_properties=snapshot_properties, branch=branch) def overwrite( self, @@ -1255,6 +1415,7 @@ def overwrite( overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for overwriting the table with a PyArrow table. @@ -1262,7 +1423,7 @@ def overwrite( An overwrite may produce zero or more snapshots based on the operation: - DELETE: In case existing Parquet files can be dropped completely. - - REPLACE: In case existing Parquet files need to be rewritten. + - OVERWRITE: In case existing Parquet files need to be rewritten to drop rows that match the overwrite filter.. - APPEND: In case new data is being inserted into the table. Args: @@ -1271,10 +1432,15 @@ def overwrite( or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive + branch: Branch Reference to run the overwrite operation """ with self.transaction() as tx: tx.overwrite( - df=df, overwrite_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties + df=df, + overwrite_filter=overwrite_filter, + case_sensitive=case_sensitive, + snapshot_properties=snapshot_properties, + branch=branch, ) def delete( @@ -1282,6 +1448,7 @@ def delete( delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for deleting rows from the table. @@ -1290,12 +1457,19 @@ def delete( delete_filter: The predicate that used to remove rows snapshot_properties: Custom properties to be added to the snapshot summary case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive + branch: Branch Reference to run the delete operation """ with self.transaction() as tx: - tx.delete(delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) + tx.delete( + delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch + ) def add_files( - self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True + self, + file_paths: List[str], + snapshot_properties: Dict[str, str] = EMPTY_DICT, + check_duplicate_files: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand API for adding files as data files to the table. @@ -1308,7 +1482,10 @@ def add_files( """ with self.transaction() as tx: tx.add_files( - file_paths=file_paths, snapshot_properties=snapshot_properties, check_duplicate_files=check_duplicate_files + file_paths=file_paths, + snapshot_properties=snapshot_properties, + check_duplicate_files=check_duplicate_files, + branch=branch, ) def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: @@ -1361,6 +1538,16 @@ def to_daft(self) -> daft.DataFrame: return daft.read_iceberg(self) + def to_bodo(self) -> bd.DataFrame: + """Read a bodo DataFrame lazily from this Iceberg table. + + Returns: + bd.DataFrame: Unmaterialized Bodo Dataframe created from the Iceberg table + """ + import bodo.pandas as bd + + return bd.read_iceberg_table(self) + def to_polars(self) -> pl.LazyFrame: """Lazily read from this Apache Iceberg table. @@ -1371,6 +1558,51 @@ def to_polars(self) -> pl.LazyFrame: return pl.scan_iceberg(self) + def __datafusion_table_provider__(self) -> "IcebergDataFusionTable": + """Return the DataFusion table provider PyCapsule interface. + + To support DataFusion features such as push down filtering, this function will return a PyCapsule + interface that conforms to the FFI Table Provider required by DataFusion. From an end user perspective + you should not need to call this function directly. Instead you can use ``register_table_provider`` in + the DataFusion SessionContext. + + Returns: + A PyCapsule DataFusion TableProvider interface. + + Example: + ```python + from datafusion import SessionContext + from pyiceberg.catalog import load_catalog + import pyarrow as pa + catalog = load_catalog("catalog", type="in-memory") + catalog.create_namespace_if_not_exists("default") + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + iceberg_table = catalog.create_table("default.test", schema=data.schema) + iceberg_table.append(data) + ctx = SessionContext() + ctx.register_table_provider("test", iceberg_table) + ctx.table("test").show() + ``` + Results in + ``` + DataFrame() + +---+---+ + | x | y | + +---+---+ + | 1 | 4 | + | 2 | 5 | + | 3 | 6 | + +---+---+ + ``` + """ + from pyiceberg_core.datafusion import IcebergDataFusionTable + + return IcebergDataFusionTable( + identifier=self.name(), + metadata_location=self.metadata_location, + file_io_properties=self.io.properties, + ).__datafusion_table_provider__() + class StaticTable(Table): """Load a table directly from a metadata file (i.e., without using a catalog).""" @@ -1522,7 +1754,14 @@ def to_polars(self) -> pl.DataFrame: ... def update(self: S, **overrides: Any) -> S: """Create a copy of this table scan with updated fields.""" - return type(self)(**{**self.__dict__, **overrides}) + from inspect import signature + + # Extract those attributes that are constructor parameters. We don't use self.__dict__ as the kwargs to the + # constructors because it may contain additional attributes that are not part of the constructor signature. + params = signature(type(self).__init__).parameters.keys() - {"self"} # Skip "self" parameter + kwargs = {param: getattr(self, param) for param in params} # Assume parameters are attributes + + return type(self)(**{**kwargs, **overrides}) def use_ref(self: S, name: str) -> S: if self.snapshot_id: @@ -1672,13 +1911,11 @@ def _build_metrics_evaluator(self) -> Callable[[DataFile], bool]: def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], ResidualEvaluator]: spec = self.table_metadata.specs()[spec_id] + from pyiceberg.expressions.visitors import residual_evaluator_of + # The lambda created here is run in multiple threads. # So we avoid creating _EvaluatorExpression methods bound to a single # shared instance across multiple threads. - # return lambda data_file: (partition_schema, partition_expr, self.case_sensitive)(data_file.partition) - from pyiceberg.expressions.visitors import residual_evaluator_of - - # assert self.row_filter == False return lambda datafile: ( residual_evaluator_of( spec=spec, @@ -1688,7 +1925,8 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu ) ) - def _check_sequence_number(self, min_sequence_number: int, manifest: ManifestFile) -> bool: + @staticmethod + def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) -> bool: """Ensure that no manifests are loaded that contain deletes that are older than the data. Args: @@ -1906,14 +2144,6 @@ def generate_data_file_filename(self, extension: str) -> str: return f"00000-{self.task_id}-{self.write_uuid}.{extension}" -@dataclass(frozen=True) -class AddFileTask: - """Task with the parameters for adding a Parquet file as a DataFile.""" - - file_path: str - partition_field_value: Record - - def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]: """Convert a list files into DataFiles. diff --git a/pyiceberg/table/update/sorting.py b/pyiceberg/table/update/sorting.py new file mode 100644 index 0000000000..a356229f91 --- /dev/null +++ b/pyiceberg/table/update/sorting.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +from pyiceberg.table.sorting import INITIAL_SORT_ORDER_ID, UNSORTED_SORT_ORDER, NullOrder, SortDirection, SortField, SortOrder +from pyiceberg.table.update import ( + AddSortOrderUpdate, + AssertDefaultSortOrderId, + SetDefaultSortOrderUpdate, + TableRequirement, + TableUpdate, + UpdatesAndRequirements, + UpdateTableMetadata, +) +from pyiceberg.transforms import Transform + +if TYPE_CHECKING: + from pyiceberg.table import Transaction + + +class UpdateSortOrder(UpdateTableMetadata["UpdateSortOrder"]): + _transaction: Transaction + _last_assigned_order_id: Optional[int] + _case_sensitive: bool + _fields: List[SortField] + + def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> None: + super().__init__(transaction) + self._fields: List[SortField] = [] + self._case_sensitive: bool = case_sensitive + self._last_assigned_order_id: Optional[int] = None + + def _column_name_to_id(self, column_name: str) -> int: + """Map the column name to the column field id.""" + return ( + self._transaction.table_metadata.schema() + .find_field( + name_or_id=column_name, + case_sensitive=self._case_sensitive, + ) + .field_id + ) + + def _add_sort_field( + self, + source_id: int, + transform: Transform[Any, Any], + direction: SortDirection, + null_order: NullOrder, + ) -> UpdateSortOrder: + """Add a sort field to the sort order list.""" + self._fields.append( + SortField( + source_id=source_id, + transform=transform, + direction=direction, + null_order=null_order, + ) + ) + return self + + def _reuse_or_create_sort_order_id(self) -> int: + """Return the last assigned sort order id or create a new one.""" + new_sort_order_id = INITIAL_SORT_ORDER_ID + for sort_order in self._transaction.table_metadata.sort_orders: + new_sort_order_id = max(new_sort_order_id, sort_order.order_id) + if sort_order.fields == self._fields: + return sort_order.order_id + elif new_sort_order_id <= sort_order.order_id: + new_sort_order_id = sort_order.order_id + 1 + return new_sort_order_id + + def asc( + self, source_column_name: str, transform: Transform[Any, Any], null_order: NullOrder = NullOrder.NULLS_LAST + ) -> UpdateSortOrder: + """Add a sort field with ascending order.""" + return self._add_sort_field( + source_id=self._column_name_to_id(source_column_name), + transform=transform, + direction=SortDirection.ASC, + null_order=null_order, + ) + + def desc( + self, source_column_name: str, transform: Transform[Any, Any], null_order: NullOrder = NullOrder.NULLS_LAST + ) -> UpdateSortOrder: + """Add a sort field with descending order.""" + return self._add_sort_field( + source_id=self._column_name_to_id(source_column_name), + transform=transform, + direction=SortDirection.DESC, + null_order=null_order, + ) + + def _apply(self) -> SortOrder: + """Return the sort order.""" + if next(iter(self._fields), None) is None: + return UNSORTED_SORT_ORDER + else: + return SortOrder(*self._fields, order_id=self._reuse_or_create_sort_order_id()) + + def _commit(self) -> UpdatesAndRequirements: + """Apply the pending changes and commit.""" + new_sort_order = self._apply() + requirements: Tuple[TableRequirement, ...] = () + updates: Tuple[TableUpdate, ...] = () + + if ( + self._transaction.table_metadata.default_sort_order_id != new_sort_order.order_id + and self._transaction.table_metadata.sort_order_by_id(new_sort_order.order_id) is None + ): + self._last_assigned_order_id = new_sort_order.order_id + updates = (AddSortOrderUpdate(sort_order=new_sort_order), SetDefaultSortOrderUpdate(sort_order_id=-1)) + else: + updates = (SetDefaultSortOrderUpdate(sort_order_id=new_sort_order.order_id),) + + required_last_assigned_sort_order_id = self._transaction.table_metadata.default_sort_order_id + requirements = (AssertDefaultSortOrderId(default_sort_order_id=required_last_assigned_sort_order_id),) + + return updates, requirements diff --git a/tests/integration/test_sort_order_update.py b/tests/integration/test_sort_order_update.py new file mode 100644 index 0000000000..548c6692db --- /dev/null +++ b/tests/integration/test_sort_order_update.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name + +import pytest + +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.table.sorting import NullOrder, SortDirection, SortField, SortOrder +from pyiceberg.transforms import ( + IdentityTransform, +) + + +def _simple_table(catalog: Catalog, table_schema_simple: Schema, format_version: str) -> Table: + return _create_table_with_schema(catalog, table_schema_simple, format_version) + + +def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table: + tbl_name = "default.test_schema_evolution" + try: + catalog.drop_table(tbl_name) + except NoSuchTableError: + pass + return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version}) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "catalog, format_version", + [ + (pytest.lazy_fixture("session_catalog"), "1"), + (pytest.lazy_fixture("session_catalog_hive"), "1"), + (pytest.lazy_fixture("session_catalog"), "2"), + (pytest.lazy_fixture("session_catalog_hive"), "2"), + ], +) +def test_map_column_name_to_id(catalog: Catalog, format_version: str, table_schema_simple: Schema) -> None: + simple_table = _simple_table(catalog, table_schema_simple, format_version) + for col_name, col_id in {"foo": 1, "bar": 2, "baz": 3}.items(): + assert col_id == simple_table.update_sort_order()._column_name_to_id(col_name) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "catalog, format_version", + [ + (pytest.lazy_fixture("session_catalog"), "1"), + (pytest.lazy_fixture("session_catalog_hive"), "1"), + (pytest.lazy_fixture("session_catalog"), "2"), + (pytest.lazy_fixture("session_catalog_hive"), "2"), + ], +) +def test_update_sort_order(catalog: Catalog, format_version: str, table_schema_simple: Schema) -> None: + simple_table = _simple_table(catalog, table_schema_simple, format_version) + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_FIRST).desc( + "bar", IdentityTransform(), NullOrder.NULLS_LAST + ).commit() + assert simple_table.sort_order() == SortOrder( + SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), + SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.DESC, null_order=NullOrder.NULLS_LAST), + order_id=1, + ) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "catalog, format_version", + [ + (pytest.lazy_fixture("session_catalog"), "1"), + (pytest.lazy_fixture("session_catalog_hive"), "1"), + (pytest.lazy_fixture("session_catalog"), "2"), + (pytest.lazy_fixture("session_catalog_hive"), "2"), + ], +) +def test_increment_existing_sort_order_id(catalog: Catalog, format_version: str, table_schema_simple: Schema) -> None: + simple_table = _simple_table(catalog, table_schema_simple, format_version) + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_FIRST).commit() + assert simple_table.sort_order() == SortOrder( + SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), + order_id=1, + ) + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_LAST).desc( + "bar", IdentityTransform(), NullOrder.NULLS_FIRST + ).commit() + assert ( + len(simple_table.sort_orders()) == 3 + ) # 0: empty sort order from creating tables, 1: first sort order, 2: second sort order + assert simple_table.sort_order() == SortOrder( + SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST), + SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.DESC, null_order=NullOrder.NULLS_FIRST), + order_id=2, + ) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "catalog, format_version", + [ + (pytest.lazy_fixture("session_catalog"), "1"), + (pytest.lazy_fixture("session_catalog_hive"), "1"), + (pytest.lazy_fixture("session_catalog"), "2"), + (pytest.lazy_fixture("session_catalog_hive"), "2"), + ], +) +def test_update_existing_sort_order(catalog: Catalog, format_version: str, table_schema_simple: Schema) -> None: + simple_table = _simple_table(catalog, table_schema_simple, format_version) + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_FIRST).commit() + assert simple_table.sort_order() == SortOrder( + SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), + order_id=1, + ) + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_LAST).desc( + "bar", IdentityTransform(), NullOrder.NULLS_FIRST + ).commit() + # Go back to the first sort order + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_FIRST).commit() + assert ( + len(simple_table.sort_orders()) == 3 + ) # line 133 should not create a new sort order since it is the same as the first one + assert simple_table.sort_order() == SortOrder( + SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), + order_id=1, + ) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "catalog, format_version", + [ + (pytest.lazy_fixture("session_catalog"), "1"), + (pytest.lazy_fixture("session_catalog_hive"), "1"), + (pytest.lazy_fixture("session_catalog"), "2"), + (pytest.lazy_fixture("session_catalog_hive"), "2"), + ], +) +def test_update_existing_sort_order_with_unsorted_sort_order( + catalog: Catalog, format_version: str, table_schema_simple: Schema +) -> None: + simple_table = _simple_table(catalog, table_schema_simple, format_version) + simple_table.update_sort_order().asc("foo", IdentityTransform(), NullOrder.NULLS_FIRST).commit() + assert simple_table.sort_order() == SortOrder( + SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), + order_id=1, + ) + # Table should now be unsorted + simple_table.update_sort_order().commit() + # Go back to the first sort order + assert len(simple_table.sort_orders()) == 2 + assert simple_table.sort_order() == SortOrder(order_id=0)