From 6560afce74f17f4f45211dbd9f24519b3f0d810d Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Tue, 31 Jan 2023 22:31:38 -0800 Subject: [PATCH 1/8] First draft --- .../src/SqlLab/actions/sqlLab.js | 8 +- .../ExploreCtasResultsButton/index.tsx | 6 +- superset/connectors/sqla/models.py | 3 - superset/datasource/api.py | 78 +++++++++++++++++++ superset/datasource/commands/__init__.py | 16 ++++ superset/datasource/commands/create_table.py | 73 +++++++++++++++++ superset/datasource/commands/exceptions.py | 26 +++++++ superset/datasource/schemas.py | 28 +++++++ superset/views/core.py | 2 + .../integration_tests/datasource/api_tests.py | 75 ++++++++++++++++++ .../datasource/commands_tests.py | 68 ++++++++++++++++ 11 files changed, 373 insertions(+), 10 deletions(-) create mode 100644 superset/datasource/commands/__init__.py create mode 100644 superset/datasource/commands/create_table.py create mode 100644 superset/datasource/commands/exceptions.py create mode 100644 superset/datasource/schemas.py create mode 100644 tests/integration_tests/datasource/commands_tests.py diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index cd63a464b1402..6f24c0ef5b760 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -1485,13 +1485,13 @@ export function createCtasDatasource(vizOptions) { return dispatch => { dispatch(createDatasourceStarted()); return SupersetClient.post({ - endpoint: '/superset/get_or_create_table/', - postPayload: { data: vizOptions }, + endpoint: '/api/v1/datasource/table/get_or_create/', + jsonPayload: vizOptions, }) .then(({ json }) => { - dispatch(createDatasourceSuccess(json)); + dispatch(createDatasourceSuccess(json.result)); - return json; + return json.result; }) .catch(() => { const errorMsg = t('An error occurred while creating the data source'); diff --git a/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx b/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx index 2fe1e14a07853..a4c71139c0d3c 100644 --- a/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx +++ b/superset-frontend/src/SqlLab/components/ExploreCtasResultsButton/index.tsx @@ -48,10 +48,10 @@ const ExploreCtasResultsButton = ({ const dispatch = useDispatch<(dispatch: any) => Promise>(); const buildVizOptions = { - datasourceName: table, + table_name: table, schema, - dbId, - templateParams, + database_id: dbId, + template_params: templateParams, }; const visualize = () => { diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index cffff7363055d..678ed4a122a21 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1918,9 +1918,6 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: error_message=error_message, ) - def get_sqla_table_object(self) -> Table: - return self.database.get_table(self.table_name, schema=self.schema) - def fetch_metadata(self, commit: bool = True) -> MetadataResult: """ Fetches the metadata for the table and merges it in diff --git a/superset/datasource/api.py b/superset/datasource/api.py index be246c915d535..61ff529f3b41c 100644 --- a/superset/datasource/api.py +++ b/superset/datasource/api.py @@ -16,11 +16,18 @@ # under the License. import logging +from flask import request from flask_appbuilder.api import expose, protect, safe +from marshmallow import ValidationError from superset import app, db, event_logger +from superset.connectors.sqla.models import SqlaTable from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError +from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.datasource.commands.create_table import CreateSqlaTableCommand +from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError from superset.datasource.dao import DatasourceDAO +from superset.datasource.schemas import GetOrCreateTableSchema from superset.exceptions import SupersetSecurityException from superset.superset_typing import FlaskResponse from superset.utils.core import apply_max_row_limit, DatasourceType @@ -35,6 +42,77 @@ class DatasourceRestApi(BaseSupersetApi): resource_name = "datasource" openapi_spec_tag = "Datasources" + openapi_spec_component_schemas = (GetOrCreateTableSchema,) + + @expose("/table/get_or_create/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".get_or_create_table", + log_to_statsd=False, + ) + def get_or_create_table(self) -> FlaskResponse: + """Retrieve a table by name, or create it if it does not exist + --- + post: + summary: Retrieve a table by name, or create it if it does not exist + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GetOrCreateTableSchema' + responses: + 200: + description: The ID of the table + content: + application/json: + schema: + type: object + properties: + result: + type: object + properties: + table_id: + type: integer + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + try: + body = GetOrCreateTableSchema().load(request.json) + except ValidationError as ex: + return self.response(400, message=ex.messages) + table_name = body["table_name"] + database_id = body["database_id"] + table = ( + db.session.query(SqlaTable) + .filter_by(database_id=database_id, table_name=table_name) + .one_or_none() + ) + if not table: + try: + table = CreateSqlaTableCommand( + table_name, + database_id, + body.get("schema"), + body.get("template_params"), + ).run() + except DatabaseNotFoundError as ex: + return self.response(404, message=ex.message) + except GetTableFromDatabaseFailedError as ex: + return self.response(400, message=ex.message) + + payload = {"table_id": table.id} + return self.response(200, result=payload) + @expose( "///column//values/", methods=["GET"], diff --git a/superset/datasource/commands/__init__.py b/superset/datasource/commands/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/datasource/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/datasource/commands/create_table.py b/superset/datasource/commands/create_table.py new file mode 100644 index 0000000000000..c661a9fc2965e --- /dev/null +++ b/superset/datasource/commands/create_table.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask import g + +from superset import db +from superset.commands.base import BaseCommand +from superset.connectors.sqla.models import SqlaTable +from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.databases.dao import DatabaseDAO +from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError +from superset.models.core import Database +from superset.views.base import create_table_permissions + +logger = logging.getLogger(__name__) + + +class CreateSqlaTableCommand(BaseCommand): + def __init__( + self, + table_name: str, + database_id: int, + schema: Optional[str] = None, + template_params: Optional[str] = None, + ): + self._table_name = table_name + self._database_id = database_id + self._schema = schema + self._template_params = template_params + self._database: Database = None # type: ignore + + def run(self) -> SqlaTable: + self.validate() + table = SqlaTable(table_name=self._table_name, owners=[g.user]) + table.database = self._database + table.schema = self._schema + table.template_params = self._template_params + db.session.add(table) + table.fetch_metadata() + create_table_permissions(table) + db.session.commit() + return table + + def validate(self) -> None: + database = DatabaseDAO.find_by_id(self._database_id) + if not database: + raise DatabaseNotFoundError() + self._database = database + try: + self._database.get_table(self._table_name, schema=self._schema) + except Exception as ex: + logger.exception( + "Error getting table %s for database %s", + self._table_name, + self._database.id, + ) + raise GetTableFromDatabaseFailedError() from ex diff --git a/superset/datasource/commands/exceptions.py b/superset/datasource/commands/exceptions.py new file mode 100644 index 0000000000000..70bfeee9c075a --- /dev/null +++ b/superset/datasource/commands/exceptions.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ + +from superset.commands.exceptions import CommandException + + +class GetTableFromDatabaseFailedError(CommandException): + message = _( + "Table could not be found, please check your " + "database connection, schema, and table name" + ) diff --git a/superset/datasource/schemas.py b/superset/datasource/schemas.py new file mode 100644 index 0000000000000..de93956792c26 --- /dev/null +++ b/superset/datasource/schemas.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from marshmallow import fields, Schema + + +class GetOrCreateTableSchema(Schema): + table_name = fields.String(required=True, description="Name of table") + database_id = fields.Integer( + required=True, description="ID of database table belongs to" + ) + schema = fields.String( + description="The schema the table belongs to", allow_none=True + ) + template_params = fields.String(description="Template params for the table") diff --git a/superset/views/core.py b/superset/views/core.py index 8d632dcde21bf..6f5255bb25366 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1918,6 +1918,7 @@ def log(self) -> FlaskResponse: # pylint: disable=no-self-use @has_access @expose("/get_or_create_table/", methods=["POST"]) @event_logger.log_this + @deprecated() def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use """Gets or creates a table object with attributes passed to the API. @@ -1947,6 +1948,7 @@ def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use table.schema = data.get("schema") table.template_params = data.get("templateParams") # needed for the table validation. + # fn can be deleted when this endpoint is removed validate_sqlatable(table) db.session.add(table) diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py index 522aa33383e62..2b1cc6812296a 100644 --- a/tests/integration_tests/datasource/api_tests.py +++ b/tests/integration_tests/datasource/api_tests.py @@ -22,6 +22,8 @@ from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable from superset.dao.exceptions import DatasourceTypeNotSupportedError +from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError +from superset.utils.database import get_example_database from tests.integration_tests.base_tests import SupersetTestCase @@ -135,3 +137,76 @@ def test_get_column_values_not_implemented_error(self, get_datasource_mock): response["message"], "Unable to get column values for datasource type: sl_table", ) + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_or_create_table_already_exists(self): + self.login(username="admin") + rv = self.client.post( + "api/v1/datasource/table/get_or_create/", + json={ + "table_name": "virtual_dataset", + "database_id": get_example_database().id, + }, + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + response["result"], {"table_id": self.get_virtual_dataset().id} + ) + + def test_get_or_create_table_database_not_found(self): + self.login(username="admin") + rv = self.client.post( + "api/v1/datasource/table/get_or_create/", + json={"table_name": "virtual_dataset", "database_id": 999}, + ) + self.assertEqual(rv.status_code, 404) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["message"], "Database not found.") + + @patch("superset.datasource.commands.create_table.CreateSqlaTableCommand.run") + def test_get_or_create_table_get_table_fails(self, run_command_mock): + run_command_mock.side_effect = GetTableFromDatabaseFailedError + self.login(username="admin") + rv = self.client.post( + "api/v1/datasource/table/get_or_create/", + json={"table_name": "tbl", "database_id": get_example_database().id}, + ) + self.assertEqual(rv.status_code, 400) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + response["message"], + "Table could not be found, please check your " + "database connection, schema, and table name", + ) + + def test_get_or_create_table_creates_table(self): + self.login(username="admin") + + examples_db = get_example_database() + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api") + engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col") + + rv = self.client.post( + "api/v1/datasource/table/get_or_create/", + json={ + "table_name": "test_create_sqla_table_api", + "database_id": examples_db.id, + "template_params": '{"param": 1}', + }, + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + table = ( + db.session.query(SqlaTable) + .filter_by(table_name="test_create_sqla_table_api") + .one() + ) + self.assertEqual(response["result"], {"table_id": table.id}) + self.assertEqual(table.template_params, '{"param": 1}') + + db.session.delete(table) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_create_sqla_table_api") + db.session.commit() diff --git a/tests/integration_tests/datasource/commands_tests.py b/tests/integration_tests/datasource/commands_tests.py new file mode 100644 index 0000000000000..d4682e6ab0a13 --- /dev/null +++ b/tests/integration_tests/datasource/commands_tests.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import patch + +from superset import db, security_manager +from superset.connectors.sqla.models import SqlaTable +from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.datasource.commands.create_table import CreateSqlaTableCommand +from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError +from superset.utils.database import get_example_database +from tests.integration_tests.base_tests import SupersetTestCase + + +class TestCreateSqlaTableCommand(SupersetTestCase): + def test_database_not_found(self): + self.login(username="admin") + with self.assertRaises(DatabaseNotFoundError): + CreateSqlaTableCommand("table", 9999).run() + + @patch("superset.security.manager.g") + @patch("superset.models.core.Database.get_table") + def test_get_table_from_database_error(self, get_table_mock, mock_g): + mock_g.user = security_manager.find_user("admin") + get_table_mock.side_effect = Exception + with self.assertRaises(GetTableFromDatabaseFailedError): + CreateSqlaTableCommand("table", get_example_database().id).run() + + @patch("superset.security.manager.g") + @patch("superset.datasource.commands.create_table.g") + def test_create_sqla_table_command(self, mock_g, mock_g2): + mock_g.user = security_manager.find_user("admin") + mock_g2.user = mock_g.user + examples_db = get_example_database() + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_command") + engine.execute( + "CREATE TABLE test_create_sqla_table_command AS SELECT 2 as col" + ) + + table = CreateSqlaTableCommand( + "test_create_sqla_table_command", examples_db.id + ).run() + fetched_table = ( + db.session.query(SqlaTable) + .filter_by(table_name="test_create_sqla_table_command") + .one() + ) + self.assertEqual(table, fetched_table) + self.assertEqual([owner.username for owner in table.owners], ["admin"]) + + db.session.delete(table) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_create_sqla_table_command") + db.session.commit() From 40a70ef44b1528526d6aaf70f1c1e9cdac3bdc1b Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Tue, 31 Jan 2023 22:39:59 -0800 Subject: [PATCH 2/8] Re-add method to sqlatable --- superset/connectors/sqla/models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 678ed4a122a21..cffff7363055d 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1918,6 +1918,9 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: error_message=error_message, ) + def get_sqla_table_object(self) -> Table: + return self.database.get_table(self.table_name, schema=self.schema) + def fetch_metadata(self, commit: bool = True) -> MetadataResult: """ Fetches the metadata for the table and merges it in From ac2e8337e3c0793bc7b5e7cb43bf10ef34121453 Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Wed, 1 Feb 2023 09:42:03 -0800 Subject: [PATCH 3/8] Move new endpoint to dataset API --- .../src/SqlLab/actions/sqlLab.js | 2 +- superset/datasets/api.py | 76 +++++++++++++++- superset/datasets/commands/create.py | 2 + superset/datasets/schemas.py | 11 +++ superset/datasource/api.py | 78 ---------------- superset/datasource/commands/__init__.py | 16 ---- superset/datasource/commands/create_table.py | 73 --------------- superset/datasource/commands/exceptions.py | 26 ------ superset/datasource/schemas.py | 28 ------ tests/integration_tests/datasets/api_tests.py | 88 +++++++++++++++++++ .../datasets/commands_tests.py | 51 ++++++++++- .../integration_tests/datasource/api_tests.py | 75 ---------------- .../datasource/commands_tests.py | 68 -------------- 13 files changed, 227 insertions(+), 367 deletions(-) delete mode 100644 superset/datasource/commands/__init__.py delete mode 100644 superset/datasource/commands/create_table.py delete mode 100644 superset/datasource/commands/exceptions.py delete mode 100644 superset/datasource/schemas.py delete mode 100644 tests/integration_tests/datasource/commands_tests.py diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 1e6a0a1e5fde9..ab8abe0edca21 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -1513,7 +1513,7 @@ export function createCtasDatasource(vizOptions) { return dispatch => { dispatch(createDatasourceStarted()); return SupersetClient.post({ - endpoint: '/api/v1/datasource/table/get_or_create/', + endpoint: '/api/v1/dataset/get_or_create/', jsonPayload: vizOptions, }) .then(({ json }) => { diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 925c3c7cb8c71..f9a8b84ed6b0b 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -28,7 +28,7 @@ from flask_babel import ngettext from marshmallow import ValidationError -from superset import event_logger, is_feature_enabled +from superset import db, event_logger, is_feature_enabled from superset.commands.importers.exceptions import NoValidFilesFoundError from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.connectors.sqla.models import SqlaTable @@ -61,6 +61,7 @@ DatasetRelatedObjectsResponse, get_delete_ids_schema, get_export_ids_schema, + GetOrCreateDatasetSchema, ) from superset.utils.core import parse_boolean_string from superset.views.base import DatasourceFilter, generate_download_headers @@ -93,6 +94,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): "refresh", "related_objects", "duplicate", + "get_or_create_dataset", } list_columns = [ "id", @@ -240,6 +242,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): openapi_spec_component_schemas = ( DatasetRelatedObjectsResponse, DatasetDuplicateSchema, + GetOrCreateDatasetSchema, ) list_outer_default_load = True @@ -877,3 +880,74 @@ def import_(self) -> Response: ) command.run() return self.response(200, message="OK") + + @expose("/get_or_create/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".get_or_create_dataset", + log_to_statsd=False, + ) + def get_or_create_dataset(self) -> Response: + """Retrieve a dataset by name, or create it if it does not exist + --- + post: + summary: Retrieve a table by name, or create it if it does not exist + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GetOrCreateDatasetSchema' + responses: + 200: + description: The ID of the table + content: + application/json: + schema: + type: object + properties: + result: + type: object + properties: + table_id: + type: integer + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + body = GetOrCreateDatasetSchema().load(request.json) + except ValidationError as ex: + return self.response(400, message=ex.messages) + table_name = body["table_name"] + database_id = body["database_id"] + table = ( + db.session.query(SqlaTable) + .filter_by(database_id=database_id, table_name=table_name) + .one_or_none() + ) + if not table: + body["database"] = database_id + try: + table = CreateDatasetCommand(body).run() + except DatasetInvalidError as ex: + return self.response_422(message=ex.normalized_messages()) + except DatasetCreateFailedError as ex: + logger.error( + "Error creating model %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_422(message=ex.message) + + payload = {"table_id": table.id} + return self.response(200, result=payload) diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 809eec7a1159a..74c58afb324cc 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -32,6 +32,7 @@ ) from superset.datasets.dao import DatasetDAO from superset.extensions import db +from superset.views.base import create_table_permissions logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def run(self) -> Model: dataset = DatasetDAO.create(self._properties, commit=False) # Updates columns and metrics from the dataset dataset.fetch_metadata(commit=False) + create_table_permissions(dataset) db.session.commit() except (SQLAlchemyError, DAOCreateFailedError) as ex: logger.warning(ex, exc_info=True) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 223324da3aa9b..103359a2c3f03 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -228,6 +228,17 @@ def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: external_url = fields.String(allow_none=True) +class GetOrCreateDatasetSchema(Schema): + table_name = fields.String(required=True, description="Name of table") + database_id = fields.Integer( + required=True, description="ID of database table belongs to" + ) + schema = fields.String( + description="The schema the table belongs to", allow_none=True + ) + template_params = fields.String(description="Template params for the table") + + class DatasetSchema(SQLAlchemyAutoSchema): """ Schema for the ``Dataset`` model. diff --git a/superset/datasource/api.py b/superset/datasource/api.py index 61ff529f3b41c..be246c915d535 100644 --- a/superset/datasource/api.py +++ b/superset/datasource/api.py @@ -16,18 +16,11 @@ # under the License. import logging -from flask import request from flask_appbuilder.api import expose, protect, safe -from marshmallow import ValidationError from superset import app, db, event_logger -from superset.connectors.sqla.models import SqlaTable from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError -from superset.databases.commands.exceptions import DatabaseNotFoundError -from superset.datasource.commands.create_table import CreateSqlaTableCommand -from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError from superset.datasource.dao import DatasourceDAO -from superset.datasource.schemas import GetOrCreateTableSchema from superset.exceptions import SupersetSecurityException from superset.superset_typing import FlaskResponse from superset.utils.core import apply_max_row_limit, DatasourceType @@ -42,77 +35,6 @@ class DatasourceRestApi(BaseSupersetApi): resource_name = "datasource" openapi_spec_tag = "Datasources" - openapi_spec_component_schemas = (GetOrCreateTableSchema,) - - @expose("/table/get_or_create/", methods=["POST"]) - @protect() - @safe - @statsd_metrics - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" - f".get_or_create_table", - log_to_statsd=False, - ) - def get_or_create_table(self) -> FlaskResponse: - """Retrieve a table by name, or create it if it does not exist - --- - post: - summary: Retrieve a table by name, or create it if it does not exist - requestBody: - required: true - content: - application/json: - schema: - $ref: '#/components/schemas/GetOrCreateTableSchema' - responses: - 200: - description: The ID of the table - content: - application/json: - schema: - type: object - properties: - result: - type: object - properties: - table_id: - type: integer - 400: - $ref: '#/components/responses/400' - 401: - $ref: '#/components/responses/401' - 404: - $ref: '#/components/responses/404' - 500: - $ref: '#/components/responses/500' - """ - try: - body = GetOrCreateTableSchema().load(request.json) - except ValidationError as ex: - return self.response(400, message=ex.messages) - table_name = body["table_name"] - database_id = body["database_id"] - table = ( - db.session.query(SqlaTable) - .filter_by(database_id=database_id, table_name=table_name) - .one_or_none() - ) - if not table: - try: - table = CreateSqlaTableCommand( - table_name, - database_id, - body.get("schema"), - body.get("template_params"), - ).run() - except DatabaseNotFoundError as ex: - return self.response(404, message=ex.message) - except GetTableFromDatabaseFailedError as ex: - return self.response(400, message=ex.message) - - payload = {"table_id": table.id} - return self.response(200, result=payload) - @expose( "///column//values/", methods=["GET"], diff --git a/superset/datasource/commands/__init__.py b/superset/datasource/commands/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/superset/datasource/commands/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/superset/datasource/commands/create_table.py b/superset/datasource/commands/create_table.py deleted file mode 100644 index c661a9fc2965e..0000000000000 --- a/superset/datasource/commands/create_table.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from typing import Optional - -from flask import g - -from superset import db -from superset.commands.base import BaseCommand -from superset.connectors.sqla.models import SqlaTable -from superset.databases.commands.exceptions import DatabaseNotFoundError -from superset.databases.dao import DatabaseDAO -from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError -from superset.models.core import Database -from superset.views.base import create_table_permissions - -logger = logging.getLogger(__name__) - - -class CreateSqlaTableCommand(BaseCommand): - def __init__( - self, - table_name: str, - database_id: int, - schema: Optional[str] = None, - template_params: Optional[str] = None, - ): - self._table_name = table_name - self._database_id = database_id - self._schema = schema - self._template_params = template_params - self._database: Database = None # type: ignore - - def run(self) -> SqlaTable: - self.validate() - table = SqlaTable(table_name=self._table_name, owners=[g.user]) - table.database = self._database - table.schema = self._schema - table.template_params = self._template_params - db.session.add(table) - table.fetch_metadata() - create_table_permissions(table) - db.session.commit() - return table - - def validate(self) -> None: - database = DatabaseDAO.find_by_id(self._database_id) - if not database: - raise DatabaseNotFoundError() - self._database = database - try: - self._database.get_table(self._table_name, schema=self._schema) - except Exception as ex: - logger.exception( - "Error getting table %s for database %s", - self._table_name, - self._database.id, - ) - raise GetTableFromDatabaseFailedError() from ex diff --git a/superset/datasource/commands/exceptions.py b/superset/datasource/commands/exceptions.py deleted file mode 100644 index 70bfeee9c075a..0000000000000 --- a/superset/datasource/commands/exceptions.py +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from flask_babel import lazy_gettext as _ - -from superset.commands.exceptions import CommandException - - -class GetTableFromDatabaseFailedError(CommandException): - message = _( - "Table could not be found, please check your " - "database connection, schema, and table name" - ) diff --git a/superset/datasource/schemas.py b/superset/datasource/schemas.py deleted file mode 100644 index de93956792c26..0000000000000 --- a/superset/datasource/schemas.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from marshmallow import fields, Schema - - -class GetOrCreateTableSchema(Schema): - table_name = fields.String(required=True, description="Name of table") - database_id = fields.Integer( - required=True, description="ID of database table belongs to" - ) - schema = fields.String( - description="The schema the table belongs to", allow_none=True - ) - template_params = fields.String(description="Template params for the table") diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 95236af09041e..02010e4feac28 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -34,6 +34,7 @@ DAODeleteFailedError, DAOUpdateFailedError, ) +from superset.datasets.commands.exceptions import DatasetCreateFailedError from superset.datasets.models import Dataset from superset.extensions import db, security_manager from superset.models.core import Database @@ -2302,3 +2303,90 @@ def test_duplicate_invalid_dataset(self): } rv = self.post_assert_metric(uri, table_data, "duplicate") assert rv.status_code == 422 + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_or_create_dataset_already_exists(self): + """ + Dataset API: Test get or create endpoint when table already exists + """ + self.login(username="admin") + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={ + "table_name": "virtual_dataset", + "database_id": get_example_database().id, + }, + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + dataset = ( + db.session.query(SqlaTable) + .filter(SqlaTable.table_name == "virtual_dataset") + .one() + ) + self.assertEqual(response["result"], {"table_id": dataset.id}) + + def test_get_or_create_dataset_database_not_found(self): + """ + Dataset API: Test get or create endpoint when database doesn't exist + """ + self.login(username="admin") + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={"table_name": "virtual_dataset", "database_id": 999}, + ) + self.assertEqual(rv.status_code, 422) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["message"], {"database": ["Database does not exist"]}) + + @patch("superset.datasets.commands.create.CreateDatasetCommand.run") + def test_get_or_create_dataset_create_fails(self, command_run_mock): + """ + Dataset API: Test get or create endpoint when create fails + """ + command_run_mock.side_effect = DatasetCreateFailedError + self.login(username="admin") + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={ + "table_name": "virtual_dataset", + "database_id": get_example_database().id, + }, + ) + self.assertEqual(rv.status_code, 422) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["message"], "Dataset could not be created.") + + def test_get_or_create_dataset_creates_table(self): + """ + Dataset API: Test get or create endpoint when table is created + """ + self.login(username="admin") + + examples_db = get_example_database() + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api") + engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col") + + rv = self.client.post( + "api/v1/dataset/get_or_create/", + json={ + "table_name": "test_create_sqla_table_api", + "database_id": examples_db.id, + "template_params": '{"param": 1}', + }, + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + table = ( + db.session.query(SqlaTable) + .filter_by(table_name="test_create_sqla_table_api") + .one() + ) + self.assertEqual(response["result"], {"table_id": table.id}) + self.assertEqual(table.template_params, '{"param": 1}') + + db.session.delete(table) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_create_sqla_table_api") + db.session.commit() diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 5cc5c85beab37..a4ff8e2aca5f1 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -19,6 +19,7 @@ from unittest.mock import patch import pytest +from sqlalchemy.exc import SQLAlchemyError import yaml from superset import db, security_manager @@ -26,7 +27,11 @@ from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.sqla.models import SqlaTable from superset.databases.commands.importers.v1 import ImportDatabasesCommand -from superset.datasets.commands.exceptions import DatasetNotFoundError +from superset.datasets.commands.create import CreateDatasetCommand +from superset.datasets.commands.exceptions import ( + DatasetInvalidError, + DatasetNotFoundError, +) from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.importers import v0, v1 from superset.models.core import Database @@ -519,3 +524,47 @@ def _get_table_from_list_by_name(name: str, tables: List[Any]): if table.table_name == name: return table raise ValueError(f"Table {name} does not exists in database") + + +class TestCreateDatasetCommand(SupersetTestCase): + def test_database_not_found(self): + self.login(username="admin") + with self.assertRaises(DatasetInvalidError): + CreateDatasetCommand({"table_name": "table", "database": 9999}).run() + + @patch("superset.models.core.Database.get_table") + def test_get_table_from_database_error(self, get_table_mock): + self.login(username="admin") + get_table_mock.side_effect = SQLAlchemyError + with self.assertRaises(DatasetInvalidError): + CreateDatasetCommand( + {"table_name": "table", "database": get_example_database().id} + ).run() + + @patch("superset.security.manager.g") + @patch("superset.commands.utils.g") + def test_create_dataset_command(self, mock_g, mock_g2): + mock_g.user = security_manager.find_user("admin") + mock_g2.user = mock_g.user + examples_db = get_example_database() + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS test_create_dataset_command") + engine.execute( + "CREATE TABLE test_create_dataset_command AS SELECT 2 as col" + ) + + table = CreateDatasetCommand( + {"table_name": "test_create_dataset_command", "database": examples_db.id} + ).run() + fetched_table = ( + db.session.query(SqlaTable) + .filter_by(table_name="test_create_dataset_command") + .one() + ) + self.assertEqual(table, fetched_table) + self.assertEqual([owner.username for owner in table.owners], ["admin"]) + + db.session.delete(table) + with examples_db.get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE test_create_dataset_command") + db.session.commit() diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py index 2b1cc6812296a..522aa33383e62 100644 --- a/tests/integration_tests/datasource/api_tests.py +++ b/tests/integration_tests/datasource/api_tests.py @@ -22,8 +22,6 @@ from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable from superset.dao.exceptions import DatasourceTypeNotSupportedError -from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError -from superset.utils.database import get_example_database from tests.integration_tests.base_tests import SupersetTestCase @@ -137,76 +135,3 @@ def test_get_column_values_not_implemented_error(self, get_datasource_mock): response["message"], "Unable to get column values for datasource type: sl_table", ) - - @pytest.mark.usefixtures("app_context", "virtual_dataset") - def test_get_or_create_table_already_exists(self): - self.login(username="admin") - rv = self.client.post( - "api/v1/datasource/table/get_or_create/", - json={ - "table_name": "virtual_dataset", - "database_id": get_example_database().id, - }, - ) - self.assertEqual(rv.status_code, 200) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response["result"], {"table_id": self.get_virtual_dataset().id} - ) - - def test_get_or_create_table_database_not_found(self): - self.login(username="admin") - rv = self.client.post( - "api/v1/datasource/table/get_or_create/", - json={"table_name": "virtual_dataset", "database_id": 999}, - ) - self.assertEqual(rv.status_code, 404) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["message"], "Database not found.") - - @patch("superset.datasource.commands.create_table.CreateSqlaTableCommand.run") - def test_get_or_create_table_get_table_fails(self, run_command_mock): - run_command_mock.side_effect = GetTableFromDatabaseFailedError - self.login(username="admin") - rv = self.client.post( - "api/v1/datasource/table/get_or_create/", - json={"table_name": "tbl", "database_id": get_example_database().id}, - ) - self.assertEqual(rv.status_code, 400) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response["message"], - "Table could not be found, please check your " - "database connection, schema, and table name", - ) - - def test_get_or_create_table_creates_table(self): - self.login(username="admin") - - examples_db = get_example_database() - with examples_db.get_sqla_engine_with_context() as engine: - engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_api") - engine.execute("CREATE TABLE test_create_sqla_table_api AS SELECT 2 as col") - - rv = self.client.post( - "api/v1/datasource/table/get_or_create/", - json={ - "table_name": "test_create_sqla_table_api", - "database_id": examples_db.id, - "template_params": '{"param": 1}', - }, - ) - self.assertEqual(rv.status_code, 200) - response = json.loads(rv.data.decode("utf-8")) - table = ( - db.session.query(SqlaTable) - .filter_by(table_name="test_create_sqla_table_api") - .one() - ) - self.assertEqual(response["result"], {"table_id": table.id}) - self.assertEqual(table.template_params, '{"param": 1}') - - db.session.delete(table) - with examples_db.get_sqla_engine_with_context() as engine: - engine.execute("DROP TABLE test_create_sqla_table_api") - db.session.commit() diff --git a/tests/integration_tests/datasource/commands_tests.py b/tests/integration_tests/datasource/commands_tests.py deleted file mode 100644 index d4682e6ab0a13..0000000000000 --- a/tests/integration_tests/datasource/commands_tests.py +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest.mock import patch - -from superset import db, security_manager -from superset.connectors.sqla.models import SqlaTable -from superset.databases.commands.exceptions import DatabaseNotFoundError -from superset.datasource.commands.create_table import CreateSqlaTableCommand -from superset.datasource.commands.exceptions import GetTableFromDatabaseFailedError -from superset.utils.database import get_example_database -from tests.integration_tests.base_tests import SupersetTestCase - - -class TestCreateSqlaTableCommand(SupersetTestCase): - def test_database_not_found(self): - self.login(username="admin") - with self.assertRaises(DatabaseNotFoundError): - CreateSqlaTableCommand("table", 9999).run() - - @patch("superset.security.manager.g") - @patch("superset.models.core.Database.get_table") - def test_get_table_from_database_error(self, get_table_mock, mock_g): - mock_g.user = security_manager.find_user("admin") - get_table_mock.side_effect = Exception - with self.assertRaises(GetTableFromDatabaseFailedError): - CreateSqlaTableCommand("table", get_example_database().id).run() - - @patch("superset.security.manager.g") - @patch("superset.datasource.commands.create_table.g") - def test_create_sqla_table_command(self, mock_g, mock_g2): - mock_g.user = security_manager.find_user("admin") - mock_g2.user = mock_g.user - examples_db = get_example_database() - with examples_db.get_sqla_engine_with_context() as engine: - engine.execute("DROP TABLE IF EXISTS test_create_sqla_table_command") - engine.execute( - "CREATE TABLE test_create_sqla_table_command AS SELECT 2 as col" - ) - - table = CreateSqlaTableCommand( - "test_create_sqla_table_command", examples_db.id - ).run() - fetched_table = ( - db.session.query(SqlaTable) - .filter_by(table_name="test_create_sqla_table_command") - .one() - ) - self.assertEqual(table, fetched_table) - self.assertEqual([owner.username for owner in table.owners], ["admin"]) - - db.session.delete(table) - with examples_db.get_sqla_engine_with_context() as engine: - engine.execute("DROP TABLE test_create_sqla_table_command") - db.session.commit() From 734fd853143f5af0682d83e4158deb02fad9d63c Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Wed, 1 Feb 2023 09:49:21 -0800 Subject: [PATCH 4/8] Fix import order --- tests/integration_tests/datasets/commands_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index a4ff8e2aca5f1..0ce98477a0b2d 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -19,8 +19,8 @@ from unittest.mock import patch import pytest -from sqlalchemy.exc import SQLAlchemyError import yaml +from sqlalchemy.exc import SQLAlchemyError from superset import db, security_manager from superset.commands.exceptions import CommandInvalidError From 8481810cb9eb678f1d3cf3797b2083e632d897f5 Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Wed, 1 Feb 2023 10:14:00 -0800 Subject: [PATCH 5/8] Fix dataset permissions test --- tests/integration_tests/datasets/api_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 02010e4feac28..6e0551bd9f826 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -475,6 +475,7 @@ def test_info_security_dataset(self): "can_write", "can_export", "can_duplicate", + "can_get_or_create_dataset", } def test_create_dataset_item(self): From ea104216680fc2fc0986f4aaf3147ddc68a51fa6 Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Wed, 8 Feb 2023 15:13:01 -0800 Subject: [PATCH 6/8] Address comments --- superset/connectors/sqla/views.py | 2 -- superset/datasets/api.py | 38 +++++++++++++--------------- superset/datasets/commands/create.py | 2 -- superset/datasets/dao.py | 8 ++++++ superset/views/base.py | 6 ----- superset/views/core.py | 2 -- 6 files changed, 25 insertions(+), 33 deletions(-) diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index c502f527acd6f..86cb08bb8690d 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -35,7 +35,6 @@ from superset.superset_typing import FlaskResponse from superset.utils import core as utils from superset.views.base import ( - create_table_permissions, DatasourceFilter, DeleteMixin, ListWidgetWithCheckboxes, @@ -511,7 +510,6 @@ def post_add( # pylint: disable=arguments-differ ) -> None: if fetch_metadata: item.fetch_metadata() - create_table_permissions(item) if flash_message: flash( _( diff --git a/superset/datasets/api.py b/superset/datasets/api.py index f9a8b84ed6b0b..0372195a3341a 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -929,25 +929,21 @@ def get_or_create_dataset(self) -> Response: return self.response(400, message=ex.messages) table_name = body["table_name"] database_id = body["database_id"] - table = ( - db.session.query(SqlaTable) - .filter_by(database_id=database_id, table_name=table_name) - .one_or_none() - ) - if not table: - body["database"] = database_id - try: - table = CreateDatasetCommand(body).run() - except DatasetInvalidError as ex: - return self.response_422(message=ex.normalized_messages()) - except DatasetCreateFailedError as ex: - logger.error( - "Error creating model %s: %s", - self.__class__.__name__, - str(ex), - exc_info=True, - ) - return self.response_422(message=ex.message) + table = DatasetDAO.get_table_by_name(database_id, table_name) + if table: + return self.response(200, result={"table_id": table.id}) - payload = {"table_id": table.id} - return self.response(200, result=payload) + body["database"] = database_id + try: + table = CreateDatasetCommand(body).run() + return self.response(200, result={"table_id": table.id}) + except DatasetInvalidError as ex: + return self.response_422(message=ex.normalized_messages()) + except DatasetCreateFailedError as ex: + logger.error( + "Error creating model %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_422(message=ex.message) diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 74c58afb324cc..809eec7a1159a 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -32,7 +32,6 @@ ) from superset.datasets.dao import DatasetDAO from superset.extensions import db -from superset.views.base import create_table_permissions logger = logging.getLogger(__name__) @@ -48,7 +47,6 @@ def run(self) -> Model: dataset = DatasetDAO.create(self._properties, commit=False) # Updates columns and metrics from the dataset dataset.fetch_metadata(commit=False) - create_table_permissions(dataset) db.session.commit() except (SQLAlchemyError, DAOCreateFailedError) as ex: logger.warning(ex, exc_info=True) diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index d260df3610002..191393f169e2d 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -367,6 +367,14 @@ def bulk_delete(models: Optional[List[SqlaTable]], commit: bool = True) -> None: db.session.rollback() raise ex + @staticmethod + def get_table_by_name(database_id: int, table_name: str) -> Optional[SqlaTable]: + return ( + db.session.query(SqlaTable) + .filter_by(database_id=database_id, table_name=table_name) + .one_or_none() + ) + class DatasetColumnDAO(BaseDAO): model_cls = TableColumn diff --git a/superset/views/base.py b/superset/views/base.py index ebccd0684b540..0d076e61b77ba 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -299,12 +299,6 @@ def validate_sqlatable(table: models.SqlaTable) -> None: ) from ex -def create_table_permissions(table: models.SqlaTable) -> None: - security_manager.add_permission_view_menu("datasource_access", table.get_perm()) - if table.schema: - security_manager.add_permission_view_menu("schema_access", table.schema_perm) - - class BaseSupersetView(BaseView): @staticmethod def json_response(obj: Any, status: int = 200) -> FlaskResponse: diff --git a/superset/views/core.py b/superset/views/core.py index 6f5255bb25366..55a5a3adeb7be 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -141,7 +141,6 @@ api, BaseSupersetView, common_bootstrap_payload, - create_table_permissions, CsvResponse, data_payload_response, deprecated, @@ -1953,7 +1952,6 @@ def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use db.session.add(table) table.fetch_metadata() - create_table_permissions(table) db.session.commit() return json_success(json.dumps({"table_id": table.id})) From 9ee4c3733b81241abf259c09ee2031bbd9e5dc05 Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Wed, 8 Feb 2023 15:18:46 -0800 Subject: [PATCH 7/8] Lint --- superset/datasets/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 0372195a3341a..baf2c774d60dc 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -28,7 +28,7 @@ from flask_babel import ngettext from marshmallow import ValidationError -from superset import db, event_logger, is_feature_enabled +from superset import event_logger, is_feature_enabled from superset.commands.importers.exceptions import NoValidFilesFoundError from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.connectors.sqla.models import SqlaTable From 4ea61a142054cf7b32be9418c099b658fa2544b7 Mon Sep 17 00:00:00 2001 From: Jack Fragassi Date: Thu, 9 Feb 2023 10:31:12 -0800 Subject: [PATCH 8/8] Fix mypy --- superset/datasets/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index baf2c774d60dc..d58a1dd3f6152 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -935,8 +935,8 @@ def get_or_create_dataset(self) -> Response: body["database"] = database_id try: - table = CreateDatasetCommand(body).run() - return self.response(200, result={"table_id": table.id}) + tbl = CreateDatasetCommand(body).run() + return self.response(200, result={"table_id": tbl.id}) except DatasetInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DatasetCreateFailedError as ex: