Skip to content

Commit

Permalink
Fixes mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina committed Feb 8, 2023
1 parent 1d670ba commit 20a3e45
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from superset.common.query_object_factory import QueryObjectFactory
from superset.datasource.dao import DatasourceDAO
from superset.models.slice import Slice
from superset.superset_typing import AdhocColumn
from superset.utils.core import DatasourceDict, DatasourceType

if TYPE_CHECKING:
Expand Down Expand Up @@ -107,7 +108,7 @@ def _get_slice(self, slice_id: Any) -> Optional[Slice]:
def _process_query_object(
self,
datasource: BaseDatasource,
form_data: Dict[str, Any],
form_data: Optional[Dict[str, Any]],
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
Expand All @@ -117,31 +118,38 @@ def _process_query_object(
def _apply_granularity(
self,
query_object: QueryObject,
form_data: Dict[str, Any],
form_data: Optional[Dict[str, Any]],
datasource: BaseDatasource,
):
) -> None:
temporal_columns = {
column.column_name for column in datasource.columns if column.is_temporal
}
granularity = query_object.granularity
x_axis = form_data.get("x_axis")
x_axis = form_data and form_data.get("x_axis")

if granularity:
filter_to_remove = None
if x_axis in temporal_columns:
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
x_axis_column = next(
(
column
for column in query_object.columns
if column["sqlExpression"] == x_axis
if column == x_axis
or (type(column) is dict and column["sqlExpression"] == x_axis)
),
None,
)
# Replaces x-axis column values with granularity
if x_axis_column:
x_axis_column["sqlExpression"] = granularity
x_axis_column["label"] = granularity
if type(x_axis_column) is dict:
x_axis_column["sqlExpression"] = granularity
x_axis_column["label"] = granularity
else:
query_object.columns = [
granularity if column == x_axis_column else column
for column in query_object.columns
]
for post_processing in query_object.post_processing:
if post_processing.get("operation") == "pivot":
post_processing["options"]["index"] = [granularity]
Expand Down Expand Up @@ -172,7 +180,7 @@ def _apply_granularity(
if filter["col"] != filter_to_remove
]

def _apply_filters(self, query_object: QueryObject):
def _apply_filters(self, query_object: QueryObject) -> None:
if query_object.time_range:
for filter_object in query_object.filter:
if filter_object["op"] == "TEMPORAL_RANGE":
Expand Down

0 comments on commit 20a3e45

Please sign in to comment.