diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 377411b944814..b3685b850d7df 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -26,6 +26,7 @@ from typing import Any, Generic, TypeVar import sqlglot +import sqlparse from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError @@ -138,9 +139,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ Base class for SQL statements. - The class can be instantiated with a string representation of the script or, for - efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, - which will split a script in multiple already parsed statements. + The class should be instantiated with a string representation of the script and, for + efficiency reasons, optionally with a pre-parsed AST. This is useful with + `sqlglot.parse`, which will split a script in multiple already parsed statements. The `engine` parameters comes from the `engine` attribute in a Superset DB engine spec. @@ -148,14 +149,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]): def __init__( self, - statement: str | InternalRepresentation, + statement: str, engine: str, + ast: InternalRepresentation | None = None, ): - self._parsed: InternalRepresentation = ( - self._parse_statement(statement, engine) - if isinstance(statement, str) - else statement - ) + self._sql = statement + self._parsed = ast or self._parse_statement(statement, engine) self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) @@ -239,11 +238,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): def __init__( self, - statement: str | exp.Expression, + statement: str, engine: str, + ast: exp.Expression | None = None, ): self._dialect = SQLGLOT_DIALECTS.get(engine) - super().__init__(statement, engine) + super().__init__(statement, engine, ast) @classmethod def _parse(cls, script: str, engine: str) -> list[exp.Expression]: @@ -275,11 +275,37 @@ def split_script( script: str, engine: str, ) -> list[SQLStatement]: - return [ - cls(statement, engine) - for statement in cls._parse(script, engine) - if statement - ] + if dialect := SQLGLOT_DIALECTS.get(engine): + try: + return [ + cls(ast.sql(), engine, ast) + for ast in cls._parse(script, engine) + if ast + ] + except ValueError: + # `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES + # FROM`). In this case, we rely on the tokenizer to generate the + # statements. + pass + + # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly + # generate the SQL of each statement, so we tokenize the script and split it + # based on the location of semi-colons. + statements = [] + start = 0 + remainder = script + for token in sqlglot.tokenize(script): + if token.token_type == sqlglot.TokenType.SEMICOLON: + statement, start = script[start : token.start], token.end + 1 + ast = sqlglot.parse_one(statement, dialect) + statements.append(cls(statement.strip(), engine, ast)) + remainder = script[start:] + + if remainder.strip(): + ast = sqlglot.parse_one(remainder, dialect) + statements.append(cls(remainder, engine, ast)) + + return statements @classmethod def _parse_statement( @@ -349,8 +375,23 @@ def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. """ - write = Dialect.get_or_raise(self._dialect) - return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + if self._dialect: + try: + write = Dialect.get_or_raise(self._dialect) + return write.generate( + self._parsed, + copy=False, + comments=comments, + pretty=True, + ) + except ValueError: + pass + + # Reformatting SQL using the generic sqlglot dialect is known to break queries. + # For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which + # breaks the query for Firebolt. To avoid this, we use sqlparse for formatting + # when the dialect is not known. + return sqlparse.format(self._sql, reindent=True, keyword_case="upper") def get_settings(self) -> dict[str, str | bool]: """ @@ -456,7 +497,9 @@ def split_script( https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string for more information. """ - return [cls(statement, engine) for statement in split_kql(script)] + return [ + cls(statement, engine, statement.strip()) for statement in split_kql(script) + ] @classmethod def _parse_statement( @@ -498,7 +541,7 @@ def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. """ - return self._parsed + return self._sql.strip() def get_settings(self) -> dict[str, str | bool]: """ @@ -548,6 +591,9 @@ def __init__( def format(self, comments: bool = True) -> str: """ Pretty-format the SQL script. + + Note that even though KQL is very different from SQL, multiple statements are + still separated by semi-colons. """ return ";\n".join(statement.format(comments) for statement in self.statements) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 19d6e56fb6441..cf1e190bbb9ba 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -281,7 +281,7 @@ def test_format_sql_request(self): "/api/v1/sqllab/format_sql/", json=data, ) - success_resp = {"result": "SELECT\n 1\nFROM my_table"} + success_resp = {"result": "SELECT 1\nFROM my_table"} resp_data = json.loads(rv.data.decode("utf-8")) self.assertDictEqual(resp_data, success_resp) # noqa: PT009 assert rv.status_code == 200 diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index ae5ebf89a8b96..3fe72142635a2 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -284,6 +284,39 @@ def test_extract_tables_show_tables_from() -> None: ) +def test_format_show_tables() -> None: + """ + Test format when `ast.sql()` raises an exception. + + In that case sqlparse should be used instead. + """ + assert ( + SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format() + == "SHOW TABLES FROM s1 LIKE '%order%'" + ) + + +def test_format_no_dialect() -> None: + """ + Test format with an engine that has no corresponding dialect. + """ + assert ( + SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format() + == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)" + ) + + +def test_split_no_dialect() -> None: + """ + Test the statement split when the engine has no corresponding dialect. + """ + sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t;" + statements = SQLScript(sql, "firebolt").statements + assert len(statements) == 2 + assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)" + assert statements[1]._sql == "SELECT * FROM t" + + def test_extract_tables_show_columns_from() -> None: """ Test `SHOW COLUMNS FROM`.