From 31617134ec7265df8c7591f20aab8439a8ef60e2 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 9 Mar 2024 17:27:50 +0100 Subject: [PATCH] fix(python): always encapsulate column names in backticks in _all functions (#2271) # Description - Always encapsulates column names in backticks to in the insert_all and update_all calls. - Added note that users need to add backticks for special column names - Removed bigint cast, this was temporarily needed while we were still relying on a physical plan # Related Issue(s) - closes https://github.com/delta-io/delta-rs/issues/2230 - closes https://github.com/delta-io/delta-rs/issues/2167 --- python/deltalake/table.py | 38 ++++++++++++++++++++++++---- python/tests/conftest.py | 13 ++++++++++ python/tests/test_merge.py | 52 +++++++++++++++++++++++++++++++++----- 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 5869ceb2e2..064ee3a83c 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1350,6 +1350,10 @@ def when_matched_update( """Update a matched table row based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: updates: a mapping of column name to update SQL expression. predicate: SQL like predicate on when to update. @@ -1362,10 +1366,10 @@ def when_matched_update( from deltalake import DeltaTable, write_deltalake import pyarrow as pa - data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + data = pa.table({"x": [1, 2, 3], "1y": [4, 5, 6]}) write_deltalake("tmp", data) dt = DeltaTable("tmp") - new_data = pa.table({"x": [1], "y": [7]}) + new_data = pa.table({"x": [1], "1y": [7]}) ( dt.merge( @@ -1373,7 +1377,7 @@ def when_matched_update( predicate="target.x = source.x", source_alias="source", target_alias="target") - .when_matched_update(updates={"x": "source.x", "y": "source.y"}) + .when_matched_update(updates={"x": "source.x", "`1y`": "source.`1y`"}) .execute() ) {'num_source_rows': 1, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 1, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 2, 'num_output_rows': 3, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...} @@ -1399,6 +1403,10 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg """Updating all source fields to target fields, source and target are required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: predicate: SQL like predicate on when to update all columns. @@ -1438,7 +1446,7 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" updates = { - f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" + f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" for col in self.source.schema } @@ -1457,6 +1465,10 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": """Delete a matched row from the table only if the given ``predicate`` (if specified) is true for the matched row. If not specified it deletes all matches. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: predicate (str | None, Optional): SQL like predicate on when to delete. @@ -1533,6 +1545,10 @@ def when_not_matched_insert( """Insert a new row to the target table based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: updates (dict): a mapping of column name to insert SQL expression. predicate (str | None, Optional): SQL like predicate on when to insert. @@ -1592,6 +1608,10 @@ def when_not_matched_insert_all( required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: predicate: SQL like predicate on when to insert. @@ -1631,7 +1651,7 @@ def when_not_matched_insert_all( src_alias = (self.source_alias + ".") if self.source_alias is not None else "" trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" updates = { - f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" + f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" for col in self.source.schema } if isinstance(self.not_matched_insert_updates, list) and isinstance( @@ -1651,6 +1671,10 @@ def when_not_matched_by_source_update( """Update a target row that has no matches in the source based on the rules defined by ``updates``. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: updates: a mapping of column name to update SQL expression. predicate: SQL like predicate on when to update. @@ -1705,6 +1729,10 @@ def when_not_matched_by_source_delete( """Delete a target row that has no matches in the source from the table only if the given ``predicate`` (if specified) is true for the target row. + Note: + Column names with special characters, such as numbers or spaces should be encapsulated + in backticks: "target.`123column`" or "target.`my column`" + Args: predicate: SQL like predicate on when to delete when not matched by source. diff --git a/python/tests/conftest.py b/python/tests/conftest.py index c81b6fb91e..6621bc9afb 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -250,6 +250,19 @@ def sample_table(): ) +@pytest.fixture() +def sample_table_with_spaces_numbers(): + nrows = 5 + return pa.table( + { + "1id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold items": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + + @pytest.fixture() def writer_properties(): return WriterProperties(compression="GZIP", compression_level=0) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index b609d88d21..82776c60fc 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -274,7 +274,7 @@ def test_merge_when_not_matched_insert_with_predicate( "sold": "source.sold", "deleted": "False", }, - predicate="source.price < bigint'50'", + predicate="source.price < 50", ).execute() expected = pa.table( @@ -314,7 +314,7 @@ def test_merge_when_not_matched_insert_all_with_predicate( target_alias="target", predicate="target.id = source.id", ).when_not_matched_insert_all( - predicate="source.price < bigint'50'", + predicate="source.price < 50", ).execute() expected = pa.table( @@ -332,6 +332,46 @@ def test_merge_when_not_matched_insert_all_with_predicate( assert result == expected +def test_merge_when_not_matched_insert_all_with_predicate_special_column_names( + tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table +): + write_deltalake(tmp_path, sample_table_with_spaces_numbers, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "1id": pa.array(["6", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold items": pa.array([10, 20], pa.int32()), + "deleted": pa.array([None, None], pa.bool_()), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.`1id` = source.`1id`", + ).when_not_matched_insert_all( + predicate="source.price < 50", + ).execute() + + expected = pa.table( + { + "1id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()), + "sold items": pa.array([0, 1, 2, 3, 4, 10], pa.int32()), + "deleted": pa.array([False, False, False, False, False, None]), + } + ) + result = dt.to_pyarrow_table().sort_by([("1id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_not_matched_by_source_update_wo_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -399,7 +439,7 @@ def test_merge_when_not_matched_by_source_update_with_predicate( updates={ "sold": "int'10'", }, - predicate="target.price > bigint'3'", + predicate="target.price > 3", ).execute() expected = pa.table( @@ -438,7 +478,7 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( source_alias="source", target_alias="target", predicate="target.id = source.id", - ).when_not_matched_by_source_delete(predicate="target.price > bigint'3'").execute() + ).when_not_matched_by_source_delete(predicate="target.price > 3").execute() expected = pa.table( { @@ -608,7 +648,7 @@ def test_merge_multiple_when_not_matched_insert_with_predicate( "sold": "source.sold", "deleted": "False", }, - predicate="source.price < bigint'50'", + predicate="source.price < 50", ).when_not_matched_insert( updates={ "id": "source.id", @@ -616,7 +656,7 @@ def test_merge_multiple_when_not_matched_insert_with_predicate( "sold": "source.sold", "deleted": "False", }, - predicate="source.price > bigint'50'", + predicate="source.price > 50", ).execute() expected = pa.table(