From 4ac8cdda706e096fba07413f42d86038decb09f5 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Sun, 6 Dec 2020 18:00:42 -0800 Subject: [PATCH] feat: confirm overwrite when importing --- superset/charts/api.py | 19 +++- .../charts/commands/importers/v1/__init__.py | 7 +- superset/commands/importers/v1/__init__.py | 51 ++++++--- superset/dashboards/api.py | 19 +++- .../commands/importers/v1/__init__.py | 7 +- superset/databases/api.py | 19 +++- .../commands/importers/v1/__init__.py | 7 +- superset/datasets/api.py | 19 +++- .../commands/importers/v1/__init__.py | 7 +- tests/charts/api_tests.py | 87 ++++++++++++--- tests/charts/commands_tests.py | 2 +- tests/dashboards/api_tests.py | 101 ++++++++++++++---- tests/dashboards/commands_tests.py | 2 +- tests/databases/api_tests.py | 84 ++++++++++++--- tests/databases/commands_tests.py | 6 +- tests/datasets/api_tests.py | 91 +++++++++++++--- tests/datasets/commands_tests.py | 6 +- 17 files changed, 425 insertions(+), 109 deletions(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index 59a8dfc62dd27..a3a2737aaee90 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -845,11 +845,19 @@ def import_(self) -> Response: --- post: requestBody: + required: true content: - application/zip: + multipart/form-data: schema: - type: string - format: binary + type: object + properties: + formData: + type: string + format: binary + passwords: + type: string + overwrite: + type: bool responses: 200: description: Chart import result @@ -883,8 +891,11 @@ def import_(self) -> Response: if "passwords" in request.form else None ) + overwrite = request.form.get("overwrite") == "true" - command = ImportChartsCommand(contents, passwords=passwords) + command = ImportChartsCommand( + contents, passwords=passwords, overwrite=overwrite + ) try: command.run() return self.response(200, message="OK") diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py index 62dde9ff0e3a4..4b3f443306b5b 100644 --- a/superset/charts/commands/importers/v1/__init__.py +++ b/superset/charts/commands/importers/v1/__init__.py @@ -37,6 +37,7 @@ class ImportChartsCommand(ImportModelsCommand): dao = ChartDAO model_name = "chart" + prefix = "charts/" schemas: Dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "datasets/": ImportV1DatasetSchema(), @@ -45,7 +46,9 @@ class ImportChartsCommand(ImportModelsCommand): import_error = ChartImportError @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: + def _import( + session: Session, configs: Dict[str, Any], overwrite: bool = False + ) -> None: # discover datasets associated with charts dataset_uuids: Set[str] = set() for file_name, config in configs.items(): @@ -88,4 +91,4 @@ def _import(session: Session, configs: Dict[str, Any]) -> None: ): # update datasource id, type, and name config.update(dataset_info[config["dataset_uuid"]]) - import_chart(session, config, overwrite=True) + import_chart(session, config, overwrite=overwrite) diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index 16d3314e11616..9637ef94da7aa 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -31,7 +31,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set from marshmallow import Schema, validate from marshmallow.exceptions import ValidationError @@ -55,6 +55,7 @@ class ImportModelsCommand(BaseCommand): dao = BaseDAO model_name = "model" + prefix = "" schemas: Dict[str, Schema] = {} import_error = CommandException @@ -62,18 +63,25 @@ class ImportModelsCommand(BaseCommand): def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.passwords: Dict[str, str] = kwargs.get("passwords") or {} + self.overwrite: bool = kwargs.get("overwrite", False) self._configs: Dict[str, Any] = {} @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: - raise NotImplementedError("Subclasses MUSC implement _import") + def _import( + session: Session, configs: Dict[str, Any], overwrite: bool = False + ) -> None: + raise NotImplementedError("Subclasses MUST implement _import") + + @classmethod + def _get_uuids(cls) -> Set[str]: + return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()} def run(self) -> None: self.validate() # rollback to prevent partial imports try: - self._import(db.session, self._configs) + self._import(db.session, self._configs, self.overwrite) db.session.commit() except Exception: db.session.rollback() @@ -97,6 +105,15 @@ def validate(self) -> None: exceptions.append(exc) metadata = None + # validate that the type declared in METADATA_FILE_NAME is correct + if metadata: + type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore + try: + type_validator(metadata["type"]) + except ValidationError as exc: + exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} + exceptions.append(exc) + # validate objects for file_name, content in self.contents.items(): prefix = file_name.split("/")[0] @@ -117,14 +134,24 @@ def validate(self) -> None: exc.messages = {file_name: exc.messages} exceptions.append(exc) - # validate that the type declared in METADATA_FILE_NAME is correct - if metadata: - type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore - try: - type_validator(metadata["type"]) - except ValidationError as exc: - exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} - exceptions.append(exc) + # check if the object exists and shouldn't be overwritten + if not self.overwrite: + existing_uuids = self._get_uuids() + for file_name, config in self._configs.items(): + if ( + file_name.startswith(self.prefix) + and config["uuid"] in existing_uuids + ): + exceptions.append( + ValidationError( + { + file_name: ( + f"{self.model_name.title()} already exists " + "and `overwrite=true` was not passed" + ), + } + ) + ) if exceptions: exception = CommandInvalidError(f"Error importing {self.model_name}") diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 3f8e3ba377b5e..929406a4686d2 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -665,11 +665,19 @@ def import_(self) -> Response: --- post: requestBody: + required: true content: - application/zip: + multipart/form-data: schema: - type: string - format: binary + type: object + properties: + formData: + type: string + format: binary + passwords: + type: string + overwrite: + type: bool responses: 200: description: Dashboard import result @@ -703,8 +711,11 @@ def import_(self) -> Response: if "passwords" in request.form else None ) + overwrite = request.form.get("overwrite") == "true" - command = ImportDashboardsCommand(contents, passwords=passwords) + command = ImportDashboardsCommand( + contents, passwords=passwords, overwrite=overwrite + ) try: command.run() return self.response(200, message="OK") diff --git a/superset/dashboards/commands/importers/v1/__init__.py b/superset/dashboards/commands/importers/v1/__init__.py index 0b7b235d310ab..1c40a40512545 100644 --- a/superset/dashboards/commands/importers/v1/__init__.py +++ b/superset/dashboards/commands/importers/v1/__init__.py @@ -52,6 +52,7 @@ class ImportDashboardsCommand(ImportModelsCommand): dao = DashboardDAO model_name = "dashboard" + prefix = "dashboards/" schemas: Dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), @@ -63,7 +64,9 @@ class ImportDashboardsCommand(ImportModelsCommand): # TODO (betodealmeida): refactor to use code from other commands # pylint: disable=too-many-branches, too-many-locals @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: + def _import( + session: Session, configs: Dict[str, Any], overwrite: bool = False + ) -> None: # discover charts associated with dashboards chart_uuids: Set[str] = set() for file_name, config in configs.items(): @@ -125,7 +128,7 @@ def _import(session: Session, configs: Dict[str, Any]) -> None: dashboard_chart_ids: List[Tuple[int, int]] = [] for file_name, config in configs.items(): if file_name.startswith("dashboards/"): - dashboard = import_dashboard(session, config, overwrite=True) + dashboard = import_dashboard(session, config, overwrite=overwrite) for uuid in find_chart_uuids(config["position"]): chart_id = chart_ids[uuid] diff --git a/superset/databases/api.py b/superset/databases/api.py index 707c9ae4cf497..89abb27f0bc9d 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -744,11 +744,19 @@ def import_(self) -> Response: --- post: requestBody: + required: true content: - application/zip: + multipart/form-data: schema: - type: string - format: binary + type: object + properties: + formData: + type: string + format: binary + passwords: + type: string + overwrite: + type: bool responses: 200: description: Database import result @@ -782,8 +790,11 @@ def import_(self) -> Response: if "passwords" in request.form else None ) + overwrite = request.form.get("overwrite") == "true" - command = ImportDatabasesCommand(contents, passwords=passwords) + command = ImportDatabasesCommand( + contents, passwords=passwords, overwrite=overwrite + ) try: command.run() return self.response(200, message="OK") diff --git a/superset/databases/commands/importers/v1/__init__.py b/superset/databases/commands/importers/v1/__init__.py index 6453b877deb21..239bd0977f784 100644 --- a/superset/databases/commands/importers/v1/__init__.py +++ b/superset/databases/commands/importers/v1/__init__.py @@ -35,6 +35,7 @@ class ImportDatabasesCommand(ImportModelsCommand): dao = DatabaseDAO model_name = "database" + prefix = "databases/" schemas: Dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "datasets/": ImportV1DatasetSchema(), @@ -42,12 +43,14 @@ class ImportDatabasesCommand(ImportModelsCommand): import_error = DatabaseImportError @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: + def _import( + session: Session, configs: Dict[str, Any], overwrite: bool = False + ) -> None: # first import databases database_ids: Dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): - database = import_database(session, config, overwrite=True) + database = import_database(session, config, overwrite=overwrite) database_ids[str(database.uuid)] = database.id # import related datasets diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 855b6eb4289a9..a9a210e6a86c3 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -624,11 +624,19 @@ def import_(self) -> Response: --- post: requestBody: + required: true content: - application/zip: + multipart/form-data: schema: - type: string - format: binary + type: object + properties: + formData: + type: string + format: binary + passwords: + type: string + overwrite: + type: bool responses: 200: description: Dataset import result @@ -662,8 +670,11 @@ def import_(self) -> Response: if "passwords" in request.form else None ) + overwrite = request.form.get("overwrite") == "true" - command = ImportDatasetsCommand(contents, passwords=passwords) + command = ImportDatasetsCommand( + contents, passwords=passwords, overwrite=overwrite + ) try: command.run() return self.response(200, message="OK") diff --git a/superset/datasets/commands/importers/v1/__init__.py b/superset/datasets/commands/importers/v1/__init__.py index 81f363165fc01..e73213319db6f 100644 --- a/superset/datasets/commands/importers/v1/__init__.py +++ b/superset/datasets/commands/importers/v1/__init__.py @@ -35,6 +35,7 @@ class ImportDatasetsCommand(ImportModelsCommand): dao = DatasetDAO model_name = "dataset" + prefix = "datasets/" schemas: Dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "datasets/": ImportV1DatasetSchema(), @@ -42,7 +43,9 @@ class ImportDatasetsCommand(ImportModelsCommand): import_error = DatasetImportError @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: + def _import( + session: Session, configs: Dict[str, Any], overwrite: bool = False + ) -> None: # discover databases associated with datasets database_uuids: Set[str] = set() for file_name, config in configs.items(): @@ -63,4 +66,4 @@ def _import(session: Session, configs: Dict[str, Any]) -> None: and config["database_uuid"] in database_ids ): config["database_id"] = database_ids[config["database_uuid"]] - import_dataset(session, config, overwrite=True) + import_dataset(session, config, overwrite=overwrite) diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 747976943764e..5d69e28fe6ae0 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -173,6 +173,22 @@ def add_dashboard_to_chart(self): db.session.delete(self.chart) db.session.commit() + def create_chart_import(self): + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("chart_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(chart_metadata_config).encode()) + with bundle.open( + "chart_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("chart_export/datasets/imported_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + with bundle.open("chart_export/charts/imported_chart.yaml", "w") as fp: + fp.write(yaml.safe_dump(chart_config).encode()) + buf.seek(0) + return buf + def test_delete_chart(self): """ Chart API: Test delete @@ -1314,20 +1330,7 @@ def test_import_chart(self): self.login(username="admin") uri = "api/v1/chart/import/" - buf = BytesIO() - with ZipFile(buf, "w") as bundle: - with bundle.open("chart_export/metadata.yaml", "w") as fp: - fp.write(yaml.safe_dump(chart_metadata_config).encode()) - with bundle.open( - "chart_export/databases/imported_database.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("chart_export/datasets/imported_dataset.yaml", "w") as fp: - fp.write(yaml.safe_dump(dataset_config).encode()) - with bundle.open("chart_export/charts/imported_chart.yaml", "w") as fp: - fp.write(yaml.safe_dump(chart_config).encode()) - buf.seek(0) - + buf = self.create_chart_import() form_data = { "formData": (buf, "chart_export.zip"), } @@ -1355,6 +1358,62 @@ def test_import_chart(self): db.session.delete(database) db.session.commit() + def test_import_chart_overwrite(self): + """ + Chart API: Test import existing chart + """ + self.login(username="admin") + uri = "api/v1/chart/import/" + + buf = self.create_chart_import() + form_data = { + "formData": (buf, "chart_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # import again without overwrite flag + buf = self.create_chart_import() + form_data = { + "formData": (buf, "chart_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": { + "charts/imported_chart.yaml": "Chart already exists and `overwrite=true` was not passed", + } + } + + # import with overwrite flag + buf = self.create_chart_import() + form_data = { + "formData": (buf, "chart_export.zip"), + "overwrite": "true", + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # clean up + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + dataset = database.tables[0] + chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one() + + db.session.delete(chart) + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + def test_import_chart_invalid(self): """ Chart API: Test import invalid chart diff --git a/tests/charts/commands_tests.py b/tests/charts/commands_tests.py index c3522077c073c..5931d86824310 100644 --- a/tests/charts/commands_tests.py +++ b/tests/charts/commands_tests.py @@ -190,7 +190,7 @@ def test_import_v1_chart_multiple(self): "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), "charts/imported_chart.yaml": yaml.safe_dump(chart_config), } - command = ImportChartsCommand(contents) + command = ImportChartsCommand(contents, overwrite=True) command.run() command.run() diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index aba79f9b93f39..3855c3ace34f1 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -434,6 +434,28 @@ def test_get_dashboards_not_favorite_filter(self): expected_model.dashboard_title == data["result"][i]["dashboard_title"] ) + def create_dashboard_import(self): + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("dashboard_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(dashboard_metadata_config).encode()) + with bundle.open( + "dashboard_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open( + "dashboard_export/datasets/imported_dataset.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + with bundle.open("dashboard_export/charts/imported_chart.yaml", "w") as fp: + fp.write(yaml.safe_dump(chart_config).encode()) + with bundle.open( + "dashboard_export/dashboards/imported_dashboard.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dashboard_config).encode()) + buf.seek(0) + return buf + def test_get_dashboards_no_data_access(self): """ Dashboard API: Test get dashboards no data access @@ -1165,26 +1187,7 @@ def test_import_dashboard(self): self.login(username="admin") uri = "api/v1/dashboard/import/" - buf = BytesIO() - with ZipFile(buf, "w") as bundle: - with bundle.open("dashboard_export/metadata.yaml", "w") as fp: - fp.write(yaml.safe_dump(dashboard_metadata_config).encode()) - with bundle.open( - "dashboard_export/databases/imported_database.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open( - "dashboard_export/datasets/imported_dataset.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(dataset_config).encode()) - with bundle.open("dashboard_export/charts/imported_chart.yaml", "w") as fp: - fp.write(yaml.safe_dump(chart_config).encode()) - with bundle.open( - "dashboard_export/dashboards/imported_dashboard.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(dashboard_config).encode()) - buf.seek(0) - + buf = self.create_dashboard_import() form_data = { "formData": (buf, "dashboard_export.zip"), } @@ -1215,6 +1218,64 @@ def test_import_dashboard(self): db.session.delete(database) db.session.commit() + def test_import_dashboard_overwrite(self): + """ + Dashboard API: Test import existing dashboard + """ + self.login(username="admin") + uri = "api/v1/dashboard/import/" + + buf = self.create_dashboard_import() + form_data = { + "formData": (buf, "dashboard_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # import again without overwrite flag + buf = self.create_dashboard_import() + form_data = { + "formData": (buf, "dashboard_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": { + "dashboards/imported_dashboard.yaml": "Dashboard already exists and `overwrite=true` was not passed" + } + } + + # import with overwrite flag + buf = self.create_dashboard_import() + form_data = { + "formData": (buf, "dashboard_export.zip"), + "overwrite": "true", + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # cleanup + dashboard = ( + db.session.query(Dashboard).filter_by(uuid=dashboard_config["uuid"]).one() + ) + chart = dashboard.slices[0] + dataset = chart.table + database = dataset.database + + db.session.delete(dashboard) + db.session.delete(chart) + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + def test_import_dashboard_invalid(self): """ Dataset API: Test import invalid dashboard diff --git a/tests/dashboards/commands_tests.py b/tests/dashboards/commands_tests.py index b081a14f38801..d8eccc5790398 100644 --- a/tests/dashboards/commands_tests.py +++ b/tests/dashboards/commands_tests.py @@ -339,7 +339,7 @@ def test_import_v1_dashboard_multiple(self): "charts/imported_chart.yaml": yaml.safe_dump(chart_config), "dashboards/imported_dashboard.yaml": yaml.safe_dump(dashboard_config), } - command = v1.ImportDashboardsCommand(contents) + command = v1.ImportDashboardsCommand(contents, overwrite=True) command.run() command.run() diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index 71c96b27c70da..14b0a50d13d32 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -91,6 +91,22 @@ def create_database_with_report(self): db.session.delete(database) db.session.commit() + def create_database_import(self): + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("database_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_metadata_config).encode()) + with bundle.open( + "database_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open( + "database_export/datasets/imported_dataset.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + return buf + def test_get_items(self): """ Database API: Test get items @@ -879,20 +895,7 @@ def test_import_database(self): self.login(username="admin") uri = "api/v1/database/import/" - buf = BytesIO() - with ZipFile(buf, "w") as bundle: - with bundle.open("database_export/metadata.yaml", "w") as fp: - fp.write(yaml.safe_dump(database_metadata_config).encode()) - with bundle.open( - "database_export/databases/imported_database.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open( - "database_export/datasets/imported_dataset.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(dataset_config).encode()) - buf.seek(0) - + buf = self.create_database_import() form_data = { "formData": (buf, "database_export.zip"), } @@ -916,6 +919,59 @@ def test_import_database(self): db.session.delete(database) db.session.commit() + def test_import_database_overwrite(self): + """ + Database API: Test import existing database + """ + self.login(username="admin") + uri = "api/v1/database/import/" + + buf = self.create_database_import() + form_data = { + "formData": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # import again without overwrite flag + buf = self.create_database_import() + form_data = { + "formData": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": { + "databases/imported_database.yaml": "Database already exists and `overwrite=true` was not passed" + } + } + + # import with overwrite flag + buf = self.create_database_import() + form_data = { + "formData": (buf, "database_export.zip"), + "overwrite": "true", + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # clean up + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + dataset = database.tables[0] + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + def test_import_database_invalid(self): """ Database API: Test import invalid database diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py index 3ace131ada2aa..cddbf0d141ee1 100644 --- a/tests/databases/commands_tests.py +++ b/tests/databases/commands_tests.py @@ -314,7 +314,7 @@ def test_import_v1_database_multiple(self): "databases/imported_database.yaml": yaml.safe_dump(database_config), "metadata.yaml": yaml.safe_dump(database_metadata_config), } - command = ImportDatabasesCommand(contents) + command = ImportDatabasesCommand(contents, overwrite=True) # import twice command.run() @@ -332,7 +332,7 @@ def test_import_v1_database_multiple(self): "databases/imported_database.yaml": yaml.safe_dump(new_config), "metadata.yaml": yaml.safe_dump(database_metadata_config), } - command = ImportDatabasesCommand(contents) + command = ImportDatabasesCommand(contents, overwrite=True) command.run() database = ( @@ -389,7 +389,7 @@ def test_import_v1_database_with_dataset_multiple(self): "datasets/imported_dataset.yaml": yaml.safe_dump(new_config), "metadata.yaml": yaml.safe_dump(database_metadata_config), } - command = ImportDatabasesCommand(contents) + command = ImportDatabasesCommand(contents, overwrite=True) command.run() # the underlying dataset should not be modified by the second import, since diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index bea5fadfdb866..684e8e795f1e9 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -137,6 +137,22 @@ def get_energy_usage_dataset(): .one() ) + def create_dataset_import(self): + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("dataset_export/metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_metadata_config).encode()) + with bundle.open( + "dataset_export/databases/imported_database.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open( + "dataset_export/datasets/imported_dataset.yaml", "w" + ) as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + return buf + def test_get_dataset_list(self): """ Dataset API: Test get dataset list @@ -1214,27 +1230,14 @@ def test_get_datasets_custom_filter_sql(self): for table_name in self.fixture_tables_names: assert table_name in [ds["table_name"] for ds in data["result"]] - def test_imported_dataset(self): + def test_import_dataset(self): """ Dataset API: Test import dataset """ self.login(username="admin") uri = "api/v1/dataset/import/" - buf = BytesIO() - with ZipFile(buf, "w") as bundle: - with bundle.open("dataset_export/metadata.yaml", "w") as fp: - fp.write(yaml.safe_dump(dataset_metadata_config).encode()) - with bundle.open( - "dataset_export/databases/imported_database.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open( - "dataset_export/datasets/imported_dataset.yaml", "w" - ) as fp: - fp.write(yaml.safe_dump(dataset_config).encode()) - buf.seek(0) - + buf = self.create_dataset_import() form_data = { "formData": (buf, "dataset_export.zip"), } @@ -1258,7 +1261,61 @@ def test_imported_dataset(self): db.session.delete(database) db.session.commit() - def test_imported_dataset_invalid(self): + def test_import_dataset_overwrite(self): + """ + Dataset API: Test import existing dataset + """ + self.login(username="admin") + uri = "api/v1/dataset/import/" + + buf = self.create_dataset_import() + form_data = { + "formData": (buf, "dataset_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # import again without overwrite flag + buf = self.create_dataset_import() + form_data = { + "formData": (buf, "dataset_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": { + "datasets/imported_dataset.yaml": "Dataset already exists and `overwrite=true` was not passed" + } + } + + # import with overwrite flag + buf = self.create_dataset_import() + form_data = { + "formData": (buf, "dataset_export.zip"), + "overwrite": "true", + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + # clean up + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + dataset = database.tables[0] + + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + + def test_import_dataset_invalid(self): """ Dataset API: Test import invalid dataset """ @@ -1290,7 +1347,7 @@ def test_imported_dataset_invalid(self): "message": {"metadata.yaml": {"type": ["Must be equal to SqlaTable."]}} } - def test_imported_dataset_invalid_v0_validation(self): + def test_import_dataset_invalid_v0_validation(self): """ Dataset API: Test import invalid dataset """ diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py index 2e7249b9d9b3f..42cea1e12f851 100644 --- a/tests/datasets/commands_tests.py +++ b/tests/datasets/commands_tests.py @@ -341,7 +341,7 @@ def test_import_v1_dataset_multiple(self): "databases/imported_database.yaml": yaml.safe_dump(database_config), "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), } - command = v1.ImportDatasetsCommand(contents) + command = v1.ImportDatasetsCommand(contents, overwrite=True) command.run() command.run() dataset = ( @@ -359,7 +359,7 @@ def test_import_v1_dataset_multiple(self): "databases/imported_database.yaml": yaml.safe_dump(database_config), "datasets/imported_dataset.yaml": yaml.safe_dump(new_config), } - command = v1.ImportDatasetsCommand(contents) + command = v1.ImportDatasetsCommand(contents, overwrite=True) command.run() dataset = ( db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() @@ -443,7 +443,7 @@ def test_import_v1_dataset_existing_database(self): "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), } - command = v1.ImportDatasetsCommand(contents) + command = v1.ImportDatasetsCommand(contents, overwrite=True) command.run() database = (