Skip to content

Commit

Permalink
feat: confirm overwrite when importing
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Dec 9, 2020
1 parent 2297f9b commit 4ac8cdd
Show file tree
Hide file tree
Showing 17 changed files with 425 additions and 109 deletions.
19 changes: 15 additions & 4 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,11 +845,19 @@ def import_(self) -> Response:
---
post:
requestBody:
required: true
content:
application/zip:
multipart/form-data:
schema:
type: string
format: binary
type: object
properties:
formData:
type: string
format: binary
passwords:
type: string
overwrite:
type: bool
responses:
200:
description: Chart import result
Expand Down Expand Up @@ -883,8 +891,11 @@ def import_(self) -> Response:
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"

command = ImportChartsCommand(contents, passwords=passwords)
command = ImportChartsCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
Expand Down
7 changes: 5 additions & 2 deletions superset/charts/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ImportChartsCommand(ImportModelsCommand):

dao = ChartDAO
model_name = "chart"
prefix = "charts/"
schemas: Dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"datasets/": ImportV1DatasetSchema(),
Expand All @@ -45,7 +46,9 @@ class ImportChartsCommand(ImportModelsCommand):
import_error = ChartImportError

@staticmethod
def _import(session: Session, configs: Dict[str, Any]) -> None:
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# discover datasets associated with charts
dataset_uuids: Set[str] = set()
for file_name, config in configs.items():
Expand Down Expand Up @@ -88,4 +91,4 @@ def _import(session: Session, configs: Dict[str, Any]) -> None:
):
# update datasource id, type, and name
config.update(dataset_info[config["dataset_uuid"]])
import_chart(session, config, overwrite=True)
import_chart(session, config, overwrite=overwrite)
51 changes: 39 additions & 12 deletions superset/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

from marshmallow import Schema, validate
from marshmallow.exceptions import ValidationError
Expand All @@ -55,25 +55,33 @@ class ImportModelsCommand(BaseCommand):

dao = BaseDAO
model_name = "model"
prefix = ""
schemas: Dict[str, Schema] = {}
import_error = CommandException

# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
self.overwrite: bool = kwargs.get("overwrite", False)
self._configs: Dict[str, Any] = {}

@staticmethod
def _import(session: Session, configs: Dict[str, Any]) -> None:
raise NotImplementedError("Subclasses MUSC implement _import")
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
raise NotImplementedError("Subclasses MUST implement _import")

@classmethod
def _get_uuids(cls) -> Set[str]:
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}

def run(self) -> None:
self.validate()

# rollback to prevent partial imports
try:
self._import(db.session, self._configs)
self._import(db.session, self._configs, self.overwrite)
db.session.commit()
except Exception:
db.session.rollback()
Expand All @@ -97,6 +105,15 @@ def validate(self) -> None:
exceptions.append(exc)
metadata = None

# validate that the type declared in METADATA_FILE_NAME is correct
if metadata:
type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore
try:
type_validator(metadata["type"])
except ValidationError as exc:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)

# validate objects
for file_name, content in self.contents.items():
prefix = file_name.split("/")[0]
Expand All @@ -117,14 +134,24 @@ def validate(self) -> None:
exc.messages = {file_name: exc.messages}
exceptions.append(exc)

# validate that the type declared in METADATA_FILE_NAME is correct
if metadata:
type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore
try:
type_validator(metadata["type"])
except ValidationError as exc:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)
# check if the object exists and shouldn't be overwritten
if not self.overwrite:
existing_uuids = self._get_uuids()
for file_name, config in self._configs.items():
if (
file_name.startswith(self.prefix)
and config["uuid"] in existing_uuids
):
exceptions.append(
ValidationError(
{
file_name: (
f"{self.model_name.title()} already exists "
"and `overwrite=true` was not passed"
),
}
)
)

