From 8a6ecd3b972322858a67ad6581c7785f0c549bd6 Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Sun, 21 Nov 2021 14:35:46 +0200 Subject: [PATCH] refactor(QueryContext): add QueryContextFactory (#17495) --- superset/charts/schemas.py | 19 ++++-- superset/common/query_context.py | 45 +++++-------- superset/common/query_context_factory.py | 83 ++++++++++++++++++++++++ superset/models/slice.py | 22 +++++-- superset/views/api.py | 24 +++++-- superset/viz.py | 17 +++-- 6 files changed, 162 insertions(+), 48 deletions(-) create mode 100644 superset/common/query_context_factory.py diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 21d4a006a4f19..46aa9a48af0e5 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=too-many-lines -from typing import Any, Dict +from __future__ import annotations + +from typing import Any, Dict, Optional, TYPE_CHECKING from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate @@ -24,7 +26,7 @@ from superset import app from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType -from superset.common.query_context import QueryContext +from superset.common.query_context_factory import QueryContextFactory from superset.db_engine_specs.base import builtin_time_grains from superset.utils import schema as utils from superset.utils.core import ( @@ -35,6 +37,9 @@ TimeRangeEndpoint, ) +if TYPE_CHECKING: + from superset.common.query_context import QueryContext + config = app.config # @@ -1129,6 +1134,7 @@ class Meta: # pylint: disable=too-few-public-methods class ChartDataQueryContextSchema(Schema): + query_context_factory: Optional[QueryContextFactory] = None datasource = fields.Nested(ChartDataDatasourceSchema) queries = fields.List(fields.Nested(ChartDataQueryObjectSchema)) force = fields.Boolean( @@ -1139,13 +1145,16 @@ class ChartDataQueryContextSchema(Schema): result_type = EnumField(ChartDataResultType, by_value=True) result_format = EnumField(ChartDataResultFormat, by_value=True) - # pylint: disable=no-self-use,unused-argument + # pylint: disable=unused-argument @post_load def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext: - query_context = QueryContext(**data) + query_context = self.get_query_context_factory().create(**data) return query_context - # pylint: enable=no-self-use,unused-argument + def get_query_context_factory(self) -> QueryContextFactory: + if self.query_context_factory is None: + self.query_context_factory = QueryContextFactory() + return self.query_context_factory class AnnotationDataSchema(Schema): diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 8f17db985ed77..310adc6ad9079 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -26,17 +26,14 @@ from pandas import DateOffset from typing_extensions import TypedDict -from superset import app, db, is_feature_enabled +from superset import app, is_feature_enabled from superset.annotation_layers.dao import AnnotationLayerDAO from superset.charts.dao import ChartDAO from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results from superset.common.query_object import QueryObject -from superset.common.query_object_factory import QueryObjectFactory from superset.common.utils import QueryCacheManager -from superset.connectors.base.models import BaseDatasource -from superset.connectors.connector_registry import ConnectorRegistry from superset.constants import CacheRegion from superset.exceptions import QueryObjectValidationError, SupersetException from superset.extensions import cache_manager, security_manager @@ -44,7 +41,6 @@ from superset.utils import csv from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( - DatasourceDict, DTTM_ALIAS, error_msg_from_exception, get_column_names_from_columns, @@ -57,6 +53,7 @@ from superset.views.utils import get_viz if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource from superset.stats_logger import BaseStatsLogger config = app.config @@ -70,10 +67,6 @@ class CachedTimeOffset(TypedDict): cache_keys: List[Optional[str]] -def create_query_object_factory() -> QueryObjectFactory: - return QueryObjectFactory(config, ConnectorRegistry(), db.session) - - class QueryContext: """ The query context contains the query object and additional fields necessary @@ -90,36 +83,28 @@ class QueryContext: force: bool custom_cache_timeout: Optional[int] + cache_values: Dict[str, Any] + # TODO: Type datasource and query_object dictionary with TypedDict when it becomes # a vanilla python type https://github.com/python/mypy/issues/5288 - # pylint: disable=too-many-arguments def __init__( self, - datasource: DatasourceDict, - queries: List[Dict[str, Any]], - result_type: Optional[ChartDataResultType] = None, - result_format: Optional[ChartDataResultFormat] = None, + *, + datasource: BaseDatasource, + queries: List[QueryObject], + result_type: ChartDataResultType, + result_format: ChartDataResultFormat, force: bool = False, custom_cache_timeout: Optional[int] = None, + cache_values: Dict[str, Any] ) -> None: - self.datasource = ConnectorRegistry.get_datasource( - str(datasource["type"]), int(datasource["id"]), db.session - ) - self.result_type = result_type or ChartDataResultType.FULL - self.result_format = result_format or ChartDataResultFormat.JSON - query_object_factory = create_query_object_factory() - self.queries = [ - query_object_factory.create(self.result_type, **query_obj) - for query_obj in queries - ] + self.datasource = datasource + self.result_type = result_type + self.result_format = result_format + self.queries = queries self.force = force self.custom_cache_timeout = custom_cache_timeout - self.cache_values = { - "datasource": datasource, - "queries": queries, - "result_type": self.result_type, - "result_format": self.result_format, - } + self.cache_values = cache_values @staticmethod def left_join_df( diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py new file mode 100644 index 0000000000000..61cf92835b37f --- /dev/null +++ b/superset/common/query_context_factory.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from superset import app, db +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType +from superset.common.query_context import QueryContext +from superset.common.query_object_factory import QueryObjectFactory +from superset.connectors.connector_registry import ConnectorRegistry +from superset.utils.core import DatasourceDict + +if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource + +config = app.config + + +def create_query_object_factory() -> QueryObjectFactory: + return QueryObjectFactory(config, ConnectorRegistry(), db.session) + + +class QueryContextFactory: # pylint: disable=too-few-public-methods + _query_object_factory: QueryObjectFactory + + def __init__(self) -> None: + self._query_object_factory = create_query_object_factory() + + def create( + self, + *, + datasource: DatasourceDict, + queries: List[Dict[str, Any]], + result_type: Optional[ChartDataResultType] = None, + result_format: Optional[ChartDataResultFormat] = None, + force: bool = False, + custom_cache_timeout: Optional[int] = None + ) -> QueryContext: + datasource_model_instance = None + if datasource: + datasource_model_instance = self._convert_to_model(datasource) + result_type = result_type or ChartDataResultType.FULL + result_format = result_format or ChartDataResultFormat.JSON + queries_ = [ + self._query_object_factory.create(result_type, **query_obj) + for query_obj in queries + ] + cache_values = { + "datasource": datasource, + "queries": queries, + "result_type": result_type, + "result_format": result_format, + } + return QueryContext( + datasource=datasource_model_instance, + queries=queries_, + result_type=result_type, + result_format=result_format, + force=force, + custom_cache_timeout=custom_cache_timeout, + cache_values=cache_values, + ) + + # pylint: disable=no-self-use + def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: + return ConnectorRegistry.get_datasource( + str(datasource["type"]), int(datasource["id"]), db.session + ) diff --git a/superset/models/slice.py b/superset/models/slice.py index f4d71953f24de..7e36619fd89c7 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import json import logging from typing import Any, Dict, Optional, Type, TYPE_CHECKING @@ -41,6 +43,7 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext + from superset.common.query_context_factory import QueryContextFactory from superset.connectors.base.models import BaseDatasource metadata = Model.metadata # pylint: disable=no-member @@ -59,6 +62,8 @@ class Slice( # pylint: disable=too-many-public-methods ): """A slice is essentially a report or a view on data""" + query_context_factory: Optional[QueryContextFactory] = None + __tablename__ = "slices" id = Column(Integer, primary_key=True) slice_name = Column(String(250)) @@ -248,13 +253,12 @@ def form_data(self) -> Dict[str, Any]: update_time_range(form_data) return form_data - def get_query_context(self) -> Optional["QueryContext"]: - # pylint: disable=import-outside-toplevel - from superset.common.query_context import QueryContext - + def get_query_context(self) -> Optional[QueryContext]: if self.query_context: try: - return QueryContext(**json.loads(self.query_context)) + return self.get_query_context_factory().create( + **json.loads(self.query_context) + ) except json.decoder.JSONDecodeError as ex: logger.error("Malformed json in slice's query context", exc_info=True) logger.exception(ex) @@ -313,6 +317,14 @@ def icons(self) -> str: def url(self) -> str: return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D" + def get_query_context_factory(self) -> QueryContextFactory: + if self.query_context_factory is None: + # pylint: disable=import-outside-toplevel + from superset.common.query_context_factory import QueryContextFactory + + self.query_context_factory = QueryContextFactory() + return self.query_context_factory + def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None: src_class = target.cls_model diff --git a/superset/views/api.py b/superset/views/api.py index c205ce9b091c2..d4d94ce72346c 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any +from __future__ import annotations + +from typing import Any, TYPE_CHECKING import simplejson as json from flask import request @@ -27,7 +29,6 @@ TimeRangeAmbiguousError, TimeRangeParseFailError, ) -from superset.common.query_context import QueryContext from superset.legacy import update_time_range from superset.models.slice import Slice from superset.typing import FlaskResponse @@ -35,23 +36,30 @@ from superset.utils.date_parser import get_since_until from superset.views.base import api, BaseSupersetView, handle_api_exception +if TYPE_CHECKING: + from superset.common.query_context_factory import QueryContextFactory + get_time_range_schema = {"type": "string"} class Api(BaseSupersetView): + query_context_factory = None + @event_logger.log_this @api @handle_api_exception @has_access_api @expose("/v1/query/", methods=["POST"]) - def query(self) -> FlaskResponse: # pylint: disable=no-self-use + def query(self) -> FlaskResponse: """ Takes a query_obj constructed in the client and returns payload data response for the given query_obj. raises SupersetSecurityException: If the user cannot access the resource """ - query_context = QueryContext(**json.loads(request.form["query_context"])) + query_context = self.get_query_context_factory().create( + **json.loads(request.form["query_context"]) + ) query_context.raise_for_access() result = query_context.get_payload() payload_json = result["queries"] @@ -99,3 +107,11 @@ def time_range(self, **kwargs: Any) -> FlaskResponse: except (ValueError, TimeRangeParseFailError, TimeRangeAmbiguousError) as error: error_msg = {"message": f"Unexpected time range: {error}"} return self.json_response(error_msg, 400) + + def get_query_context_factory(self) -> QueryContextFactory: + if self.query_context_factory is None: + # pylint: disable=import-outside-toplevel + from superset.common.query_context_factory import QueryContextFactory + + self.query_context_factory = QueryContextFactory() + return self.query_context_factory diff --git a/superset/viz.py b/superset/viz.py index bedd3c879b9e6..53bc333224160 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -20,6 +20,8 @@ These objects represent the backend of all the visualizations that Superset can render. """ +from __future__ import annotations + import copy import dataclasses import logging @@ -88,6 +90,7 @@ from superset.utils.hashing import md5_sha_from_str if TYPE_CHECKING: + from superset.common.query_context_factory import QueryContextFactory from superset.connectors.base.models import BaseDatasource config = app.config @@ -2097,6 +2100,7 @@ class FilterBoxViz(BaseViz): """A multi filter, multi-choice filter box to make dashboards interactive""" + query_context_factory: Optional[QueryContextFactory] = None viz_type = "filter_box" verbose_name = _("Filters") is_timeseries = False @@ -2108,9 +2112,6 @@ def query_obj(self) -> QueryObjectDict: return {} def run_extra_queries(self) -> None: - # pylint: disable=import-outside-toplevel - from superset.common.query_context import QueryContext - query_obj = super().query_obj() filters = self.form_data.get("filter_configs") or [] query_obj["row_limit"] = self.filter_row_limit @@ -2127,7 +2128,7 @@ def run_extra_queries(self) -> None: asc = flt.get("asc") if metric and asc is not None: query_obj["orderby"] = [(metric, asc)] - QueryContext( + self.get_query_context_factory().create( datasource={"id": self.datasource.id, "type": self.datasource.type}, queries=[query_obj], ).raise_for_access() @@ -2160,6 +2161,14 @@ def get_data(self, df: pd.DataFrame) -> VizData: data[col] = [] return data + def get_query_context_factory(self) -> QueryContextFactory: + if self.query_context_factory is None: + # pylint: disable=import-outside-toplevel + from superset.common.query_context_factory import QueryContextFactory + + self.query_context_factory = QueryContextFactory() + return self.query_context_factory + class ParallelCoordinatesViz(BaseViz):