diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 402e26462e041..949323b9aa75c 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -61,7 +61,7 @@ export type ExploreQuery = QueryResponse & { }; export interface ISimpleColumn { - column_name?: string | null; + name?: string | null; type?: string | null; is_dttm?: boolean | null; } @@ -216,7 +216,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${datasetToOverwrite.datasetid}__table`, ...(defaultVizType === 'table' && { - all_columns: datasource?.columns?.map(column => column.column_name), + all_columns: datasource?.columns?.map(column => column.name), }), }), ]); @@ -301,7 +301,7 @@ export const SaveDatasetModal = ({ ...formDataWithDefaults, datasource: `${data.table_id}__table`, ...(defaultVizType === 'table' && { - all_columns: selectedColumns.map(column => column.column_name), + all_columns: selectedColumns.map(column => column.name), }), }), ) diff --git a/superset-frontend/src/SqlLab/fixtures.ts b/superset-frontend/src/SqlLab/fixtures.ts index ba88a41b0accc..fcb0fff8e3d70 100644 --- a/superset-frontend/src/SqlLab/fixtures.ts +++ b/superset-frontend/src/SqlLab/fixtures.ts @@ -692,17 +692,17 @@ export const testQuery: ISaveableDatasource = { sql: 'SELECT *', columns: [ { - column_name: 'Column 1', + name: 'Column 1', type: DatasourceType.Query, is_dttm: false, }, { - column_name: 'Column 3', + name: 'Column 3', type: DatasourceType.Query, is_dttm: false, }, { - column_name: 'Column 2', + name: 'Column 2', type: DatasourceType.Query, is_dttm: true, }, diff --git a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx index c74212f0baf6b..80cf879f7f256 100644 --- a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx +++ b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx @@ -48,7 +48,7 @@ class AdhocMetricOption extends React.PureComponent { } onRemoveMetric(e) { - e?.stopPropagation(); + e.stopPropagation(); this.props.onRemoveMetric(this.props.index); } diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6180a546e7500..593c5f853b935 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -31,6 +31,7 @@ Dict, Hashable, List, + NamedTuple, Optional, Set, Tuple, @@ -49,9 +50,11 @@ from jinja2.exceptions import TemplateError from sqlalchemy import ( and_, + asc, Boolean, Column, DateTime, + desc, Enum, ForeignKey, inspect, @@ -77,11 +80,13 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table from sqlalchemy.sql.elements import ColumnClause, TextClause -from sqlalchemy.sql.expression import Label, TextAsFrom +from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager +from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus +from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( find_cached_objects_in_session, @@ -93,6 +98,7 @@ from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression from superset.exceptions import ( + AdvancedDataTypeResponseError, ColumnNotFoundException, DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, @@ -100,6 +106,7 @@ SupersetGenericDBErrorException, SupersetSecurityException, ) +from superset.extensions import feature_flag_manager from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -107,17 +114,26 @@ ) from superset.models.annotations import Annotation from superset.models.core import Database -from superset.models.helpers import ( - AuditMixinNullable, - CertificationMixin, - ExploreMixin, - QueryResult, - QueryStringExtended, -) +from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult from superset.sql_parse import ParsedQuery, sanitize_clause -from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict +from superset.superset_typing import ( + AdhocColumn, + AdhocMetric, + Column as ColumnTyping, + Metric, + OrderBy, + QueryObjectDict, +) from superset.utils import core as utils -from superset.utils.core import GenericDataType, get_username, MediumText +from superset.utils.core import ( + GenericDataType, + get_column_name, + get_username, + is_adhoc_column, + MediumText, + QueryObjectFilterClause, + remove_duplicates, +) config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -134,6 +150,26 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES} +class SqlaQuery(NamedTuple): + applied_template_filters: List[str] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] + cte: Optional[str] + extra_cache_keys: List[Any] + labels_expected: List[str] + prequeries: List[str] + sqla_query: Select + + +class QueryStringExtended(NamedTuple): + applied_template_filters: Optional[List[str]] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] + labels_expected: List[str] + prequeries: List[str] + sql: str + + @dataclass class MetadataResult: added: List[str] = field(default_factory=list) @@ -274,35 +310,6 @@ def db_extra(self) -> Dict[str, Any]: def type_generic(self) -> Optional[utils.GenericDataType]: if self.is_dttm: return GenericDataType.TEMPORAL - - bool_types = ("BOOL",) - num_types = ( - "DOUBLE", - "FLOAT", - "INT", - "BIGINT", - "NUMBER", - "LONG", - "REAL", - "NUMERIC", - "DECIMAL", - "MONEY", - ) - date_types = ("DATE", "TIME") - str_types = ("VARCHAR", "STRING", "CHAR") - - if self.table is None: - # Query.TableColumns don't have a reference to a table.db_engine_spec - # reference so this logic will manage rendering types - if self.type and any(map(lambda t: t in self.type.upper(), str_types)): - return GenericDataType.STRING - if self.type and any(map(lambda t: t in self.type.upper(), bool_types)): - return GenericDataType.BOOLEAN - if self.type and any(map(lambda t: t in self.type.upper(), num_types)): - return GenericDataType.NUMERIC - if self.type and any(map(lambda t: t in self.type.upper(), date_types)): - return GenericDataType.TEMPORAL - column_spec = self.db_engine_spec.get_column_spec( self.type, db_extra=self.db_extra ) @@ -538,10 +545,8 @@ def _process_sql_expression( return expression -class SqlaTable( - Model, BaseDatasource, ExploreMixin -): # pylint: disable=too-many-public-methods - """An ORM object for SqlAlchemy table references""" +class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods + """An ORM object for SqlAlchemy table references.""" type = "table" query_language = "sql" @@ -621,10 +626,6 @@ class SqlaTable( def __repr__(self) -> str: # pylint: disable=invalid-repr-returned return self.name - @property - def db_extra(self) -> Dict[str, Any]: - return self.database.get_extra() - @staticmethod def _apply_cte(sql: str, cte: Optional[str]) -> str: """ @@ -1150,6 +1151,680 @@ def get_sqla_row_level_filters( def text(self, clause: str) -> TextClause: return self.db_engine_spec.get_text_clause(clause) + def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + self, + apply_fetch_values_predicate: bool = False, + columns: Optional[List[ColumnTyping]] = None, + extras: Optional[Dict[str, Any]] = None, + filter: Optional[ # pylint: disable=redefined-builtin + List[QueryObjectFilterClause] + ] = None, + from_dttm: Optional[datetime] = None, + granularity: Optional[str] = None, + groupby: Optional[List[Column]] = None, + inner_from_dttm: Optional[datetime] = None, + inner_to_dttm: Optional[datetime] = None, + is_rowcount: bool = False, + is_timeseries: bool = True, + metrics: Optional[List[Metric]] = None, + orderby: Optional[List[OrderBy]] = None, + order_desc: bool = True, + to_dttm: Optional[datetime] = None, + series_columns: Optional[List[Column]] = None, + series_limit: Optional[int] = None, + series_limit_metric: Optional[Metric] = None, + row_limit: Optional[int] = None, + row_offset: Optional[int] = None, + timeseries_limit: Optional[int] = None, + timeseries_limit_metric: Optional[Metric] = None, + time_shift: Optional[str] = None, + ) -> SqlaQuery: + """Querying any sqla table from this common interface""" + if granularity not in self.dttm_cols and granularity is not None: + granularity = self.main_dttm_col + + extras = extras or {} + time_grain = extras.get("time_grain_sqla") + + template_kwargs = { + "columns": columns, + "from_dttm": from_dttm.isoformat() if from_dttm else None, + "groupby": groupby, + "metrics": metrics, + "row_limit": row_limit, + "row_offset": row_offset, + "time_column": granularity, + "time_grain": time_grain, + "to_dttm": to_dttm.isoformat() if to_dttm else None, + "table_columns": [col.column_name for col in self.columns], + "filter": filter, + } + columns = columns or [] + groupby = groupby or [] + rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + series_column_names = utils.get_column_names(series_columns or []) + # deprecated, to be removed in 2.0 + if is_timeseries and timeseries_limit: + series_limit = timeseries_limit + series_limit_metric = series_limit_metric or timeseries_limit_metric + template_kwargs.update(self.template_params_dict) + extra_cache_keys: List[Any] = [] + template_kwargs["extra_cache_keys"] = extra_cache_keys + removed_filters: List[str] = [] + applied_template_filters: List[str] = [] + template_kwargs["removed_filters"] = removed_filters + template_kwargs["applied_filters"] = applied_template_filters + template_processor = self.get_template_processor(**template_kwargs) + db_engine_spec = self.db_engine_spec + prequeries: List[str] = [] + orderby = orderby or [] + need_groupby = bool(metrics is not None or groupby) + metrics = metrics or [] + + # For backward compatibility + if granularity not in self.dttm_cols and granularity is not None: + granularity = self.main_dttm_col + + columns_by_name: Dict[str, TableColumn] = { + col.column_name: col for col in self.columns + } + + metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + + if not granularity and is_timeseries: + raise QueryObjectValidationError( + _( + "Datetime column not provided as part table configuration " + "and is required by this type of chart" + ) + ) + if not metrics and not columns and not groupby: + raise QueryObjectValidationError(_("Empty query?")) + + metrics_exprs: List[ColumnElement] = [] + for metric in metrics: + if utils.is_adhoc_metric(metric): + assert isinstance(metric, dict) + metrics_exprs.append( + self.adhoc_metric_to_sqla( + metric=metric, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) + ) + elif isinstance(metric, str) and metric in metrics_by_name: + metrics_exprs.append( + metrics_by_name[metric].get_sqla_col( + template_processor=template_processor + ) + ) + else: + raise QueryObjectValidationError( + _("Metric '%(metric)s' does not exist", metric=metric) + ) + + if metrics_exprs: + main_metric_expr = metrics_exprs[0] + else: + main_metric_expr, label = literal_column("COUNT(*)"), "count" + main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) + + # To ensure correct handling of the ORDER BY labeling we need to reference the + # metric instance if defined in the SELECT clause. + # use the key of the ColumnClause for the expected label + metrics_exprs_by_label = {m.key: m for m in metrics_exprs} + metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} + + # Since orderby may use adhoc metrics, too; we need to process them first + orderby_exprs: List[ColumnElement] = [] + for orig_col, ascending in orderby: + col: Union[AdhocMetric, ColumnElement] = orig_col + if isinstance(col, dict): + col = cast(AdhocMetric, col) + if col.get("sqlExpression"): + col["sqlExpression"] = _process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + if utils.is_adhoc_metric(col): + # add adhoc sort by column to columns_by_name if not exists + col = self.adhoc_metric_to_sqla(col, columns_by_name) + # if the adhoc metric has been defined before + # use the existing instance. + col = metrics_exprs_by_expr.get(str(col), col) + need_groupby = True + elif col in columns_by_name: + col = columns_by_name[col].get_sqla_col( + template_processor=template_processor + ) + elif col in metrics_exprs_by_label: + col = metrics_exprs_by_label[col] + need_groupby = True + elif col in metrics_by_name: + col = metrics_by_name[col].get_sqla_col( + template_processor=template_processor + ) + need_groupby = True + + if isinstance(col, ColumnElement): + orderby_exprs.append(col) + else: + # Could not convert a column reference to valid ColumnElement + raise QueryObjectValidationError( + _("Unknown column used in orderby: %(col)s", col=orig_col) + ) + + select_exprs: List[Union[Column, Label]] = [] + groupby_all_columns = {} + groupby_series_columns = {} + + # filter out the pseudo column __timestamp from columns + columns = [col for col in columns if col != utils.DTTM_ALIAS] + dttm_col = columns_by_name.get(granularity) if granularity else None + + if need_groupby: + # dedup columns while preserving order + columns = groupby or columns + for selected in columns: + if isinstance(selected, str): + # if groupby field/expr equals granularity field/expr + if selected == granularity: + table_col = columns_by_name[selected] + outer = table_col.get_timestamp_expression( + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) + # if groupby field equals a selected column + elif selected in columns_by_name: + outer = columns_by_name[selected].get_sqla_col( + template_processor=template_processor + ) + else: + selected = validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) + outer = literal_column(f"({selected})") + outer = self.make_sqla_column_compatible(outer, selected) + else: + outer = self.adhoc_column_to_sqla( + col=selected, template_processor=template_processor + ) + groupby_all_columns[outer.name] = outer + if ( + is_timeseries and not series_column_names + ) or outer.name in series_column_names: + groupby_series_columns[outer.name] = outer + select_exprs.append(outer) + elif columns: + for selected in columns: + if is_adhoc_column(selected): + _sql = selected["sqlExpression"] + _column_label = selected["label"] + elif isinstance(selected, str): + _sql = selected + _column_label = selected + + selected = validate_adhoc_subquery( + _sql, + self.database_id, + self.schema, + ) + select_exprs.append( + columns_by_name[selected].get_sqla_col( + template_processor=template_processor + ) + if isinstance(selected, str) and selected in columns_by_name + else self.make_sqla_column_compatible( + literal_column(selected), _column_label + ) + ) + metrics_exprs = [] + + if granularity: + if granularity not in columns_by_name or not dttm_col: + raise QueryObjectValidationError( + _( + 'Time column "%(col)s" does not exist in dataset', + col=granularity, + ) + ) + time_filters = [] + + if is_timeseries: + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) + # always put timestamp as the first column + select_exprs.insert(0, timestamp) + groupby_all_columns[timestamp.name] = timestamp + + # Use main dttm column to support index with secondary dttm columns. + if ( + db_engine_spec.time_secondary_columns + and self.main_dttm_col in self.dttm_cols + and self.main_dttm_col != dttm_col.column_name + ): + time_filters.append( + columns_by_name[self.main_dttm_col].get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) + time_filters.append( + dttm_col.get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) + + # Always remove duplicates by column name, as sometimes `metrics_exprs` + # can have the same name as a groupby column (e.g. when users use + # raw columns as custom SQL adhoc metric). + select_exprs = remove_duplicates( + select_exprs + metrics_exprs, key=lambda x: x.name + ) + + # Expected output columns + labels_expected = [c.key for c in select_exprs] + + # Order by columns are "hidden" columns, some databases require them + # always be present in SELECT if an aggregation function is used + if not db_engine_spec.allows_hidden_orderby_agg: + select_exprs = remove_duplicates(select_exprs + orderby_exprs) + + qry = sa.select(select_exprs) + + tbl, cte = self.get_from_clause(template_processor) + + if groupby_all_columns: + qry = qry.group_by(*groupby_all_columns.values()) + + where_clause_and = [] + having_clause_and = [] + + for flt in filter: # type: ignore + if not all(flt.get(s) for s in ["col", "op"]): + continue + flt_col = flt["col"] + val = flt.get("val") + op = flt["op"].upper() + col_obj: Optional[TableColumn] = None + sqla_col: Optional[Column] = None + if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: + col_obj = dttm_col + elif is_adhoc_column(flt_col): + try: + sqla_col = self.adhoc_column_to_sqla( + col=flt_col, + force_type_check=True, + template_processor=template_processor, + ) + applied_adhoc_filters_columns.append(flt_col) + except ColumnNotFoundException: + rejected_adhoc_filters_columns.append(flt_col) + continue + else: + col_obj = columns_by_name.get(cast(str, flt_col)) + filter_grain = flt.get("grain") + + if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): + if get_column_name(flt_col) in removed_filters: + # Skip generating SQLA filter when the jinja template handles it. + continue + + if col_obj or sqla_col is not None: + if sqla_col is not None: + pass + elif col_obj and filter_grain: + sqla_col = col_obj.get_timestamp_expression( + time_grain=filter_grain, template_processor=template_processor + ) + elif col_obj: + sqla_col = col_obj.get_sqla_col( + template_processor=template_processor + ) + col_type = col_obj.type if col_obj else None + col_spec = db_engine_spec.get_column_spec( + native_type=col_type, + db_extra=self.database.get_extra(), + ) + is_list_target = op in ( + utils.FilterOperator.IN.value, + utils.FilterOperator.NOT_IN.value, + ) + + col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" + + if col_spec and not col_advanced_data_type: + target_generic_type = col_spec.generic_type + else: + target_generic_type = GenericDataType.STRING + eq = self.filter_values_handler( + values=val, + operator=op, + target_generic_type=target_generic_type, + target_native_type=col_type, + is_list_target=is_list_target, + db_engine_spec=db_engine_spec, + db_extra=self.database.get_extra(), + ) + if ( + col_advanced_data_type != "" + and feature_flag_manager.is_feature_enabled( + "ENABLE_ADVANCED_DATA_TYPES" + ) + and col_advanced_data_type in ADVANCED_DATA_TYPES + ): + values = eq if is_list_target else [eq] # type: ignore + bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ + col_advanced_data_type + ].translate_type( + { + "type": col_advanced_data_type, + "values": values, + } + ) + if bus_resp["error_message"]: + raise AdvancedDataTypeResponseError( + _(bus_resp["error_message"]) + ) + + where_clause_and.append( + ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( + sqla_col, op, bus_resp["values"] + ) + ) + elif is_list_target: + assert isinstance(eq, (tuple, list)) + if len(eq) == 0: + raise QueryObjectValidationError( + _("Filter value list cannot be empty") + ) + if len(eq) > len( + eq_without_none := [x for x in eq if x is not None] + ): + is_null_cond = sqla_col.is_(None) + if eq: + cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) + else: + cond = is_null_cond + else: + cond = sqla_col.in_(eq) + if op == utils.FilterOperator.NOT_IN.value: + cond = ~cond + where_clause_and.append(cond) + elif op == utils.FilterOperator.IS_NULL.value: + where_clause_and.append(sqla_col.is_(None)) + elif op == utils.FilterOperator.IS_NOT_NULL.value: + where_clause_and.append(sqla_col.isnot(None)) + elif op == utils.FilterOperator.IS_TRUE.value: + where_clause_and.append(sqla_col.is_(True)) + elif op == utils.FilterOperator.IS_FALSE.value: + where_clause_and.append(sqla_col.is_(False)) + else: + if ( + op + not in { + utils.FilterOperator.EQUALS.value, + utils.FilterOperator.NOT_EQUALS.value, + } + and eq is None + ): + raise QueryObjectValidationError( + _( + "Must specify a value for filters " + "with comparison operators" + ) + ) + if op == utils.FilterOperator.EQUALS.value: + where_clause_and.append(sqla_col == eq) + elif op == utils.FilterOperator.NOT_EQUALS.value: + where_clause_and.append(sqla_col != eq) + elif op == utils.FilterOperator.GREATER_THAN.value: + where_clause_and.append(sqla_col > eq) + elif op == utils.FilterOperator.LESS_THAN.value: + where_clause_and.append(sqla_col < eq) + elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: + where_clause_and.append(sqla_col >= eq) + elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: + where_clause_and.append(sqla_col <= eq) + elif op == utils.FilterOperator.LIKE.value: + where_clause_and.append(sqla_col.like(eq)) + elif op == utils.FilterOperator.ILIKE.value: + where_clause_and.append(sqla_col.ilike(eq)) + elif ( + op == utils.FilterOperator.TEMPORAL_RANGE.value + and isinstance(eq, str) + and col_obj is not None + ): + _since, _until = get_since_until_from_time_range( + time_range=eq, + time_shift=time_shift, + extras=extras, + ) + where_clause_and.append( + col_obj.get_time_filter( + start_dttm=_since, + end_dttm=_until, + label=sqla_col.key, + template_processor=template_processor, + ) + ) + else: + raise QueryObjectValidationError( + _("Invalid filter operation type: %(op)s", op=op) + ) + where_clause_and += self.get_sqla_row_level_filters(template_processor) + if extras: + where = extras.get("where") + if where: + try: + where = template_processor.process_template(f"({where})") + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in WHERE clause: %(msg)s", + msg=ex.message, + ) + ) from ex + where = _process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + ) + where_clause_and += [self.text(where)] + having = extras.get("having") + if having: + try: + having = template_processor.process_template(f"({having})") + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in HAVING clause: %(msg)s", + msg=ex.message, + ) + ) from ex + having = _process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + ) + having_clause_and += [self.text(having)] + + if apply_fetch_values_predicate and self.fetch_values_predicate: + qry = qry.where( + self.get_fetch_values_predicate(template_processor=template_processor) + ) + if granularity: + qry = qry.where(and_(*(time_filters + where_clause_and))) + else: + qry = qry.where(and_(*where_clause_and)) + qry = qry.having(and_(*having_clause_and)) + + self.make_orderby_compatible(select_exprs, orderby_exprs) + + for col, (orig_col, ascending) in zip(orderby_exprs, orderby): + if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): + # if engine does not allow using SELECT alias in ORDER BY + # revert to the underlying column + col = col.element + + if ( + db_engine_spec.allows_alias_in_select + and db_engine_spec.allows_hidden_cc_in_orderby + and col.name in [select_col.name for select_col in select_exprs] + ): + col = literal_column(col.name) + direction = asc if ascending else desc + qry = qry.order_by(direction(col)) + + if row_limit: + qry = qry.limit(row_limit) + if row_offset: + qry = qry.offset(row_offset) + + if series_limit and groupby_series_columns: + if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: + # some sql dialects require for order by expressions + # to also be in the select clause -- others, e.g. vertica, + # require a unique inner alias + inner_main_metric_expr = self.make_sqla_column_compatible( + main_metric_expr, "mme_inner__" + ) + inner_groupby_exprs = [] + inner_select_exprs = [] + for gby_name, gby_obj in groupby_series_columns.items(): + label = get_column_name(gby_name) + inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") + inner_groupby_exprs.append(inner) + inner_select_exprs.append(inner) + + inner_select_exprs += [inner_main_metric_expr] + subq = select(inner_select_exprs).select_from(tbl) + inner_time_filter = [] + + if dttm_col and not db_engine_spec.time_groupby_inline: + inner_time_filter = [ + dttm_col.get_time_filter( + start_dttm=inner_from_dttm or from_dttm, + end_dttm=inner_to_dttm or to_dttm, + template_processor=template_processor, + ) + ] + subq = subq.where(and_(*(where_clause_and + inner_time_filter))) + subq = subq.group_by(*inner_groupby_exprs) + + ob = inner_main_metric_expr + if series_limit_metric: + ob = self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) + direction = desc if order_desc else asc + subq = subq.order_by(direction(ob)) + subq = subq.limit(series_limit) + + on_clause = [] + for gby_name, gby_obj in groupby_series_columns.items(): + # in this case the column name, not the alias, needs to be + # conditionally mutated, as it refers to the column alias in + # the inner query + col_name = db_engine_spec.make_label_compatible(gby_name + "__") + on_clause.append(gby_obj == column(col_name)) + + tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + if series_limit_metric: + orderby = [ + ( + self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ), + not order_desc, + ) + ] + + # run prequery to get top groups + prequery_obj = { + "is_timeseries": False, + "row_limit": series_limit, + "metrics": metrics, + "granularity": granularity, + "groupby": groupby, + "from_dttm": inner_from_dttm or from_dttm, + "to_dttm": inner_to_dttm or to_dttm, + "filter": filter, + "orderby": orderby, + "extras": extras, + "columns": columns, + "order_desc": True, + } + + result = self.query(prequery_obj) + prequeries.append(result.query) + dimensions = [ + c + for c in result.df.columns + if c not in metrics and c in groupby_series_columns + ] + top_groups = self._get_top_groups( + result.df, dimensions, groupby_series_columns, columns_by_name + ) + qry = qry.where(top_groups) + + qry = qry.select_from(tbl) + + if is_rowcount: + if not db_engine_spec.allows_subqueries: + raise QueryObjectValidationError( + _("Database does not support subqueries") + ) + label = "rowcount" + col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) + qry = select([col]).select_from(qry.alias("rowcount_qry")) + labels_expected = [label] + + filter_columns = [flt.get("col") for flt in filter] if filter else [] + rejected_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and col not in self.column_names + and col not in applied_template_filters + ] + rejected_adhoc_filters_columns + applied_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and (col in self.column_names or col in applied_template_filters) + ] + applied_adhoc_filters_columns + + return SqlaQuery( + applied_template_filters=applied_template_filters, + rejected_filter_columns=rejected_filter_columns, + applied_filter_columns=applied_filter_columns, + cte=cte, + extra_cache_keys=extra_cache_keys, + labels_expected=labels_expected, + sqla_query=qry, + prequeries=prequeries, + ) + def _get_series_orderby( self, series_limit_metric: Metric, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index cc3b34ae627cb..0790e3709abd6 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -14,22 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" -import dataclasses +# pylint: disable=too-many-lines import json import logging import re import uuid -from collections import defaultdict from datetime import datetime, timedelta from json.decoder import JSONDecodeError from typing import ( Any, cast, Dict, - Hashable, List, + Mapping, NamedTuple, Optional, Set, @@ -73,7 +71,6 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( AdvancedDataTypeResponseError, - ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, SupersetSecurityException, @@ -91,13 +88,7 @@ QueryObjectDict, ) from superset.utils import core as utils -from superset.utils.core import ( - GenericDataType, - get_column_name, - get_user_id, - is_adhoc_column, - remove_duplicates, -) +from superset.utils.core import get_user_id from superset.utils.dates import datetime_to_epoch if TYPE_CHECKING: @@ -677,8 +668,6 @@ def clone_model( # todo(hugh): centralize where this code lives class QueryStringExtended(NamedTuple): applied_template_filters: Optional[List[str]] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] labels_expected: List[str] prequeries: List[str] sql: str @@ -686,8 +675,6 @@ class QueryStringExtended(NamedTuple): class SqlaQuery(NamedTuple): applied_template_filters: List[str] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] cte: Optional[str] extra_cache_keys: List[Any] labels_expected: List[str] @@ -711,18 +698,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods } @property - def fetch_value_predicate(self) -> str: - return "fix this!" - - @property - def type(self) -> str: - raise NotImplementedError() - - @property - def db_extra(self) -> Optional[Dict[str, Any]]: - raise NotImplementedError() - - def query(self, query_obj: QueryObjectDict) -> QueryResult: + def query(self) -> str: raise NotImplementedError() @property @@ -735,7 +711,7 @@ def owners_data(self) -> List[Any]: @property def metrics(self) -> List[Any]: - return [] + raise NotImplementedError() @property def uid(self) -> str: @@ -785,59 +761,17 @@ def sql(self) -> str: def columns(self) -> List[Any]: raise NotImplementedError() - def get_fetch_values_predicate( - self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> TextClause: + @property + def get_fetch_values_predicate(self) -> List[Any]: raise NotImplementedError() - def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: + @staticmethod + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: raise NotImplementedError() - def get_sqla_row_level_filters( - self, - template_processor: BaseTemplateProcessor, - ) -> List[TextClause]: - """ - Return the appropriate row level security filters for this table and the - current user. A custom username can be passed when the user is not present in the - Flask global namespace. - - :param template_processor: The template processor to apply to the filters. - :returns: A list of SQL clauses to be ANDed together. - """ - all_filters: List[TextClause] = [] - filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) - try: - for filter_ in security_manager.get_rls_filters(self): - clause = self.text( - f"({template_processor.process_template(filter_.clause)})" - ) - if filter_.group_key: - filter_groups[filter_.group_key].append(clause) - else: - all_filters.append(clause) - - if is_feature_enabled("EMBEDDED_SUPERSET"): - for rule in security_manager.get_guest_rls_filters(self): - clause = self.text( - f"({template_processor.process_template(rule['clause'])})" - ) - all_filters.append(clause) - - grouped_filters = [or_(*clauses) for clauses in filter_groups.values()] - all_filters.extend(grouped_filters) - return all_filters - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in RLS filters: %(msg)s", - msg=ex.message, - ) - ) from ex - def _process_sql_expression( # pylint: disable=no-self-use self, expression: Optional[str], @@ -936,19 +870,14 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) - def get_query_str_extended( - self, query_obj: QueryObjectDict, mutate: bool = True - ) -> QueryStringExtended: + def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) - if mutate: - sql = self.mutate_query_from_config(sql) + sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, - applied_filter_columns=sqlaq.applied_filter_columns, - rejected_filter_columns=sqlaq.rejected_filter_columns, labels_expected=sqlaq.labels_expected, prequeries=sqlaq.prequeries, sql=sql, @@ -1073,16 +1002,9 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: logger.warning( "Query %s on schema %s failed", sql, self.schema, exc_info=True ) - db_engine_spec = self.db_engine_spec - errors = [ - dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex) - ] error_message = utils.error_msg_from_exception(ex) return QueryResult( - applied_template_filters=query_str_ext.applied_template_filters, - applied_filter_columns=query_str_ext.applied_filter_columns, - rejected_filter_columns=query_str_ext.rejected_filter_columns, status=status, df=df, duration=datetime.now() - qry_start_dttm, @@ -1152,7 +1074,7 @@ def get_from_clause( def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument + columns_by_name: Dict[str, "TableColumn"], # # pylint: disable=unused-argument template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ @@ -1252,20 +1174,19 @@ def get_query_str(self, query_obj: QueryObjectDict) -> str: def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Dict[str, "SqlMetric"], - columns_by_name: Dict[str, "TableColumn"], - template_processor: Optional[BaseTemplateProcessor] = None, + metrics_by_name: Mapping[str, "SqlMetric"], + columns_by_name: Mapping[str, "TableColumn"], ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) - ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name) + ob = self.adhoc_metric_to_sqla( + series_limit_metric, columns_by_name # type: ignore + ) elif ( isinstance(series_limit_metric, str) and series_limit_metric in metrics_by_name ): - ob = metrics_by_name[series_limit_metric].get_sqla_col( - template_processor=template_processor - ) + ob = metrics_by_name[series_limit_metric].get_sqla_col() else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=series_limit_metric) @@ -1274,11 +1195,26 @@ def _get_series_orderby( def adhoc_column_to_sqla( self, - col: "AdhocColumn", # type: ignore - force_type_check: bool = False, + col: Type["AdhocColumn"], # type: ignore template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - raise NotImplementedError() + """ + Turn an adhoc column into a sqlalchemy column. + + :param col: Adhoc column definition + :param template_processor: template_processor instance + :returns: The metric defined as a sqlalchemy column + :rtype: sqlalchemy.sql.column + """ + label = utils.get_column_name(col) # type: ignore + expression = self._process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + sqla_column = literal_column(expression) + return self.make_sqla_column_compatible(sqla_column, label) def _get_top_groups( self, @@ -1316,30 +1252,29 @@ def dttm_sql_literal(self, dttm: sa.DateTime, col_type: Optional[str]) -> str: return f'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}' - def get_time_filter( # pylint: disable=too-many-arguments + def get_time_filter( self, - time_col: "TableColumn", + time_col: Dict[str, Any], start_dttm: Optional[sa.DateTime], end_dttm: Optional[sa.DateTime], - label: Optional[str] = "__time", - template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - col = self.convert_tbl_column_to_sqla_col( - time_col, label=label, template_processor=template_processor - ) + label = "__time" + col = time_col.get("column_name") + sqla_col = literal_column(col) + my_col = self.make_sqla_column_compatible(sqla_col, label) l = [] if start_dttm: l.append( - col + my_col >= self.db_engine_spec.get_text_clause( - self.dttm_sql_literal(start_dttm, time_col.type) + self.dttm_sql_literal(start_dttm, time_col.get("type")) ) ) if end_dttm: l.append( - col + my_col < self.db_engine_spec.get_text_clause( - self.dttm_sql_literal(end_dttm, time_col.type) + self.dttm_sql_literal(end_dttm, time_col.get("type")) ) ) return and_(*l) @@ -1403,24 +1338,11 @@ def get_timestamp_expression( time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain) return self.make_sqla_column_compatible(time_expr, label) - def convert_tbl_column_to_sqla_col( - self, - tbl_column: "TableColumn", - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, - ) -> Column: - label = label or tbl_column.column_name - db_engine_spec = self.db_engine_spec - column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra) - type_ = column_spec.sqla_type if column_spec else None - if expression := tbl_column.expression: - if template_processor: - expression = template_processor.process_template(expression) - col = literal_column(expression, type_=type_) - else: - col = sa.column(tbl_column.column_name, type_=type_) - col = self.make_sqla_column_compatible(col, label) - return col + def get_sqla_col(self, col: Dict[str, Any]) -> Column: + label = col.get("column_name") + col_type = col.get("type") + col = sa.column(label, type_=col_type) + return self.make_sqla_column_compatible(col, label) def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, @@ -1467,13 +1389,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma "time_column": granularity, "time_grain": time_grain, "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.column_name for col in self.columns], + "table_columns": [col.get("column_name") for col in self.columns], "filter": filter, } columns = columns or [] groupby = groupby or [] - rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] - applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: @@ -1498,11 +1418,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma granularity = self.main_dttm_col columns_by_name: Dict[str, "TableColumn"] = { - col.column_name: col for col in self.columns - } - - metrics_by_name: Dict[str, "SqlMetric"] = { - m.metric_name: m for m in self.metrics + col.get("column_name"): col + for col in self.columns # col.column_name: col for col in self.columns } if not granularity and is_timeseries: @@ -1526,12 +1443,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma template_processor=template_processor, ) ) - elif isinstance(metric, str) and metric in metrics_by_name: - metrics_exprs.append( - metrics_by_name[metric].get_sqla_col( - template_processor=template_processor - ) - ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=metric) @@ -1570,17 +1481,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - col = self.convert_tbl_column_to_sqla_col( - columns_by_name[col], template_processor=template_processor - ) + gb_column_obj = columns_by_name[col] + if isinstance(gb_column_obj, dict): + col = self.get_sqla_col(gb_column_obj) + else: + col = gb_column_obj.get_sqla_col() elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True - elif col in metrics_by_name: - col = metrics_by_name[col].get_sqla_col( - template_processor=template_processor - ) - need_groupby = True if isinstance(col, ColumnElement): orderby_exprs.append(col) @@ -1606,24 +1514,33 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # if groupby field/expr equals granularity field/expr if selected == granularity: table_col = columns_by_name[selected] - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) + if isinstance(table_col, dict): + outer = self.get_timestamp_expression( + column=table_col, + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) + else: + outer = table_col.get_timestamp_expression( + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) # if groupby field equals a selected column elif selected in columns_by_name: - outer = self.convert_tbl_column_to_sqla_col( - columns_by_name[selected], - template_processor=template_processor, - ) + if isinstance(columns_by_name[selected], dict): + outer = sa.column(f"{selected}") + outer = self.make_sqla_column_compatible(outer, selected) + else: + outer = columns_by_name[selected].get_sqla_col() else: - selected = validate_adhoc_subquery( + selected = self.validate_adhoc_subquery( selected, self.database_id, self.schema, ) - outer = literal_column(f"({selected})") + outer = sa.column(f"{selected}") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( @@ -1637,28 +1554,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma select_exprs.append(outer) elif columns: for selected in columns: - if is_adhoc_column(selected): - _sql = selected["sqlExpression"] - _column_label = selected["label"] - elif isinstance(selected, str): - _sql = selected - _column_label = selected - - selected = validate_adhoc_subquery( - _sql, + selected = self.validate_adhoc_subquery( + selected, self.database_id, self.schema, ) - - select_exprs.append( - self.convert_tbl_column_to_sqla_col( - columns_by_name[selected], template_processor=template_processor - ) - if isinstance(selected, str) and selected in columns_by_name - else self.make_sqla_column_compatible( - literal_column(selected), _column_label + if isinstance(columns_by_name[selected], dict): + select_exprs.append(sa.column(f"{selected}")) + else: + select_exprs.append( + columns_by_name[selected].get_sqla_col() + if selected in columns_by_name + else self.make_sqla_column_compatible(literal_column(selected)) ) - ) metrics_exprs = [] if granularity: @@ -1669,43 +1577,57 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col=granularity, ) ) - time_filters = [] + time_filters: List[Any] = [] if is_timeseries: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) + if isinstance(dttm_col, dict): + timestamp = self.get_timestamp_expression( + dttm_col, time_grain, template_processor=template_processor + ) + else: + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns. - if ( - db_engine_spec.time_secondary_columns - and self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col.column_name - ): - time_filters.append( - self.get_time_filter( - time_col=columns_by_name[self.main_dttm_col], - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) + if db_engine_spec.time_secondary_columns: + if isinstance(dttm_col, dict): + dttm_col_name = dttm_col.get("column_name") + else: + dttm_col_name = dttm_col.column_name - time_filter_column = self.get_time_filter( - time_col=dttm_col, - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - time_filters.append(time_filter_column) + if ( + self.main_dttm_col in self.dttm_cols + and self.main_dttm_col != dttm_col_name + ): + if isinstance(self.main_dttm_col, dict): + time_filters.append( + self.get_time_filter( + self.main_dttm_col, + from_dttm, + to_dttm, + ) + ) + else: + time_filters.append( + columns_by_name[self.main_dttm_col].get_time_filter( + from_dttm, + to_dttm, + ) + ) + + if isinstance(dttm_col, dict): + time_filters.append(self.get_time_filter(dttm_col, from_dttm, to_dttm)) + else: + time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use # raw columns as custom SQL adhoc metric). - select_exprs = remove_duplicates( + select_exprs = utils.remove_duplicates( select_exprs + metrics_exprs, key=lambda x: x.name ) @@ -1715,7 +1637,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma # Order by columns are "hidden" columns, some databases require them # always be present in SELECT if an aggregation function is used if not db_engine_spec.allows_hidden_orderby_agg: - select_exprs = remove_duplicates(select_exprs + orderby_exprs) + select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs) qry = sa.select(select_exprs) @@ -1737,19 +1659,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma sqla_col: Optional[Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col - elif is_adhoc_column(flt_col): - try: - sqla_col = self.adhoc_column_to_sqla(flt_col, force_type_check=True) - applied_adhoc_filters_columns.append(flt_col) - except ColumnNotFoundException: - rejected_adhoc_filters_columns.append(flt_col) - continue + elif utils.is_adhoc_column(flt_col): + sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore else: col_obj = columns_by_name.get(cast(str, flt_col)) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if get_column_name(flt_col) in removed_filters: + if utils.get_column_name(flt_col) in removed_filters: # Skip generating SQLA filter when the jinja template handles it. continue @@ -1757,29 +1674,44 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if sqla_col is not None: pass elif col_obj and filter_grain: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, template_processor=template_processor - ) + if isinstance(col_obj, dict): + sqla_col = self.get_timestamp_expression( + col_obj, time_grain, template_processor=template_processor + ) + else: + sqla_col = col_obj.get_timestamp_expression( + time_grain=filter_grain, + template_processor=template_processor, + ) + elif col_obj and isinstance(col_obj, dict): + sqla_col = sa.column(col_obj.get("column_name")) elif col_obj: - sqla_col = self.convert_tbl_column_to_sqla_col( - tbl_column=col_obj, template_processor=template_processor - ) - col_type = col_obj.type if col_obj else None + sqla_col = col_obj.get_sqla_col() + + if col_obj and isinstance(col_obj, dict): + col_type = col_obj.get("type") + else: + col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, - # db_extra=self.database.get_extra(), + db_extra=self.database.get_extra(), # type: ignore ) is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) - col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" + if col_obj and isinstance(col_obj, dict): + col_advanced_data_type = "" + else: + col_advanced_data_type = ( + col_obj.advanced_data_type if col_obj else "" + ) if col_spec and not col_advanced_data_type: target_generic_type = col_spec.generic_type else: - target_generic_type = GenericDataType.STRING + target_generic_type = utils.GenericDataType.STRING eq = self.filter_values_handler( values=val, operator=op, @@ -1787,7 +1719,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma target_native_type=col_type, is_list_target=is_list_target, db_engine_spec=db_engine_spec, - # db_extra=self.database.get_extra(), + db_extra=self.database.get_extra(), # type: ignore ) if ( col_advanced_data_type != "" @@ -1843,14 +1775,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma elif op == utils.FilterOperator.IS_FALSE.value: where_clause_and.append(sqla_col.is_(False)) else: - if ( - op - not in { - utils.FilterOperator.EQUALS.value, - utils.FilterOperator.NOT_EQUALS.value, - } - and eq is None - ): + if eq is None: raise QueryObjectValidationError( _( "Must specify a value for filters " @@ -1888,20 +1813,19 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma time_col=col_obj, start_dttm=_since, end_dttm=_until, - label=sqla_col.key, - template_processor=template_processor, ) ) else: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - where_clause_and += self.get_sqla_row_level_filters(template_processor) + # todo(hugh): fix this w/ template_processor + # where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: try: - where = template_processor.process_template(f"({where})") + where = template_processor.process_template(f"{where}") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1909,17 +1833,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - where = self._process_sql_expression( - expression=where, - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(f"({having})") + having = template_processor.process_template(f"{having}") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1927,18 +1845,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - having = self._process_sql_expression( - expression=having, - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) having_clause_and += [self.text(having)] - if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore - qry = qry.where( - self.get_fetch_values_predicate(template_processor=template_processor) - ) + qry = qry.where(self.get_fetch_values_predicate()) # type: ignore if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -1978,7 +1887,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_series_columns.items(): - label = get_column_name(gby_name) + label = utils.get_column_name(gby_name) inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) @@ -1988,25 +1897,26 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_time_filter = [] if dttm_col and not db_engine_spec.time_groupby_inline: - inner_time_filter = [ - self.get_time_filter( - time_col=dttm_col, - start_dttm=inner_from_dttm or from_dttm, - end_dttm=inner_to_dttm or to_dttm, - template_processor=template_processor, - ) - ] + if isinstance(dttm_col, dict): + inner_time_filter = [ + self.get_time_filter( + dttm_col, + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, + ) + ] + else: + inner_time_filter = [ + dttm_col.get_time_filter( + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, + ) + ] + subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr - if series_limit_metric: - ob = self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) direction = sa.desc if order_desc else sa.asc subq = subq.order_by(direction(ob)) subq = subq.limit(series_limit) @@ -2020,19 +1930,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma on_clause.append(gby_obj == sa.column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) - else: - if series_limit_metric: - orderby = [ - ( - self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ), - not order_desc, - ) - ] # run prequery to get top groups prequery_obj = { @@ -2049,8 +1946,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma "columns": columns, "order_desc": True, } - - result = self.query(prequery_obj) + result = self.exc_query(prequery_obj) prequeries.append(result.query) dimensions = [ c @@ -2074,29 +1970,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] - filter_columns = [flt.get("col") for flt in filter] if filter else [] - rejected_filter_columns = [ - col - for col in filter_columns - if col - and not is_adhoc_column(col) - and col not in self.column_names - and col not in applied_template_filters - ] + rejected_adhoc_filters_columns - - applied_filter_columns = [ - col - for col in filter_columns - if col - and not is_adhoc_column(col) - and (col in self.column_names or col in applied_template_filters) - ] + applied_adhoc_filters_columns - return SqlaQuery( applied_template_filters=applied_template_filters, cte=cte, - applied_filter_columns=applied_filter_columns, - rejected_filter_columns=rejected_filter_columns, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry, diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index d37ed440db862..7a000e839093a 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -19,7 +19,7 @@ import logging import re from datetime import datetime -from typing import Any, Dict, Hashable, List, Optional, Type, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING import simplejson as json import sqlalchemy as sqla @@ -52,10 +52,9 @@ ) from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sqllab.limiting_factor import LimitingFactor -from superset.utils.core import QueryStatus, user_label +from superset.utils.core import GenericDataType, QueryStatus, user_label if TYPE_CHECKING: - from superset.connectors.sqla.models import TableColumn from superset.db_engine_specs import BaseEngineSpec @@ -184,33 +183,47 @@ def sql_tables(self) -> List[Table]: return list(ParsedQuery(self.sql).tables) @property - def columns(self) -> List["TableColumn"]: - from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel - TableColumn, + def columns(self) -> List[Dict[str, Any]]: + bool_types = ("BOOL",) + num_types = ( + "DOUBLE", + "FLOAT", + "INT", + "BIGINT", + "NUMBER", + "LONG", + "REAL", + "NUMERIC", + "DECIMAL", + "MONEY", ) - + date_types = ("DATE", "TIME") + str_types = ("VARCHAR", "STRING", "CHAR") columns = [] + col_type = "" for col in self.extra.get("columns", []): - columns.append( - TableColumn( - column_name=col["name"], - type=col["type"], - is_dttm=col["is_dttm"], - groupby=True, - filterable=True, - ) - ) + computed_column = {**col} + col_type = col.get("type") + + if col_type and any(map(lambda t: t in col_type.upper(), str_types)): + computed_column["type_generic"] = GenericDataType.STRING + if col_type and any(map(lambda t: t in col_type.upper(), bool_types)): + computed_column["type_generic"] = GenericDataType.BOOLEAN + if col_type and any(map(lambda t: t in col_type.upper(), num_types)): + computed_column["type_generic"] = GenericDataType.NUMERIC + if col_type and any(map(lambda t: t in col_type.upper(), date_types)): + computed_column["type_generic"] = GenericDataType.TEMPORAL + + computed_column["column_name"] = col.get("name") + computed_column["groupby"] = True + columns.append(computed_column) return columns - @property - def db_extra(self) -> Optional[Dict[str, Any]]: - return None - @property def data(self) -> Dict[str, Any]: order_by_choices = [] for col in self.columns: - column_name = str(col.column_name or "") + column_name = str(col.get("column_name") or "") order_by_choices.append( (json.dumps([column_name, True]), f"{column_name} " + __("[asc]")) ) @@ -224,7 +237,7 @@ def data(self) -> Dict[str, Any]: ], "filter_select": True, "name": self.tab_name, - "columns": [o.data for o in self.columns], + "columns": self.columns, "metrics": [], "id": self.id, "type": self.type, @@ -267,7 +280,7 @@ def cache_timeout(self) -> int: @property def column_names(self) -> List[Any]: - return [col.column_name for col in self.columns] + return [col.get("column_name") for col in self.columns] @property def offset(self) -> int: @@ -282,7 +295,7 @@ def main_dttm_col(self) -> Optional[str]: @property def dttm_cols(self) -> List[Any]: - return [col.column_name for col in self.columns if col.is_dttm] + return [col.get("column_name") for col in self.columns if col.get("is_dttm")] @property def schema_perm(self) -> str: @@ -297,7 +310,7 @@ def default_endpoint(self) -> str: return "" @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]: + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: return [] @property @@ -325,7 +338,7 @@ def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]: if not column_name: return None for col in self.columns: - if col.column_name == column_name: + if col.get("column_name") == column_name: return col return None diff --git a/superset/utils/core.py b/superset/utils/core.py index 8cf1076f5eef3..460c17b949dac 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1732,13 +1732,13 @@ def extract_dataframe_dtypes( if datasource: for column in datasource.columns: if isinstance(column, dict): - columns_by_name[column.get("column_name")] = column + columns_by_name[column.get("column_name")] = column # type: ignore else: columns_by_name[column.column_name] = column generic_types: List[GenericDataType] = [] for column in df.columns: - column_object = columns_by_name.get(column) + column_object = columns_by_name.get(column) # type: ignore series = df[column] inferred_type = infer_dtype(series) if isinstance(column_object, dict): @@ -1786,9 +1786,15 @@ def get_time_filter_status( datasource: "BaseDatasource", applied_time_extras: Dict[str, str], ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - temporal_columns: Set[Any] = { - col.column_name for col in datasource.columns if col.is_dttm - } + temporal_columns: Set[Any] + if datasource.type == "query": + temporal_columns = { + col.get("column_name") for col in datasource.columns if col.get("is_dttm") + } + else: + temporal_columns = { + col.column_name for col in datasource.columns if col.is_dttm + } applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL) diff --git a/superset/views/core.py b/superset/views/core.py index 1d9d80c5b187a..7576f042e8ea5 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2032,7 +2032,7 @@ def sqllab_viz(self) -> FlaskResponse: # pylint: disable=no-self-use db.session.add(table) cols = [] for config_ in data.get("columns"): - column_name = config_.get("column_name") or config_.get("name") + column_name = config_.get("name") col = TableColumn( column_name=column_name, filterable=True, diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index db81488c3f9d2..821c80ec42083 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -1242,8 +1242,8 @@ def test_chart_cache_timeout_chart_not_found( [ (200, {"where": "1 = 1"}), (200, {"having": "count(*) > 0"}), - (403, {"where": "col1 in (select distinct col1 from physical_dataset)"}), - (403, {"having": "count(*) > (select count(*) from physical_dataset)"}), + (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), ], ) @with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 27ccdde96be29..d9f26239d1394 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -493,16 +493,8 @@ def test_sqllab_viz(self): "datasourceName": f"test_viz_flow_table_{random()}", "schema": "superset", "columns": [ - { - "is_dttm": False, - "type": "STRING", - "column_name": f"viz_type_{random()}", - }, - { - "is_dttm": False, - "type": "OBJECT", - "column_name": f"ccount_{random()}", - }, + {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"}, + {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"}, ], "sql": """\ SELECT * @@ -531,16 +523,8 @@ def test_sqllab_viz_bad_payload(self): "chartType": "dist_bar", "schema": "superset", "columns": [ - { - "is_dttm": False, - "type": "STRING", - "column_name": f"viz_type_{random()}", - }, - { - "is_dttm": False, - "type": "OBJECT", - "column_name": f"ccount_{random()}", - }, + {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"}, + {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"}, ], "sql": """\ SELECT *