Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ChartData): move ChartDataResult enums to common #17399

Merged
merged 1 commit into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,14 @@
)
from superset.commands.importers.exceptions import NoValidFilesFoundError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger, security_manager
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import (
ChartDataResultFormat,
ChartDataResultType,
json_int_dttm_ser,
)
from superset.utils.core import json_int_dttm_ser
from superset.utils.screenshots import ChartScreenshot
from superset.utils.urls import get_url_path
from superset.views.base_api import (
Expand Down
8 changes: 2 additions & 6 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@

import pandas as pd

from superset.utils.core import (
ChartDataResultFormat,
DTTM_ALIAS,
extract_dataframe_dtypes,
get_metric_name,
)
from superset.common.chart_data import ChartDataResultFormat
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name


def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
Expand Down
3 changes: 1 addition & 2 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from marshmallow_enum import EnumField

from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import schema as utils
from superset.utils.core import (
AnnotationType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator,
PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation,
Expand Down
40 changes: 40 additions & 0 deletions superset/common/chart_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum


class ChartDataResultFormat(str, Enum):
"""
Chart data response format
"""

CSV = "csv"
JSON = "json"


class ChartDataResultType(str, Enum):
"""
Chart data response type
"""

COLUMNS = "columns"
FULL = "full"
QUERY = "query"
RESULTS = "results"
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"
2 changes: 1 addition & 1 deletion superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from flask_babel import _

from superset import app
from superset.common.chart_data import ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import (
ChartDataResultType,
extract_column_dtype,
extract_dataframe_dtypes,
ExtraFiltersReasonType,
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset import app, db, is_feature_enabled
from superset.annotation_layers.dao import AnnotationLayerDAO
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject
Expand All @@ -42,8 +43,6 @@
from superset.utils import csv
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import (
ChartDataResultFormat,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
error_msg_from_exception,
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
from pandas import DataFrame

from superset import app, db
from superset.common.chart_data import ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
find_duplicates,
Expand Down
2 changes: 1 addition & 1 deletion superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from superset import app
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandException
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.extensions import feature_flag_manager, machine_auth_provider_factory
from superset.models.reports import (
ReportDataFormat,
Expand Down Expand Up @@ -64,7 +65,6 @@
from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError
from superset.utils.celery import session_scope
from superset.utils.core import ChartDataResultFormat, ChartDataResultType
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.screenshots import (
BaseScreenshot,
Expand Down
23 changes: 0 additions & 23 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,29 +174,6 @@ class GenericDataType(IntEnum):
# ROW = 7


class ChartDataResultFormat(str, Enum):
"""
Chart data response format
"""

CSV = "csv"
JSON = "json"


class ChartDataResultType(str, Enum):
"""
Chart data response type
"""

COLUMNS = "columns"
FULL = "full"
QUERY = "query"
RESULTS = "results"
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"


class DatasourceDict(TypedDict):
type: str
id: int
Expand Down
23 changes: 12 additions & 11 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
viz,
)
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
Expand Down Expand Up @@ -459,18 +460,18 @@ def send_data_payload_response(viz_obj: BaseViz, payload: Any) -> FlaskResponse:
def generate_json(
self, viz_obj: BaseViz, response_type: Optional[str] = None
) -> FlaskResponse:
if response_type == utils.ChartDataResultFormat.CSV:
if response_type == ChartDataResultFormat.CSV:
return CsvResponse(
viz_obj.get_csv(), headers=generate_download_headers("csv")
)

if response_type == utils.ChartDataResultType.QUERY:
if response_type == ChartDataResultType.QUERY:
return self.get_query_string_response(viz_obj)

if response_type == utils.ChartDataResultType.RESULTS:
if response_type == ChartDataResultType.RESULTS:
return self.get_raw_results(viz_obj)

if response_type == utils.ChartDataResultType.SAMPLES:
if response_type == ChartDataResultType.SAMPLES:
return self.get_samples(viz_obj)

payload = viz_obj.get_payload()
Expand Down Expand Up @@ -598,19 +599,19 @@ def explore_json(

TODO: break into one endpoint for each return shape"""

response_type = utils.ChartDataResultFormat.JSON.value
responses: List[
Union[utils.ChartDataResultFormat, utils.ChartDataResultType]
] = list(utils.ChartDataResultFormat)
responses.extend(list(utils.ChartDataResultType))
response_type = ChartDataResultFormat.JSON.value
responses: List[Union[ChartDataResultFormat, ChartDataResultType]] = list(
ChartDataResultFormat
)
responses.extend(list(ChartDataResultType))
for response_option in responses:
if request.args.get(response_option) == "true":
response_type = response_option
break

# Verify user has permission to export CSV file
if (
response_type == utils.ChartDataResultFormat.CSV
response_type == ChartDataResultFormat.CSV
and not security_manager.can_access("can_csv", "Superset")
):
return json_error_response(
Expand All @@ -628,7 +629,7 @@ def explore_json(
# TODO: support CSV, SQL query and other non-JSON types
if (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and response_type == utils.ChartDataResultFormat.JSON
and response_type == ChartDataResultFormat.JSON
):
# First, look for the chart query results in the cache.
try:
Expand Down
24 changes: 11 additions & 13 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,13 @@
from superset.models.dashboard import Dashboard
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils import core as utils
from superset.utils.core import (
AnnotationType,
ChartDataResultFormat,
get_example_database,
get_example_default_schema,
get_main_database,
)

from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType

from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.integration_tests.base_tests import (
Expand Down Expand Up @@ -1239,7 +1237,7 @@ def test_chart_data_sample_default_limit(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["result_type"] = ChartDataResultType.SAMPLES
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand All @@ -1258,7 +1256,7 @@ def test_chart_data_sample_custom_limit(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["result_type"] = ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand All @@ -1276,7 +1274,7 @@ def test_chart_data_sql_max_row_sample_limit(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["result_type"] = ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1326,7 +1324,7 @@ def test_chart_data_query_result_type(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["result_type"] = ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)

Expand Down Expand Up @@ -1453,7 +1451,7 @@ def test_chart_data_query_missing_filter(self):
request_payload["queries"][0]["filters"] = [
{"col": "non_existent_filter", "op": "==", "val": "foo"},
]
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["result_type"] = ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1532,7 +1530,7 @@ def test_chart_data_jinja_filter_request(self):
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["result_type"] = ChartDataResultType.QUERY
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "boy"}
]
Expand Down Expand Up @@ -1574,7 +1572,7 @@ def test_chart_data_async_cached_sync_response(self):

class QueryContext:
result_format = ChartDataResultFormat.JSON
result_type = utils.ChartDataResultType.FULL
result_type = ChartDataResultType.FULL

cmd_run_val = {
"query_context": QueryContext(),
Expand All @@ -1585,7 +1583,7 @@ class QueryContext:
ChartDataCommand, "run", return_value=cmd_run_val
) as patched_run:
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.FULL
request_payload["result_type"] = ChartDataResultType.FULL
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1997,8 +1995,8 @@ def test_chart_data_timegrains(self):
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"] = [
{"result_type": utils.ChartDataResultType.TIMEGRAINS},
{"result_type": utils.ChartDataResultType.COLUMNS},
{"result_type": ChartDataResultType.TIMEGRAINS},
{"result_type": ChartDataResultType.COLUMNS},
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
Expand Down
9 changes: 2 additions & 7 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@

from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlMetric
from superset.extensions import cache_manager
from superset.utils.core import (
AdhocMetricExpressionType,
backend,
ChartDataResultFormat,
ChartDataResultType,
TimeRangeEndpoint,
)
from superset.utils.core import AdhocMetricExpressionType, backend, TimeRangeEndpoint
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand Down