Skip to content

Commit

Permalink
fix: sqlparse fallback for formatting queries
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Oct 11, 2024
1 parent ef0ede7 commit 882e411
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 21 deletions.
86 changes: 66 additions & 20 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Any, Generic, TypeVar

import sqlglot
import sqlparse
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -138,24 +139,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
The class can be instantiated with a string representation of the script or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a script in multiple already parsed statements.
The class should be instantiated with a string representation of the script and, for
efficiency reasons, optionally with a pre-parsed AST. This is useful with
`sqlglot.parse`, which will split a script in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""

def __init__(
self,
statement: str | InternalRepresentation,
statement: str,
engine: str,
ast: InternalRepresentation | None = None,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self._sql = statement
self._parsed = ast or self._parse_statement(statement, engine)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)

Expand Down Expand Up @@ -239,11 +238,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):

def __init__(
self,
statement: str | exp.Expression,
statement: str,
engine: str,
ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
super().__init__(statement, engine, ast)

@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
Expand Down Expand Up @@ -275,11 +275,37 @@ def split_script(
script: str,
engine: str,
) -> list[SQLStatement]:
return [
cls(statement, engine)
for statement in cls._parse(script, engine)
if statement
]
if dialect := SQLGLOT_DIALECTS.get(engine):
try:
return [
cls(ast.sql(), engine, ast)
for ast in cls._parse(script, engine)
if ast
]
except ValueError:
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
# FROM`). In this case, we rely on the tokenizer to generate the
# statements.
pass

# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
# generate the SQL of each statement, so we tokenize the script and split it
# based on the location of semi-colons.
statements = []
start = 0
remainder = script
for token in sqlglot.tokenize(script):
if token.token_type == sqlglot.TokenType.SEMICOLON:
statement, start = script[start : token.start], token.end + 1
ast = sqlglot.parse_one(statement, dialect)
statements.append(cls(statement.strip(), engine, ast))
remainder = script[start:]

if remainder.strip():
ast = sqlglot.parse_one(remainder, dialect)
statements.append(cls(remainder, engine, ast))

return statements

@classmethod
def _parse_statement(
Expand Down Expand Up @@ -349,8 +375,23 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
if self._dialect:
try:
write = Dialect.get_or_raise(self._dialect)
return write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
except ValueError:
pass

# Reformatting SQL using the generic sqlglot dialect is known to break queries.
# For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
# breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
# when the dialect is not known.
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -456,7 +497,9 @@ def split_script(
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(script)]
return [
cls(statement, engine, statement.strip()) for statement in split_kql(script)
]

@classmethod
def _parse_statement(
Expand Down Expand Up @@ -498,7 +541,7 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
return self._sql.strip()

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -548,6 +591,9 @@ def __init__(
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL script.
Note that even though KQL is very different from SQL, multiple statements are
still separated by semi-colons.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_format_sql_request(self):
"/api/v1/sqllab/format_sql/",
json=data,
)
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
success_resp = {"result": "SELECT 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200
Expand Down
33 changes: 33 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,39 @@ def test_extract_tables_show_tables_from() -> None:
)


def test_format_show_tables() -> None:
"""
Test format when `ast.sql()` raises an exception.
In that case sqlparse should be used instead.
"""
assert (
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
== "SHOW TABLES FROM s1 LIKE '%order%'"
)


def test_format_no_dialect() -> None:
"""
Test format with an engine that has no corresponding dialect.
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
)


def test_split_no_dialect() -> None:
"""
Test the statement split when the engine has no corresponding dialect.
"""
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t;"
statements = SQLScript(sql, "firebolt").statements
assert len(statements) == 2
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
assert statements[1]._sql == "SELECT * FROM t"


def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.
Expand Down

0 comments on commit 882e411

Please sign in to comment.