Skip to content

Commit

Permalink
fix(dask_normalize): use sqlglot to generate escaped name string (let…
Browse files Browse the repository at this point in the history
  • Loading branch information
dlovell committed Jun 11, 2024
1 parent 27fb007 commit 23de776
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/letsql/common/utils/dask_normalize_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dask
import ibis
import ibis.expr.operations.relations as ir
import sqlglot as sg

from letsql.expr.relations import (
make_native_op,
Expand Down Expand Up @@ -111,7 +112,10 @@ def normalize_snowflake_databasetable(dt):
def normalize_duckdb_databasetable(dt):
if dt.source.name != "duckdb":
raise ValueError
((_, plan),) = dt.source.raw_sql(f"EXPLAIN SELECT * FROM {dt.name}").fetchall()
name = sg.table(dt.name, quoted=dt.source.compiler.quoted).sql(
dialect=dt.source.name
)
((_, plan),) = dt.source.raw_sql(f"EXPLAIN SELECT * FROM {name}").fetchall()
scan_line = plan.split("\n")[1]
execution_plan_name = r"\s*│\s*(\w+)\s*│\s*"
match re.match(execution_plan_name, scan_line).group(1):
Expand All @@ -124,8 +128,9 @@ def normalize_duckdb_databasetable(dt):


def normalize_duckdb_file_read(dt):
name = sg.exp.convert(dt.name).sql(dialect=dt.source.name)
(sql_ddl_statement,) = dt.source.con.sql(
f"select sql from duckdb_views() where view_name = '{dt.name}'"
f"select sql from duckdb_views() where view_name = {name}"
).fetchone()
return dask.base._normalize_seq_func(
(
Expand Down
8 changes: 8 additions & 0 deletions python/letsql/common/utils/tests/test_dask_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ def test_tokenize_pandas_expr(alltypes_df):
actual = dask.base.tokenize(t)
expected = "7b0019049171a3ef78ecbd5f463ac728"
assert actual == expected


def test_tokenize_duckdb_dt(batting):
db_con = ibis.duckdb.connect()
t = db_con.register(batting.to_pyarrow(), "dashed-name")
actual = dask.base.tokenize(t)
expected = "e5d0040b184eaa719ebb5dc0efff3cc7"
assert actual == expected

0 comments on commit 23de776

Please sign in to comment.