Skip to content

Commit

Permalink
ploomber#892 Use sqlparse instead of sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
marshallwhiteorg committed Oct 5, 2023
1 parent 01cf0c5 commit 16823a1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
54 changes: 27 additions & 27 deletions src/sql/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
)
from IPython.core.error import UsageError
import sqlglot
from sqlglot import parse_one, exp
from sqlglot.generator import Generator
import sqlparse
from ploomber_core.exceptions import modify_exceptions

Expand Down Expand Up @@ -731,13 +729,7 @@ def _connection_execute(self, query, parameters=None):
# empty results if we commit after a SELECT or SUMMARIZE statement,
# see: https://github.com/Mause/duckdb_engine/issues/734.
if self.dialect == "duckdb":
is_duckdb_sqlalchemy = not self.is_dbapi_connection
if is_duckdb_sqlalchemy:
parse_dialect = "tsql"
else:
parse_dialect = "duckdb"

no_commit = detect_duckdb_summarize_or_select(query, parse_dialect)
no_commit = detect_duckdb_summarize_or_select(query)
if no_commit:
return out

Expand Down Expand Up @@ -1074,24 +1066,6 @@ def _check_if_duckdb_dbapi_connection(conn):
return hasattr(conn, "df") and hasattr(conn, "pl")


def detect_duckdb_summarize_or_select(query, parse_dialect):
# Attempt to use sqlglot to detect SELECT and SUMMARIZE.
try:
expression = parse_one(query, dialect=parse_dialect)
sql_stripped = Generator(comments=False).generate(expression)
words = sql_stripped.split()
return (
words
and (
words[0].lower() == "select"
or words[0].lower() == "summarize"
)
or isinstance(expression, exp.Select)
)
except sqlglot.errors.ParseError:
return False


def _suggest_fix(env_var, connect_str=None):
"""
Returns an error message that we can display to the user
Expand Down Expand Up @@ -1200,4 +1174,30 @@ def set_sqlalchemy_isolation_level(conn):
return False


def detect_duckdb_summarize_or_select(query):
"""
Checks if the SQL query is a DuckDB SELECT or SUMMARIZE statement.
Note:
Assumes there is only one SQL statement in the query.
"""
statements = sqlparse.parse(query)
if statements:
assert len(statements) == 1
stype = statements[0].get_type()
if stype == "SELECT":
return True
elif stype == "UNKNOWN":
# Further analysis is required
sql_stripped = sqlparse.format(query, strip_comments=True)
words = sql_stripped.split()
return (
len(words) > 0
and (
words[0].lower() == "from"
or words[0].lower() == "summarize"
)
)
return False

atexit.register(ConnectionManager.close_all, verbose=True)
10 changes: 4 additions & 6 deletions src/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,21 +1209,19 @@ def test_database_in_directory_that_doesnt_exist(tmp_empty, uri, expected):
("SELECT column FROM (SELECT * FROM table WHERE column = 'SELECT') AS x", True),

# Invalid SQL returns false
("SELECT FROM table WHERE (column = 'value'", False),
("INSERT INTO table (column) VALUES ('SELECT')", False),
pytest.param("SELECT FROM table WHERE (column = 'value'", False, marks=pytest.mark.xfail(reason="sqlparse does not notice the missing close paren")),

# Comments have no effect
("-- SELECT * FROM table", False),
("-- SELECT * FROM table\nSELECT * FROM table", True),
("-- SELECT * FROM table\nINSERT INTO table SELECT * FROM table2", False),
("-- FROM table SELECT *", False),
("-- FROM table SELECT *\nFROM table SELECT *", True),
("-- FROM table SELECT *\n/**/FROM/**/ table SELECT */**/", True),
("-- FROM table SELECT *\nINSERT INTO table FROM table2 SELECT *", False),
("-- INSERT INTO table SELECT * FROM table2\nSELECT /**/ * FROM tbl /**/", True),
("-- INSERT INTO table SELECT * FROM table2\n/**/SUMMARIZE/**/ /**//**/tbl/**/", True),
]
_dialects = ["duckdb", "tsql"]
@pytest.mark.parametrize("query, expected_output", _query_expected_outputs)
@pytest.mark.parametrize("parse_dialect", _dialects)
def test_detect_duckdb_summarize_or_select(query, parse_dialect, expected_output):
assert detect_duckdb_summarize_or_select(query, parse_dialect) == expected_output
def test_detect_duckdb_summarize_or_select(query, expected_output):
assert detect_duckdb_summarize_or_select(query) == expected_output

0 comments on commit 16823a1

Please sign in to comment.