From 7a894cb6074454fa864a4ae905f9ee43fcec6d45 Mon Sep 17 00:00:00 2001 From: Sujith Kumar S Date: Wed, 2 Feb 2022 07:55:17 +0530 Subject: [PATCH 1/7] Fix for handling regular CTE queries with MSSQL,#8074 --- superset/db_engine_specs/base.py | 5 +++++ superset/db_engine_specs/mssql.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index bdd1922d2ca56..70cbcb03eacda 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -292,6 +292,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # But for backward compatibility, False by default allows_hidden_cc_in_orderby = False + # Whether allow CTE as subquery or regular CTE + # If True, then it will allow in subquery , + # if False it will allow as regular CTE + allows_cte_in_subquery = True + force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 87499284712c1..c129932f642b0 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -24,6 +24,8 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.errors import SupersetErrorType from superset.utils import core as utils +import sqlparse +from sqlparse.tokens import Keyword, CTE logger = logging.getLogger(__name__) @@ -47,6 +49,7 @@ class MssqlEngineSpec(BaseEngineSpec): engine_name = "Microsoft SQL Server" limit_method = LimitMethod.WRAP_SQL max_column_name_length = 128 + allows_cte_in_subquery = False _time_grain_expressions = { None: "{col}", @@ -134,6 +137,36 @@ def extract_error_message(cls, ex: Exception) -> str: ) return f"{cls.engine} error: {cls._extract_error_message(ex)}" + @classmethod + def get_cte_query(cls, sql) -> Optional[str]: + """ + Returns the wrapped CTE + """ + if not cls.allows_cte_in_subquery: + p = sqlparse.parse(sql)[0] + + # The first meaningful token for CTE will be with WITH + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return None + idx, tok = p.token_next(idx) + idx = p.token_index(tok) + 1 + + # extarct rest of the SQLs after CTE + remainder = u"".join(str(tok) for tok in p.tokens[idx:]) + + __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" + __query = sqlparse.format(__query, reindent=True, keyword_case='upper') + return __query + return None + + @classmethod + def test_cte_sql(cls): + sql = """ + select * from currency + """ + sql = cls.get_cte_prequery(sql) + print(sql) class AzureSynapseSpec(MssqlEngineSpec): engine = "mssql" From 5acea670d0ce0eb5a18f4fbacc0a1b8a660df1c6 Mon Sep 17 00:00:00 2001 From: sujiplr Date: Wed, 2 Feb 2022 11:27:21 +0530 Subject: [PATCH 2/7] Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine --- superset/db_engine_specs/base.py | 28 ++++++++++++++++++++++++++++ superset/db_engine_specs/mssql.py | 30 ------------------------------ 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 70cbcb03eacda..40f1118bbf314 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -38,6 +38,7 @@ import pandas as pd import sqlparse +from sqlparse.tokens import Keyword, CTE from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask import current_app, g @@ -668,6 +669,33 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str: parsed_query = sql_parse.ParsedQuery(sql) return parsed_query.set_or_update_query_limit(limit) + @classmethod + def get_cte_query(cls, sql) -> Optional[str]: + """ + Convert the input CTE based SQL to the SQL for virtual table conversion + + :param sql: SQL query + :return: Query with __query alias + + """ + if not cls.allows_cte_in_subquery: + p = sqlparse.parse(sql)[0] + + # The first meaningful token for CTE will be with WITH + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return None + idx, tok = p.token_next(idx) + idx = p.token_index(tok) + 1 + + # extract rest of the SQLs after CTE + remainder = u"".join(str(tok) for tok in p.tokens[idx:]) + + __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" + __query = sqlparse.format(__query, reindent=True, keyword_case='upper') + return __query + return None + @classmethod def df_to_sql( cls, diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index c129932f642b0..0d73a7b80d280 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -137,36 +137,6 @@ def extract_error_message(cls, ex: Exception) -> str: ) return f"{cls.engine} error: {cls._extract_error_message(ex)}" - @classmethod - def get_cte_query(cls, sql) -> Optional[str]: - """ - Returns the wrapped CTE - """ - if not cls.allows_cte_in_subquery: - p = sqlparse.parse(sql)[0] - - # The first meaningful token for CTE will be with WITH - idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) - if not (tok and tok.ttype == CTE): - return None - idx, tok = p.token_next(idx) - idx = p.token_index(tok) + 1 - - # extarct rest of the SQLs after CTE - remainder = u"".join(str(tok) for tok in p.tokens[idx:]) - - __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" - __query = sqlparse.format(__query, reindent=True, keyword_case='upper') - return __query - return None - - @classmethod - def test_cte_sql(cls): - sql = """ - select * from currency - """ - sql = cls.get_cte_prequery(sql) - print(sql) class AzureSynapseSpec(MssqlEngineSpec): engine = "mssql" From b3638e2be17d7ecffe81bdf3e45e74e1a0b3c8f7 Mon Sep 17 00:00:00 2001 From: Sujith Kumar S Date: Wed, 2 Feb 2022 07:55:17 +0530 Subject: [PATCH 3/7] Fix for handling regular CTE queries with MSSQL,#8074 --- superset/db_engine_specs/base.py | 5 +++++ superset/db_engine_specs/mssql.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index bdd1922d2ca56..70cbcb03eacda 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -292,6 +292,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # But for backward compatibility, False by default allows_hidden_cc_in_orderby = False + # Whether allow CTE as subquery or regular CTE + # If True, then it will allow in subquery , + # if False it will allow as regular CTE + allows_cte_in_subquery = True + force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 87499284712c1..c129932f642b0 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -24,6 +24,8 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.errors import SupersetErrorType from superset.utils import core as utils +import sqlparse +from sqlparse.tokens import Keyword, CTE logger = logging.getLogger(__name__) @@ -47,6 +49,7 @@ class MssqlEngineSpec(BaseEngineSpec): engine_name = "Microsoft SQL Server" limit_method = LimitMethod.WRAP_SQL max_column_name_length = 128 + allows_cte_in_subquery = False _time_grain_expressions = { None: "{col}", @@ -134,6 +137,36 @@ def extract_error_message(cls, ex: Exception) -> str: ) return f"{cls.engine} error: {cls._extract_error_message(ex)}" + @classmethod + def get_cte_query(cls, sql) -> Optional[str]: + """ + Returns the wrapped CTE + """ + if not cls.allows_cte_in_subquery: + p = sqlparse.parse(sql)[0] + + # The first meaningful token for CTE will be with WITH + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return None + idx, tok = p.token_next(idx) + idx = p.token_index(tok) + 1 + + # extarct rest of the SQLs after CTE + remainder = u"".join(str(tok) for tok in p.tokens[idx:]) + + __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" + __query = sqlparse.format(__query, reindent=True, keyword_case='upper') + return __query + return None + + @classmethod + def test_cte_sql(cls): + sql = """ + select * from currency + """ + sql = cls.get_cte_prequery(sql) + print(sql) class AzureSynapseSpec(MssqlEngineSpec): engine = "mssql" From 2eddbd145981105fd5d69ba96421b04e62c75656 Mon Sep 17 00:00:00 2001 From: sujiplr Date: Wed, 2 Feb 2022 11:27:21 +0530 Subject: [PATCH 4/7] Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine --- superset/db_engine_specs/base.py | 28 ++++++++++++++++++++++++++++ superset/db_engine_specs/mssql.py | 30 ------------------------------ 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 70cbcb03eacda..40f1118bbf314 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -38,6 +38,7 @@ import pandas as pd import sqlparse +from sqlparse.tokens import Keyword, CTE from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask import current_app, g @@ -668,6 +669,33 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str: parsed_query = sql_parse.ParsedQuery(sql) return parsed_query.set_or_update_query_limit(limit) + @classmethod + def get_cte_query(cls, sql) -> Optional[str]: + """ + Convert the input CTE based SQL to the SQL for virtual table conversion + + :param sql: SQL query + :return: Query with __query alias + + """ + if not cls.allows_cte_in_subquery: + p = sqlparse.parse(sql)[0] + + # The first meaningful token for CTE will be with WITH + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return None + idx, tok = p.token_next(idx) + idx = p.token_index(tok) + 1 + + # extract rest of the SQLs after CTE + remainder = u"".join(str(tok) for tok in p.tokens[idx:]) + + __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" + __query = sqlparse.format(__query, reindent=True, keyword_case='upper') + return __query + return None + @classmethod def df_to_sql( cls, diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index c129932f642b0..0d73a7b80d280 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -137,36 +137,6 @@ def extract_error_message(cls, ex: Exception) -> str: ) return f"{cls.engine} error: {cls._extract_error_message(ex)}" - @classmethod - def get_cte_query(cls, sql) -> Optional[str]: - """ - Returns the wrapped CTE - """ - if not cls.allows_cte_in_subquery: - p = sqlparse.parse(sql)[0] - - # The first meaningful token for CTE will be with WITH - idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) - if not (tok and tok.ttype == CTE): - return None - idx, tok = p.token_next(idx) - idx = p.token_index(tok) + 1 - - # extarct rest of the SQLs after CTE - remainder = u"".join(str(tok) for tok in p.tokens[idx:]) - - __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" - __query = sqlparse.format(__query, reindent=True, keyword_case='upper') - return __query - return None - - @classmethod - def test_cte_sql(cls): - sql = """ - select * from currency - """ - sql = cls.get_cte_prequery(sql) - print(sql) class AzureSynapseSpec(MssqlEngineSpec): engine = "mssql" From 1009c110fe0d9b8a04acf40b042a3063a01a107b Mon Sep 17 00:00:00 2001 From: sujiplr Date: Thu, 3 Feb 2022 16:53:57 +0530 Subject: [PATCH 5/7] Unit test added for the db engine CTE SQL parsing. Unit test added for the db engine CTE SQL parsing. Removed additional spaces from the CTE parsing SQL generation. --- superset/db_engine_specs/base.py | 5 ++- .../unit_tests/db_engine_specs/test_mssql.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 40f1118bbf314..5c215c2abde04 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -690,10 +690,9 @@ def get_cte_query(cls, sql) -> Optional[str]: # extract rest of the SQLs after CTE remainder = u"".join(str(tok) for tok in p.tokens[idx:]) - - __query = "WITH " + tok.value + ", __query as ( " + remainder + ")" - __query = sqlparse.format(__query, reindent=True, keyword_case='upper') + __query = "WITH " + tok.value + ",\n__query as (" + remainder + ")" return __query + return None @classmethod diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 75d2dcb1089cc..507f268a3ad21 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -179,6 +179,40 @@ def test_column_datatype_to_string( actual = MssqlEngineSpec.column_datatype_to_string(original, mssql.dialect()) assert actual == expected +@pytest.mark.parametrize( + "original,expected", + [ + ("with currency as\n" + "(\n" + "select 'INR' as cur\n" + ")\n" + "select * from currency\n" + , "WITH currency as\n" + "(\n" + "select 'INR' as cur\n" + "),\n" + "__query as (\n" + "select * from currency\n" + ")",), + ("SELECT 1 as cnt", None,), + ("select 'INR' as cur\n" + "union\n" + "select 'AUD' as cur\n" + "union \n" + "select 'USD' as cur\n", None) + ], +) +def test_cte_query_parsing( + app_context: AppContext, original: TypeEngine, expected: str +) -> None: + from superset.db_engine_specs.mssql import MssqlEngineSpec + + actual = MssqlEngineSpec.get_cte_query(original) + print(original) + print(expected) + print(actual) + assert actual == expected + def test_extract_errors(app_context: AppContext) -> None: """ From 16c3a3fead05dc3f322e6c274050edcb95ecf3e9 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 4 Feb 2022 22:31:16 +0200 Subject: [PATCH 6/7] implement in sqla model --- superset/connectors/sqla/models.py | 39 ++++++++---- superset/db_engine_specs/base.py | 12 ++-- tests/integration_tests/core_tests.py | 2 +- .../unit_tests/db_engine_specs/test_mssql.py | 59 +++++++++++-------- 4 files changed, 71 insertions(+), 41 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 74e8e6c3d1dbe..41f29f7ad74c0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -103,10 +103,12 @@ logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" +CTE_ALIAS = "__cte" class SqlaQuery(NamedTuple): applied_template_filters: List[str] + cte: Optional[str] extra_cache_keys: List[Any] labels_expected: List[str] prequeries: List[str] @@ -127,6 +129,12 @@ class MetadataResult: modified: List[str] = field(default_factory=list) +def _apply_cte(sql: str, cte: Optional[str]) -> str: + if cte: + sql = f"{cte}{sql}" + return sql + + class AnnotationDatasource(BaseDatasource): """Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -743,12 +751,9 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() + tbl, cte = self.get_from_clause(tp) - qry = ( - select([target_col.get_sqla_col()]) - .select_from(self.get_from_clause(tp)) - .distinct() - ) + qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct() if limit: qry = qry.limit(limit) @@ -756,7 +761,8 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: qry = qry.where(self.get_fetch_values_predicate()) engine = self.database.get_sqla_engine() - sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True})) + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = _apply_cte(sql, cte) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) @@ -778,6 +784,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) + sql = _apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended( @@ -800,13 +807,14 @@ def get_sqla_table(self) -> TableClause: def get_from_clause( self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Union[TableClause, Alias]: + ) -> Tuple[Union[TableClause, Alias], Optional[str]]: """ Return where to select the columns and metrics from. Either a physical table - or a virtual table with it's own subquery. + or a virtual table with it's own subquery. If the FROM is referencing a + CTE, the CTE is returned as the second value in the return tuple. """ if not self.is_virtual: - return self.get_sqla_table() + return self.get_sqla_table(), None from_sql = self.get_rendered_sql(template_processor) parsed_query = ParsedQuery(from_sql) @@ -817,7 +825,15 @@ def get_from_clause( raise QueryObjectValidationError( _("Virtual dataset query must be read-only") ) - return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + + cte = self.db_engine_spec.get_cte_query(from_sql) + from_clause = ( + table(CTE_ALIAS) + if cte + else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + ) + + return from_clause, cte def get_rendered_sql( self, template_processor: Optional[BaseTemplateProcessor] = None @@ -1224,7 +1240,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma qry = sa.select(select_exprs) - tbl = self.get_from_clause(template_processor) + tbl, cte = self.get_from_clause(template_processor) if groupby_all_columns: qry = qry.group_by(*groupby_all_columns.values()) @@ -1491,6 +1507,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma return SqlaQuery( applied_template_filters=applied_template_filters, + cte=cte, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5c215c2abde04..0c0d0662ad92f 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -38,7 +38,6 @@ import pandas as pd import sqlparse -from sqlparse.tokens import Keyword, CTE from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask import current_app, g @@ -55,6 +54,7 @@ from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine +from sqlparse.tokens import CTE, Keyword from typing_extensions import TypedDict from superset import security_manager, sql_parse @@ -670,12 +670,12 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str: return parsed_query.set_or_update_query_limit(limit) @classmethod - def get_cte_query(cls, sql) -> Optional[str]: + def get_cte_query(cls, sql: str) -> Optional[str]: """ Convert the input CTE based SQL to the SQL for virtual table conversion :param sql: SQL query - :return: Query with __query alias + :return: CTE with the main select query aliased as `__cte` """ if not cls.allows_cte_in_subquery: @@ -689,10 +689,10 @@ def get_cte_query(cls, sql) -> Optional[str]: idx = p.token_index(tok) + 1 # extract rest of the SQLs after CTE - remainder = u"".join(str(tok) for tok in p.tokens[idx:]) - __query = "WITH " + tok.value + ",\n__query as (" + remainder + ")" + remainder = "".join(str(tok) for tok in p.tokens[idx:]).strip() + __query = "WITH " + tok.value + ",\n__cte AS (\n" + remainder + "\n)" return __query - + return None @classmethod diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 1c4682ad9a9b6..2e824ab5bca2c 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -925,7 +925,7 @@ def test_comments_in_sqlatable_query(self): sql=commented_query, database=get_example_database(), ) - rendered_query = str(table.get_from_clause()) + rendered_query = str(table.get_from_clause()[0]) self.assertEqual(clean_query, rendered_query) def test_slice_payload_no_datasource(self): diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 507f268a3ad21..ebb8a2d332f94 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -179,39 +179,52 @@ def test_column_datatype_to_string( actual = MssqlEngineSpec.column_datatype_to_string(original, mssql.dialect()) assert actual == expected + @pytest.mark.parametrize( "original,expected", [ - ("with currency as\n" - "(\n" - "select 'INR' as cur\n" - ")\n" - "select * from currency\n" - , "WITH currency as\n" - "(\n" - "select 'INR' as cur\n" - "),\n" - "__query as (\n" - "select * from currency\n" - ")",), + ( + dedent( + """ +with currency as +( +select 'INR' as cur +) +select * from currency +""" + ), + dedent( + """WITH currency as +( +select 'INR' as cur +), +__cte AS ( +select * from currency +)""" + ), + ), ("SELECT 1 as cnt", None,), - ("select 'INR' as cur\n" - "union\n" - "select 'AUD' as cur\n" - "union \n" - "select 'USD' as cur\n", None) + ( + dedent( + """ +select 'INR' as cur +union +select 'AUD' as cur +union +select 'USD' as cur +""" + ), + None, + ), ], ) def test_cte_query_parsing( app_context: AppContext, original: TypeEngine, expected: str ) -> None: - from superset.db_engine_specs.mssql import MssqlEngineSpec + from superset.db_engine_specs.mssql import MssqlEngineSpec - actual = MssqlEngineSpec.get_cte_query(original) - print(original) - print(expected) - print(actual) - assert actual == expected + actual = MssqlEngineSpec.get_cte_query(original) + assert actual == expected def test_extract_errors(app_context: AppContext) -> None: From 6a47c7d4d32ca928fd3aca7f4f875135bc891fbc Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sat, 5 Feb 2022 10:22:14 +0200 Subject: [PATCH 7/7] lint + cleanup --- superset/connectors/sqla/models.py | 26 ++++++----- superset/db_engine_specs/base.py | 20 +++++---- superset/db_engine_specs/mssql.py | 2 - tests/unit_tests/db_engine_specs/test_base.py | 43 +++++++++++++++++++ .../unit_tests/db_engine_specs/test_mssql.py | 16 ++++--- 5 files changed, 80 insertions(+), 27 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 41f29f7ad74c0..ca1d4bc57a022 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -77,7 +77,7 @@ get_physical_table_metadata, get_virtual_table_metadata, ) -from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression +from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import QueryObjectValidationError from superset.jinja_context import ( BaseTemplateProcessor, @@ -103,7 +103,6 @@ logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" -CTE_ALIAS = "__cte" class SqlaQuery(NamedTuple): @@ -129,12 +128,6 @@ class MetadataResult: modified: List[str] = field(default_factory=list) -def _apply_cte(sql: str, cte: Optional[str]) -> str: - if cte: - sql = f"{cte}{sql}" - return sql - - class AnnotationDatasource(BaseDatasource): """Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -570,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def __repr__(self) -> str: return self.name + @staticmethod + def _apply_cte(sql: str, cte: Optional[str]) -> str: + """ + Append a CTE before the SELECT statement if defined + + :param sql: SELECT statement + :param cte: CTE statement + :return: + """ + if cte: + sql = f"{cte}\n{sql}" + return sql + @property def db_engine_spec(self) -> Type[BaseEngineSpec]: return self.database.db_engine_spec @@ -762,7 +768,7 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: engine = self.database.get_sqla_engine() sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = _apply_cte(sql, cte) + sql = self._apply_cte(sql, cte) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) @@ -784,7 +790,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) - sql = _apply_cte(sql, sqlaq.cte) + sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 0c0d0662ad92f..764f3fde70580 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -54,7 +54,7 @@ from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine -from sqlparse.tokens import CTE, Keyword +from sqlparse.tokens import CTE from typing_extensions import TypedDict from superset import security_manager, sql_parse @@ -81,6 +81,9 @@ logger = logging.getLogger() +CTE_ALIAS = "__cte" + + class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove label: str @@ -679,19 +682,18 @@ def get_cte_query(cls, sql: str) -> Optional[str]: """ if not cls.allows_cte_in_subquery: - p = sqlparse.parse(sql)[0] + stmt = sqlparse.parse(sql)[0] # The first meaningful token for CTE will be with WITH - idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) - if not (tok and tok.ttype == CTE): + idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True) + if not (token and token.ttype == CTE): return None - idx, tok = p.token_next(idx) - idx = p.token_index(tok) + 1 + idx, token = stmt.token_next(idx) + idx = stmt.token_index(token) + 1 # extract rest of the SQLs after CTE - remainder = "".join(str(tok) for tok in p.tokens[idx:]).strip() - __query = "WITH " + tok.value + ",\n__cte AS (\n" + remainder + "\n)" - return __query + remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip() + return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)" return None diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 0d73a7b80d280..e5c66e046a082 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -24,8 +24,6 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.errors import SupersetErrorType from superset.utils import core as utils -import sqlparse -from sqlparse.tokens import Keyword, CTE logger = logging.getLogger(__name__) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index d822f50de9d8a..4dc27c0928f99 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -16,7 +16,11 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +from textwrap import dedent + +import pytest from flask.ctx import AppContext +from sqlalchemy.types import TypeEngine def test_get_text_clause_with_colon(app_context: AppContext) -> None: @@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None: "SELECT foo FROM tbl1", "SELECT bar FROM tbl2", ] + + +@pytest.mark.parametrize( + "original,expected", + [ + ( + dedent( + """ +with currency as +( +select 'INR' as cur +) +select * from currency +""" + ), + None, + ), + ("SELECT 1 as cnt", None,), + ( + dedent( + """ +select 'INR' as cur +union +select 'AUD' as cur +union +select 'USD' as cur +""" + ), + None, + ), + ], +) +def test_cte_query_parsing( + app_context: AppContext, original: TypeEngine, expected: str +) -> None: + from superset.db_engine_specs.base import BaseEngineSpec + + actual = BaseEngineSpec.get_cte_query(original) + assert actual == expected diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index ebb8a2d332f94..250b8158fa320 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -186,20 +186,24 @@ def test_column_datatype_to_string( ( dedent( """ -with currency as -( +with currency as ( select 'INR' as cur +), +currency_2 as ( +select 'EUR' as cur ) -select * from currency +select * from currency union all select * from currency_2 """ ), dedent( - """WITH currency as -( + """WITH currency as ( select 'INR' as cur ), +currency_2 as ( +select 'EUR' as cur +), __cte AS ( -select * from currency +select * from currency union all select * from currency_2 )""" ), ),