diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index a2732ae5537f1..d993eca279093 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -16,12 +16,21 @@ # under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from superset.common.chart_data import ChartDataResultType from superset.common.query_object import QueryObject from superset.common.utils.time_range_utils import get_since_until_from_time_range -from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType +from superset.constants import NO_TIME_RANGE +from superset.superset_typing import Column +from superset.utils.core import ( + apply_max_row_limit, + DatasourceDict, + DatasourceType, + FilterOperator, + get_xaxis_label, + QueryObjectFilterClause, +) if TYPE_CHECKING: from sqlalchemy.orm import sessionmaker @@ -61,8 +70,11 @@ def create( # pylint: disable=too-many-arguments processed_extras = self._process_extras(extras) result_type = kwargs.setdefault("result_type", parent_result_type) row_limit = self._process_row_limit(row_limit, result_type) + processed_time_range = self._process_time_range( + time_range, kwargs.get("filters"), kwargs.get("columns") + ) from_dttm, to_dttm = get_since_until_from_time_range( - time_range, time_shift, processed_extras + processed_time_range, time_shift, processed_extras ) kwargs["from_dttm"] = from_dttm kwargs["to_dttm"] = to_dttm @@ -99,6 +111,33 @@ def _process_row_limit( ) return apply_max_row_limit(row_limit or default_row_limit) + @staticmethod + def _process_time_range( + time_range: str | None, + filters: list[QueryObjectFilterClause] | None = None, + columns: list[Column] | None = None, + ) -> str: + if time_range is None: + time_range = NO_TIME_RANGE + temporal_flt = [ + flt + for flt in filters or [] + if flt.get("op") == FilterOperator.TEMPORAL_RANGE + ] + if temporal_flt: + # Use the temporal filter as the time range. + # if the temporal filters uses x-axis as the temporal filter + # then use it or use the first temporal filter + xaxis_label = get_xaxis_label(columns or []) + match_flt = [ + flt for flt in temporal_flt if flt.get("col") == xaxis_label + ] + if match_flt: + time_range = cast(str, match_flt[0].get("val")) + else: + time_range = cast(str, temporal_flt[0].get("val")) + return time_range + # light version of the view.utils.core # import view.utils require application context # Todo: move it and the view.utils.core to utils package diff --git a/tests/unit_tests/common/test_process_time_range.py b/tests/unit_tests/common/test_process_time_range.py new file mode 100644 index 0000000000000..12ee6d21aa3c1 --- /dev/null +++ b/tests/unit_tests/common/test_process_time_range.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from superset.common.query_object_factory import QueryObjectFactory +from superset.constants import NO_TIME_RANGE + + +def test_process_time_range(): + """ + correct empty time range + """ + assert QueryObjectFactory._process_time_range(None) == NO_TIME_RANGE + + """ + Use the first temporal filter as time range + """ + filters = [ + {"col": "dttm", "op": "TEMPORAL_RANGE", "val": "2001 : 2002"}, + {"col": "dttm2", "op": "TEMPORAL_RANGE", "val": "2002 : 2003"}, + ] + assert QueryObjectFactory._process_time_range(None, filters) == "2001 : 2002" + + """ + Use the BASE_AXIS temporal filter as time range + """ + columns = [ + { + "columnType": "BASE_AXIS", + "label": "dttm2", + "sqlExpression": "dttm", + } + ] + assert ( + QueryObjectFactory._process_time_range(None, filters, columns) == "2002 : 2003" + )