if exceptions:
exception = CommandInvalidError(f"Error importing {self.model_name}")
Expand Down
19 changes: 15 additions & 4 deletions superset/dashboards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,11 +665,19 @@ def import_(self) -> Response:
---
post:
requestBody:
required: true
content:
application/zip:
multipart/form-data:
schema:
type: string
format: binary
type: object
properties:
formData:
type: string
format: binary
passwords:
type: string
overwrite:
type: bool
responses:
200:
description: Dashboard import result
Expand Down Expand Up @@ -703,8 +711,11 @@ def import_(self) -> Response:
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"

command = ImportDashboardsCommand(contents, passwords=passwords)
command = ImportDashboardsCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
Expand Down
7 changes: 5 additions & 2 deletions superset/dashboards/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ImportDashboardsCommand(ImportModelsCommand):

dao = DashboardDAO
model_name = "dashboard"
prefix = "dashboards/"
schemas: Dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
Expand All @@ -63,7 +64,9 @@ class ImportDashboardsCommand(ImportModelsCommand):
# TODO (betodealmeida): refactor to use code from other commands
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
def _import(session: Session, configs: Dict[str, Any]) -> None:
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# discover charts associated with dashboards
chart_uuids: Set[str] = set()
for file_name, config in configs.items():
Expand Down Expand Up @@ -125,7 +128,7 @@ def _import(session: Session, configs: Dict[str, Any]) -> None:
dashboard_chart_ids: List[Tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
dashboard = import_dashboard(session, config, overwrite=True)
dashboard = import_dashboard(session, config, overwrite=overwrite)

for uuid in find_chart_uuids(config["position"]):
chart_id = chart_ids[uuid]
Expand Down
19 changes: 15 additions & 4 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,11 +744,19 @@ def import_(self) -> Response:
---
post:
requestBody:
required: true
content:
application/zip:
multipart/form-data:
schema:
type: string
format: binary
type: object
properties:
formData:
type: string
format: binary
passwords:
type: string
overwrite:
type: bool
responses:
200:
description: Database import result
Expand Down Expand Up @@ -782,8 +790,11 @@ def import_(self) -> Response:
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"

command = ImportDatabasesCommand(contents, passwords=passwords)
command = ImportDatabasesCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
Expand Down
7 changes: 5 additions & 2 deletions superset/databases/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ class ImportDatabasesCommand(ImportModelsCommand):

dao = DatabaseDAO
model_name = "database"
prefix = "databases/"
schemas: Dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
import_error = DatabaseImportError

@staticmethod
def _import(session: Session, configs: Dict[str, Any]) -> None:
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# first import databases
database_ids: Dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=True)
database = import_database(session, config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id

# import related datasets
Expand Down
19 changes: 15 additions & 4 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,11 +624,19 @@ def import_(self) -> Response:
---
post:
requestBody:
required: true
content:
application/zip:
multipart/form-data:
schema:
type: string
format: binary
type: object
properties:
formData:
type: string
format: binary
passwords:
type: string
overwrite:
type: bool
responses:
200:
description: Dataset import result
Expand Down Expand Up @@ -662,8 +670,11 @@ def import_(self) -> Response:
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"

command = ImportDatasetsCommand(contents, passwords=passwords)
command = ImportDatasetsCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
Expand Down
7 changes: 5 additions & 2 deletions superset/datasets/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ class ImportDatasetsCommand(ImportModelsCommand):

dao = DatasetDAO
model_name = "dataset"
prefix = "datasets/"
schemas: Dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
import_error = DatasetImportError

@staticmethod
def _import(session: Session, configs: Dict[str, Any]) -> None:
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# discover databases associated with datasets
database_uuids: Set[str] = set()
for file_name, config in configs.items():
Expand All @@ -63,4 +66,4 @@ def _import(session: Session, configs: Dict[str, Any]) -> None:
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
import_dataset(session, config, overwrite=True)
import_dataset(session, config, overwrite=overwrite)
Loading

0 comments on commit 4ac8cdd

Please sign in to comment.