Skip to content

Commit

Permalink
unalias
Browse files Browse the repository at this point in the history
  • Loading branch information
tekumara committed Sep 8, 2024
1 parent 63948c1 commit dfa68e1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
19 changes: 19 additions & 0 deletions fakesnow/transforms_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@ def _remove_table_alias(eq_exp: exp.Condition) -> exp.Condition:
return eq_exp


def _unalias_table_identifiers(expression: exp.Expression) -> exp.Expression:
"""Replace aliased table identifiers with the table name"""

table_aliases = {a.parent.alias: a.parent.this for a in expression.find_all(exp.TableAlias) if a.parent}

# unalias columns that reference the table alias
for c in expression.find_all(exp.Column):
if unaliased := table_aliases.get(c.table):
c.args["table"] = unaliased

# remove table aliases
for a in expression.find_all(exp.TableAlias):
assert a.parent
a.parent.set("alias", None)

return expression


def merge(merge_expr: exp.Expression) -> list[exp.Expression]:
"""
Create multiple compatible duckdb statements to be functionally equivalent to Snowflake's MERGE INTO.
Expand All @@ -107,6 +125,7 @@ def merge(merge_expr: exp.Expression) -> list[exp.Expression]:
if not isinstance(merge_expr, exp.Merge):
return [merge_expr]

merge_expr = _unalias_table_identifiers(merge_expr)
temp_table_inserts = _create_temp_tables()
output_expressions = []

Expand Down
56 changes: 54 additions & 2 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import snowflake.connector
import sqlglot

from fakesnow import transforms
from fakesnow import transforms, transforms_merge


def test_merge_transform() -> None:
def test_transform_merge() -> None:
assert [
e.sql(dialect="duckdb")
for e in transforms.merge(
Expand Down Expand Up @@ -38,6 +38,58 @@ def test_merge_transform() -> None:
]


def test_unalias_table_identifiers() -> None:
assert (
(
transforms_merge._unalias_table_identifiers( # noqa: SLF001
sqlglot.parse_one(
"""
merge into dest as dst
using source as src
on dst.a = src.a
when not matched then
insert (a,b)
values (src.a,src.b)
when matched then
update set a = src.a, b = src.b
"""
)
).sql()
)
== "MERGE INTO dest USING source ON dest.a = source.a WHEN NOT MATCHED THEN INSERT (a, b) VALUES (source.a, source.b) WHEN MATCHED THEN UPDATE SET a = source.a, b = source.b"
)


def test_transform_merge_with_as() -> None:
assert [
e.sql(dialect="duckdb")
for e in transforms.merge(
sqlglot.parse_one(
"""
merge into test as dst
using TMP_TEST_1698864265 as src
on dst.a = src.a
when not matched then
insert (a,b)
values (src.a,src.b)
when matched then
update set a = src.a, b = src.b
"""
)
)
] == [
"BEGIN",
"CREATE OR REPLACE TEMPORARY TABLE temp_merge_updates_deletes (target_rowid INT, when_id INT, type TEXT(1))",
"CREATE OR REPLACE TEMPORARY TABLE temp_merge_inserts (source_rowid INT, when_id INT)",
"INSERT INTO temp_merge_inserts SELECT rowid, 0 FROM TMP_TEST_1698864265 AS src WHERE NOT EXISTS(SELECT 1 FROM test AS dst WHERE dst.a = src.a) AND NOT EXISTS(SELECT 1 FROM temp_merge_inserts WHERE TMP_TEST_1698864265 AS src.rowid = source_rowid)",
"INSERT INTO temp_merge_updates_deletes SELECT rowid, 1, 'U' FROM test AS dst WHERE EXISTS(SELECT 1 FROM TMP_TEST_1698864265 AS src WHERE dst.a = src.a) AND NOT EXISTS(SELECT 1 FROM temp_merge_updates_deletes WHERE test AS dst.rowid = target_rowid)",
"INSERT INTO test AS dst (a, b) SELECT src.a, src.b FROM TMP_TEST_1698864265 AS src WHERE TMP_TEST_1698864265 AS src.rowid IN (SELECT source_rowid FROM temp_merge_inserts WHERE when_id = 0 AND source_rowid = TMP_TEST_1698864265 AS src.rowid)",
"UPDATE test AS dst SET a = src.a, b = src.b FROM TMP_TEST_1698864265 AS src WHERE dst.a = src.a AND test AS dst.rowid IN (SELECT target_rowid FROM temp_merge_updates_deletes WHERE when_id = 1 AND target_rowid = test AS dst.rowid)",
"COMMIT",
'WITH merge_update_deletes AS (SELECT CAST(COUNT_IF(type = \'U\') AS INT) AS "updates", CAST(COUNT_IF(type = \'D\') AS INT) AS "deletes" FROM temp_merge_updates_deletes), merge_inserts AS (SELECT COUNT() AS "inserts" FROM temp_merge_inserts) SELECT mi.inserts AS "number of rows inserted", mud.updates AS "number of rows updated", mud.deletes AS "number of rows deleted" FROM merge_update_deletes AS mud, merge_inserts AS mi',
]


# TODO: Also consider nondeterministic config for throwing errors when multiple source criteria match a target row
# https://docs.snowflake.com/en/sql-reference/sql/merge#nondeterministic-results-for-update-and-delete
def test_merge(conn: snowflake.connector.SnowflakeConnection):
Expand Down

0 comments on commit dfa68e1

Please sign in to comment.