Skip to content

Commit

Permalink
feat(sqlparse): improve table parsing (#26476)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored and michael-s-molina committed Feb 1, 2024
1 parent c739ce7 commit fb64100
Show file tree
Hide file tree
Showing 17 changed files with 253 additions and 118 deletions.
2 changes: 2 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
sqlglot==20.8.0
# via apache-superset
sqlparse==0.4.4
# via apache-superset
sshtunnel==0.4.0
Expand Down
6 changes: 1 addition & 5 deletions requirements/testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
ephem==4.1.4
# via lunarcalendar
exceptiongroup==1.1.1
# via pytest
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
Expand Down Expand Up @@ -119,7 +115,7 @@ pyee==9.0.4
# via playwright
pyfakefs==5.2.2
# via -r requirements/testing.in
pyhive[presto]==0.6.5
pyhive[presto]==0.7.0
# via apache-superset
pytest==7.3.1
# via
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_git_sha() -> str:
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def get_from_clause(
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(
Expand Down
5 changes: 4 additions & 1 deletion superset/datasets/commands/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def run(self) -> Model:
table.template_params = self._base_model.template_params
table.normalize_columns = self._base_model.normalize_columns
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
Expand Down
12 changes: 6 additions & 6 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def apply_limit_to_sql(
return database.compile_sqla_query(qry)

if cls.limit_method == LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force)

return sql
Expand Down Expand Up @@ -955,7 +955,7 @@ def get_limit_from_sql(cls, sql: str) -> int | None:
:param sql: SQL query
:return: Value of limit clause in query
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit

@classmethod
Expand All @@ -967,7 +967,7 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.set_or_update_query_limit(limit)

@classmethod
Expand Down Expand Up @@ -1450,7 +1450,7 @@ def process_statement(cls, statement: str, database: Database) -> str:
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
Expand Down Expand Up @@ -1483,7 +1483,7 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()

costs = []
Expand Down Expand Up @@ -1544,7 +1544,7 @@ def execute( # pylint: disable=unused-argument
:return:
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query)
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)

if cls.arraysize:
cursor.arraysize = cls.arraysize
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
for statement in statements:
Expand Down
2 changes: 1 addition & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def get_from_clause(
"""

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
6 changes: 4 additions & 2 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def username(self) -> str:

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)

@property
def columns(self) -> list["TableColumn"]:
Expand Down Expand Up @@ -428,7 +428,9 @@ def url(self) -> str:

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)

@property
def last_run_humanized(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,10 @@ def raise_for_access(
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(query.sql).tables
for table_ in sql_parse.ParsedQuery(
query.sql,
engine=database.db_engine_spec.engine,
).tables
}
elif table:
tables = {table}
Expand Down
11 changes: 8 additions & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments
database: Database = query.database
db_engine_spec = database.db_engine_spec

parsed_query = ParsedQuery(sql_statement)
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# Insert any applicable RLS predicates
parsed_query = ParsedQuery(
Expand All @@ -213,7 +213,8 @@ def execute_sql_statement( # pylint: disable=too-many-arguments
database.id,
query.schema,
)
)
),
engine=db_engine_spec.engine,
)

sql = parsed_query.stripped()
Expand Down Expand Up @@ -404,7 +405,11 @@ def execute_sql_statements(
)

# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
parsed_query = ParsedQuery(
rendered_query,
strip_comments=True,
engine=db_engine_spec.engine,
)
if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements()
logger.info(
Expand Down
Loading

0 comments on commit fb64100

Please sign in to comment.