From 8659381fa512ea5c7b16340b67883250d9c62969 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Thu, 21 Apr 2022 11:32:12 -0700 Subject: [PATCH] create mixing for chart query --- superset/config.py | 13 ++++- superset/models/helpers.py | 36 ++++++++++++ superset/models/sql_lab.py | 3 +- tests/integration_tests/charts/api_tests.py | 61 +++++++++++++++++++++ tests/integration_tests/fixtures/query.py | 57 +++++++++++++++++++ 5 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 tests/integration_tests/fixtures/query.py diff --git a/superset/config.py b/superset/config.py index 62198c484c64a..33313a89488fa 100644 --- a/superset/config.py +++ b/superset/config.py @@ -659,6 +659,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: [ ("superset.connectors.sqla.models", ["SqlaTable"]), ("superset.connectors.druid.models", ["DruidDatasource"]), + ("superset.models.sql_lab", ["Query"]), ] ) ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {} @@ -984,7 +985,11 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # Provide a callable that receives a tracking_url and returns another # URL. This is used to translate internal Hadoop job tracker URL # into a proxied one -TRACKING_URL_TRANSFORMER = lambda x: x + + +def TRACKING_URL_TRANSFORMER(x): + return x + # Interval between consecutive polls when using Hive Engine HIVE_POLL_INTERVAL = int(timedelta(seconds=5).total_seconds()) @@ -1203,7 +1208,11 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument # to allow mutating the object with this callback. # This can be used to set any properties of the object based on naming # conventions and such. You can find examples in the tests. -SQLA_TABLE_MUTATOR = lambda table: table + + +def SQLA_TABLE_MUTATOR(table): + return table + # Global async query config options. # Requires GLOBAL_ASYNC_QUERIES feature flag to be enabled. diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 3b4e99159f0b8..185dd9d724530 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -39,6 +39,7 @@ from sqlalchemy.orm.exc import MultipleResultsFound from sqlalchemy_utils import UUIDType +from superset import security_manager from superset.common.db_query_status import QueryStatus logger = logging.getLogger(__name__) @@ -543,3 +544,38 @@ def clone_model( data.update(kwargs) return target.__class__(**data) + + +class ExploreMixin: + """ + Sets up data to allow an object to be used to power a chart + """ + + @property + def database(self): + raise NotImplementedError + + @property + def schema(self): + raise NotImplementedError + + # @property + # def type(self) -> str: + # return f"{self.__class__.__name__.lower()}" + + type = "query" + + @staticmethod + def default_query(qry): + return qry + + @property + def perm(self) -> Optional[str]: + return f"[{self.database.database_name}].(id:{self.database.id})" + + def get_perm(self) -> Optional[str]: + return self.perm + + def get_schema_perm(self) -> Optional[str]: + """Returns schema permission if present, database one otherwise.""" + return security_manager.get_schema_perm(self.database, self.schema) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 04d5fc9a94359..2039eca945028 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -42,6 +42,7 @@ from superset import security_manager from superset.models.helpers import ( AuditMixinNullable, + ExploreMixin, ExtraJSONMixin, ImportExportMixin, ) @@ -51,7 +52,7 @@ from superset.utils.core import QueryStatus, user_label -class Query(Model, ExtraJSONMixin): +class Query(Model, ExtraJSONMixin, ExploreMixin): """ORM model for SQL query Now that SQL Lab support multi-statement execution, an entry in this diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 6b8d625d567e3..927cd4f9e3b5e 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -18,6 +18,7 @@ """Unit tests for Superset""" import json from io import BytesIO +from superset.models.sql_lab import Query from zipfile import is_zipfile, ZipFile import prison @@ -51,6 +52,7 @@ dataset_config, dataset_metadata_config, ) +from tests.integration_tests.fixtures.query import get_query_datasource from tests.integration_tests.fixtures.unicode_dashboard import ( load_unicode_dashboard_with_slice, load_unicode_data, @@ -143,6 +145,35 @@ def create_chart_with_report(self): db.session.delete(chart) db.session.commit() + @pytest.fixture() + def create_charts_from_query(self): + with self.create_app().app_context(): + charts = [] + admin = self.get_user("admin") + query = db.session.query(Query).first() + for cx in range(CHARTS_FIXTURE_COUNT - 1): + charts.append( + self.insert_chart( + f"name{cx}", [admin.id], query.id, datasource_type=query.type + ) + ) + fav_charts = [] + for cx in range(round(CHARTS_FIXTURE_COUNT / 2)): + fav_star = FavStar( + user_id=admin.id, class_name="slice", obj_id=charts[cx].id + ) + db.session.add(fav_star) + db.session.commit() + fav_charts.append(fav_star) + yield charts + + # rollback changes + for chart in charts: + db.session.delete(chart) + for fav_chart in fav_charts: + db.session.delete(fav_chart) + db.session.commit() + @pytest.fixture() def add_dashboard_to_chart(self): with self.create_app().app_context(): @@ -534,6 +565,36 @@ def test_create_chart_validate_datasource(self): response, {"message": {"datasource_id": ["Dataset does not exist"]}} ) + @pytest.mark.usefixtures("get_query_datasource", "create_charts_from_query") + def test_create_chart_from_query(self): + """ + Chart API: Test create chart from Query + """ + dashboards_ids = get_dashboards_ids(db, ["world_health", "births"]) + admin_id = self.get_user("admin").id + query_id = db.session.query(Query).first().id + chart_data = { + "slice_name": "name1", + "description": "description1", + "owners": [admin_id], + "viz_type": "viz_type1", + "params": "1234", + "cache_timeout": 1000, + "datasource_id": query_id, + "datasource_type": "query", + "dashboards": dashboards_ids, + "certified_by": "John Doe", + "certification_details": "Sample certification", + } + self.login(username="admin") + uri = f"api/v1/chart/" + rv = self.post_assert_metric(uri, chart_data, "post") + self.assertEqual(rv.status_code, 201) + data = json.loads(rv.data.decode("utf-8")) + model = db.session.query(Slice).get(data.get("id")) + db.session.delete(model) + db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_update_chart(self): """ diff --git a/tests/integration_tests/fixtures/query.py b/tests/integration_tests/fixtures/query.py new file mode 100644 index 0000000000000..6ee79d16b8597 --- /dev/null +++ b/tests/integration_tests/fixtures/query.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import uuid +from typing import Any, Dict + +import pytest +from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table + +from superset.extensions import db +from superset.models.core import Database +from superset.models.sql_lab import Query +from tests.integration_tests.test_app import app + + +@pytest.fixture() +def get_query_datasource(): + with app.app_context(): + engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) + meta = MetaData() + + students = Table( + "students", + meta, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("lastname", String), + Column("ds", Date), + ) + meta.create_all(engine) + + students.insert().values(name="George", ds="2021-01-01") + + query = Query( + database_id=db.session.query(Database).first().id, + client_id=str(uuid.uuid4())[0:10], + sql="select * from students", + ) + db.session.add(query) + db.session.commit() + yield query + + # rollback changes + # todo