From 25c599d0400cf320ce5accf50ed88e76ccff5980 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 27 Jul 2017 09:47:31 -0700 Subject: [PATCH] Escaping the user's SQL in the explore view (#3186) * Escaping the user's SQL in the explore view When executing SQL from SQL Lab, we use a lower level API to the database which doesn't require escaping the SQL. When going through the explore view, the stack chain leading to the same method may need escaping depending on how the DBAPI driver is written, and that is the case for Presto (and perhaps other drivers). * Using regex to avoid doubling doubles --- superset/connectors/sqla/models.py | 16 ++++++++++------ superset/db_engine_specs.py | 17 +++++++++-------- superset/sql_lab.py | 1 - 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 147c667df04bf..0d06bfbde04c8 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -285,10 +285,12 @@ def values_for_column(self, column_name, limit=10000): """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] + tp = self.get_template_processor() + db_engine_spec = self.database.db_engine_spec qry = ( select([target_col.sqla_col]) - .select_from(self.get_from_clause()) + .select_from(self.get_from_clause(tp, db_engine_spec)) .distinct(column_name) ) if limit: @@ -322,7 +324,6 @@ def get_query_str(self, query_obj): ) logging.info(sql) sql = sqlparse.format(sql, reindent=True) - sql = self.database.db_engine_spec.sql_preprocessor(sql) return sql def get_sqla_table(self): @@ -331,12 +332,14 @@ def get_sqla_table(self): tbl.schema = self.schema return tbl - def get_from_clause(self, template_processor=None): + def get_from_clause(self, template_processor=None, db_engine_spec=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) + if db_engine_spec: + from_sql = db_engine_spec.escape_sql(from_sql) return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() @@ -367,13 +370,14 @@ def get_sqla_query( # sqla 'form_data': form_data, } template_processor = self.get_template_processor(**template_kwargs) + db_engine_spec = self.database.db_engine_spec # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping - time_groupby_inline = self.database.db_engine_spec.time_groupby_inline + time_groupby_inline = db_engine_spec.time_groupby_inline cols = {col.column_name: col for col in self.columns} metrics_dict = {m.metric_name: m for m in self.metrics} @@ -428,7 +432,7 @@ def get_sqla_query( # sqla groupby_exprs += [timestamp] # Use main dttm column to support index with secondary dttm columns - if self.database.db_engine_spec.time_secondary_columns and \ + if db_engine_spec.time_secondary_columns and \ self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col]. @@ -438,7 +442,7 @@ def get_sqla_query( # sqla select_exprs += metrics_exprs qry = sa.select(select_exprs) - tbl = self.get_from_clause(template_processor) + tbl = self.get_from_clause(template_processor, db_engine_spec) if not columns: qry = qry.group_by(*groupby_exprs) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index b460226c1aa6c..d08f2a8feb795 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -73,6 +73,11 @@ def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" return {} + @classmethod + def escape_sql(cls, sql): + """Escapes the raw SQL""" + return sql + @classmethod def convert_dttm(cls, target_type, dttm): return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) @@ -139,14 +144,6 @@ def adjust_database_uri(cls, uri, selected_schema): """ return uri - @classmethod - def sql_preprocessor(cls, sql): - """If the SQL needs to be altered prior to running it - - For example Presto needs to double `%` characters - """ - return sql - @classmethod def patch(cls): pass @@ -399,6 +396,10 @@ def adjust_database_uri(cls, uri, selected_schema=None): uri.database = database return uri + @classmethod + def escape_sql(cls, sql): + return re.sub(r'%%|%', "%%", sql) + @classmethod def convert_dttm(cls, target_type, dttm): tt = target_type.upper() diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 4b0bd863bcd04..638b29abbee19 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -154,7 +154,6 @@ def handle_error(msg): template_processor = get_template_processor( database=database, query=query) executed_sql = template_processor.process_template(executed_sql) - executed_sql = db_engine_spec.sql_preprocessor(executed_sql) except Exception as e: logging.exception(e) msg = "Template rendering failed: " + utils.error_msg_from_exception(e)