Skip to content

Commit

Permalink
fix(python): always encapsulate column names in backticks in _all fun…
Browse files Browse the repository at this point in the history
…ctions (#2271)

# Description
- Always encapsulates column names in backticks to in the insert_all and
update_all calls.
- Added note that users need to add backticks for special column names
- Removed bigint cast, this was temporarily needed while we were still
relying on a physical plan

# Related Issue(s)
- closes #2230
- closes #2167
  • Loading branch information
ion-elgreco authored Mar 9, 2024
1 parent 1e19cf3 commit 3161713
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 11 deletions.
38 changes: 33 additions & 5 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,10 @@ def when_matched_update(
"""Update a matched table row based on the rules defined by ``updates``.
If a ``predicate`` is specified, then it must evaluate to true for the row to be updated.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
updates: a mapping of column name to update SQL expression.
predicate: SQL like predicate on when to update.
Expand All @@ -1362,18 +1366,18 @@ def when_matched_update(
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
data = pa.table({"x": [1, 2, 3], "1y": [4, 5, 6]})
write_deltalake("tmp", data)
dt = DeltaTable("tmp")
new_data = pa.table({"x": [1], "y": [7]})
new_data = pa.table({"x": [1], "1y": [7]})
(
dt.merge(
source=new_data,
predicate="target.x = source.x",
source_alias="source",
target_alias="target")
.when_matched_update(updates={"x": "source.x", "y": "source.y"})
.when_matched_update(updates={"x": "source.x", "`1y`": "source.`1y`"})
.execute()
)
{'num_source_rows': 1, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 1, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 2, 'num_output_rows': 3, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...}
Expand All @@ -1399,6 +1403,10 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
"""Updating all source fields to target fields, source and target are required to have the same field names.
If a ``predicate`` is specified, then it must evaluate to true for the row to be updated.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
predicate: SQL like predicate on when to update all columns.
Expand Down Expand Up @@ -1438,7 +1446,7 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""

updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self.source.schema
}

Expand All @@ -1457,6 +1465,10 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
"""Delete a matched row from the table only if the given ``predicate`` (if specified) is
true for the matched row. If not specified it deletes all matches.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
predicate (str | None, Optional): SQL like predicate on when to delete.
Expand Down Expand Up @@ -1533,6 +1545,10 @@ def when_not_matched_insert(
"""Insert a new row to the target table based on the rules defined by ``updates``. If a
``predicate`` is specified, then it must evaluate to true for the new row to be inserted.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
updates (dict): a mapping of column name to insert SQL expression.
predicate (str | None, Optional): SQL like predicate on when to insert.
Expand Down Expand Up @@ -1592,6 +1608,10 @@ def when_not_matched_insert_all(
required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for
the new row to be inserted.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
predicate: SQL like predicate on when to insert.
Expand Down Expand Up @@ -1631,7 +1651,7 @@ def when_not_matched_insert_all(
src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self.source.schema
}
if isinstance(self.not_matched_insert_updates, list) and isinstance(
Expand All @@ -1651,6 +1671,10 @@ def when_not_matched_by_source_update(
"""Update a target row that has no matches in the source based on the rules defined by ``updates``.
If a ``predicate`` is specified, then it must evaluate to true for the row to be updated.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
updates: a mapping of column name to update SQL expression.
predicate: SQL like predicate on when to update.
Expand Down Expand Up @@ -1705,6 +1729,10 @@ def when_not_matched_by_source_delete(
"""Delete a target row that has no matches in the source from the table only if the given
``predicate`` (if specified) is true for the target row.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
predicate: SQL like predicate on when to delete when not matched by source.
Expand Down
13 changes: 13 additions & 0 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,19 @@ def sample_table():
)


@pytest.fixture()
def sample_table_with_spaces_numbers():
nrows = 5
return pa.table(
{
"1id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold items": pa.array(list(range(nrows)), pa.int32()),
"deleted": pa.array([False] * nrows),
}
)


@pytest.fixture()
def writer_properties():
return WriterProperties(compression="GZIP", compression_level=0)
52 changes: 46 additions & 6 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_merge_when_not_matched_insert_with_predicate(
"sold": "source.sold",
"deleted": "False",
},
predicate="source.price < bigint'50'",
predicate="source.price < 50",
).execute()

expected = pa.table(
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_merge_when_not_matched_insert_all_with_predicate(
target_alias="target",
predicate="target.id = source.id",
).when_not_matched_insert_all(
predicate="source.price < bigint'50'",
predicate="source.price < 50",
).execute()

expected = pa.table(
Expand All @@ -332,6 +332,46 @@ def test_merge_when_not_matched_insert_all_with_predicate(
assert result == expected


def test_merge_when_not_matched_insert_all_with_predicate_special_column_names(
tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table
):
write_deltalake(tmp_path, sample_table_with_spaces_numbers, mode="append")

dt = DeltaTable(tmp_path)

source_table = pa.table(
{
"1id": pa.array(["6", "10"]),
"price": pa.array([10, 100], pa.int64()),
"sold items": pa.array([10, 20], pa.int32()),
"deleted": pa.array([None, None], pa.bool_()),
}
)

dt.merge(
source=source_table,
source_alias="source",
target_alias="target",
predicate="target.`1id` = source.`1id`",
).when_not_matched_insert_all(
predicate="source.price < 50",
).execute()

expected = pa.table(
{
"1id": pa.array(["1", "2", "3", "4", "5", "6"]),
"price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()),
"sold items": pa.array([0, 1, 2, 3, 4, 10], pa.int32()),
"deleted": pa.array([False, False, False, False, False, None]),
}
)
result = dt.to_pyarrow_table().sort_by([("1id", "ascending")])
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_when_not_matched_by_source_update_wo_predicate(
tmp_path: pathlib.Path, sample_table: pa.Table
):
Expand Down Expand Up @@ -399,7 +439,7 @@ def test_merge_when_not_matched_by_source_update_with_predicate(
updates={
"sold": "int'10'",
},
predicate="target.price > bigint'3'",
predicate="target.price > 3",
).execute()

expected = pa.table(
Expand Down Expand Up @@ -438,7 +478,7 @@ def test_merge_when_not_matched_by_source_delete_with_predicate(
source_alias="source",
target_alias="target",
predicate="target.id = source.id",
).when_not_matched_by_source_delete(predicate="target.price > bigint'3'").execute()
).when_not_matched_by_source_delete(predicate="target.price > 3").execute()

expected = pa.table(
{
Expand Down Expand Up @@ -608,15 +648,15 @@ def test_merge_multiple_when_not_matched_insert_with_predicate(
"sold": "source.sold",
"deleted": "False",
},
predicate="source.price < bigint'50'",
predicate="source.price < 50",
).when_not_matched_insert(
updates={
"id": "source.id",
"price": "source.price",
"sold": "source.sold",
"deleted": "False",
},
predicate="source.price > bigint'50'",
predicate="source.price > 50",
).execute()

expected = pa.table(
Expand Down

0 comments on commit 3161713

Please sign in to comment.