Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Dec 23, 2020
1 parent a52031a commit f795874
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
57 changes: 46 additions & 11 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from superset.extensions import celery_app
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.sql_parse import CtasMethod, ParsedQuery
from superset.utils.celery import session_scope
from superset.utils.core import (
json_iso_dttm_ser,
Expand Down Expand Up @@ -160,6 +160,7 @@ def execute_sql_statement(
session: Session,
cursor: Any,
log_params: Optional[Dict[str, Any]],
apply_ctas: bool = False,
) -> SupersetResultSet:
"""Executes a single SQL statement"""
database = query.database
Expand All @@ -171,14 +172,7 @@ def execute_sql_statement(
raise SqlLabSecurityException(
_("Only `SELECT` statements are allowed against this database")
)
if query.select_as_cta:
if not parsed_query.is_select():
raise SqlLabException(
_(
"Only `SELECT` statements can be used with the CREATE TABLE "
"feature."
)
)
if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = "tmp_{}_table_{}".format(
Expand Down Expand Up @@ -322,8 +316,8 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
raise SqlLabException("Results backend isn't configured.")

# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
if not db_engine_spec.run_multiple_statements_as_one:
parsed_query = ParsedQuery(rendered_query)
statements = parsed_query.get_statements()
logger.info(
"Query %s: Executing %i statement(s)", str(query_id), len(statements)
Expand All @@ -337,6 +331,32 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
query.start_running_time = now_as_float()
session.commit()

# Should we create a table or view from the select?
if (
query.select_as_cta
and query.ctas_method == CtasMethod.TABLE
and not parsed_query.is_valid_ctas()
):
raise SqlLabException(
_(
"CTAS (create table as select) can only be run with a query where "
"the last statement is a SELECT. Please make sure your query has "
"a SELECT as its last statement. Then, try running your query again."
)
)
if (
query.select_as_cta
and query.ctas_method == CtasMethod.VIEW
and not parsed_query.is_valid_cvas()
):
raise SqlLabException(
_(
"CVAS (create view as select) can only be run with a query with "
"a single SELECT statement. Please make sure your query has only "
"a SELECT statement. Then, try running your query again."
)
)

engine = database.get_sqla_engine(
schema=query.schema,
nullpool=True,
Expand All @@ -354,14 +374,29 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
if query.status == QueryStatus.STOPPED:
return None

# For CTAS we create the table only on the last statement
apply_ctas = query.select_as_cta and (
query.ctas_method == CtasMethod.VIEW
or (
query.ctas_method == CtasMethod.TABLE
and i == len(statements) - 1
)
)

# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logger.info("Query %s: %s", str(query_id), msg)
query.set_extra_json_key("progress", msg)
session.commit()
try:
result_set = execute_sql_statement(
statement, query, user_name, session, cursor, log_params
statement,
query,
user_name,
session,
cursor,
log_params,
apply_ctas,
)
except Exception as ex: # pylint: disable=broad-except
msg = str(ex)
Expand Down
11 changes: 10 additions & 1 deletion superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def __str__(self) -> str:


class ParsedQuery:
def __init__(self, sql_statement: str):
def __init__(self, sql_statement: str, strip_comments: bool = False):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

self.sql: str = sql_statement
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
Expand Down Expand Up @@ -110,6 +113,12 @@ def limit(self) -> Optional[int]:
def is_select(self) -> bool:
return self._parsed[0].get_type() == "SELECT"

def is_valid_ctas(self) -> bool:
return self._parsed[-1].get_type() == "SELECT"

def is_valid_cvas(self) -> bool:
return len(self._parsed) == 1 and self._parsed[0].get_type() == "SELECT"

def is_explain(self) -> bool:
# Remove comments
statements_without_comments = sqlparse.format(
Expand Down

0 comments on commit f795874

Please sign in to comment.