From 38f009e7a3ccf09aecd6a4fab8d5782ca69c284d Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 08:53:52 +0100 Subject: [PATCH 01/14] Change order of operations --- superset/commands/database/create.py | 6 ++++-- superset/commands/database/ssh_tunnel/create.py | 9 --------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index a012e9b2a5768..b6ad4772dceec 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -80,14 +80,14 @@ def run(self) -> Model: database = DatabaseDAO.create(attributes=self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) + db.session.add(database) + ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): if not is_feature_enabled("SSH_TUNNELING"): db.session.rollback() raise SSHTunnelingNotEnabledError() try: - # So database.id is not None - db.session.flush() ssh_tunnel = CreateSSHTunnelCommand( database.id, ssh_tunnel_properties ).run() @@ -96,6 +96,7 @@ def run(self) -> Model: action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], ) + db.session.rollback() # So we can show the original message raise ex except Exception as ex: @@ -103,6 +104,7 @@ def run(self) -> Model: action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], ) + db.session.rollback() raise DatabaseCreateFailedError() from ex # adding a new database we always want to force refresh schema list diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 07209f010ba1d..fc4d667e3513f 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -40,20 +40,11 @@ def __init__(self, database_id: int, data: dict[str, Any]): def run(self) -> Model: try: - # Start nested transaction since we are always creating the tunnel - # through a DB command (Create or Update). Without this, we cannot - # safely rollback changes to databases if any, i.e, things like - # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail - db.session.begin_nested() self.validate() return SSHTunnelDAO.create(attributes=self._properties, commit=False) except DAOCreateFailedError as ex: - # Rollback nested transaction - db.session.rollback() raise SSHTunnelCreateFailedError() from ex except SSHTunnelInvalidError as ex: - # Rollback nested transaction - db.session.rollback() raise ex def validate(self) -> None: From abff83571cb2266f647b862165534bb780764aa7 Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 09:08:07 +0100 Subject: [PATCH 02/14] Pass database ref --- superset/commands/database/create.py | 2 +- superset/commands/database/ssh_tunnel/create.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index b6ad4772dceec..a6ae868009f62 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -89,7 +89,7 @@ def run(self) -> Model: raise SSHTunnelingNotEnabledError() try: ssh_tunnel = CreateSSHTunnelCommand( - database.id, ssh_tunnel_properties + database, ssh_tunnel_properties ).run() except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: event_logger.log_with_context( diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index fc4d667e3513f..aeeadd4b42eb6 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -28,15 +28,16 @@ ) from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOCreateFailedError -from superset.extensions import db, event_logger +from superset.extensions import event_logger +from superset.models.core import Database logger = logging.getLogger(__name__) class CreateSSHTunnelCommand(BaseCommand): - def __init__(self, database_id: int, data: dict[str, Any]): + def __init__(self, database: Database, data: dict[str, Any]): self._properties = data.copy() - self._properties["database_id"] = database_id + self._properties["database"] = database def run(self) -> Model: try: @@ -59,8 +60,6 @@ def validate(self) -> None: private_key_password: Optional[str] = self._properties.get( "private_key_password" ) - if not database_id: - exceptions.append(SSHTunnelRequiredFieldValidationError("database_id")) if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) if not server_port: From aae33b88a89be8b3a884ecff77877d6ebb76785d Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 10:02:58 +0100 Subject: [PATCH 03/14] Update tests --- superset/commands/database/ssh_tunnel/create.py | 1 - .../ssh_tunnel/commands/commands_tests.py | 17 ----------------- .../ssh_tunnel/commands/create_test.py | 4 ++-- 3 files changed, 2 insertions(+), 20 deletions(-) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index aeeadd4b42eb6..ae7b6cc1b910d 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -52,7 +52,6 @@ def validate(self) -> None: # TODO(hughhh): check to make sure the server port is not localhost # using the config.SSH_TUNNEL_MANAGER exceptions: list[ValidationError] = [] - database_id: Optional[int] = self._properties.get("database_id") server_address: Optional[str] = self._properties.get("server_address") server_port: Optional[int] = self._properties.get("server_port") username: Optional[str] = self._properties.get("username") diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py index 1cd9afcc809c7..f6e5ca9d09681 100644 --- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -30,23 +30,6 @@ from tests.integration_tests.base_tests import SupersetTestCase -class TestCreateSSHTunnelCommand(SupersetTestCase): - @mock.patch("superset.utils.core.g") - def test_create_invalid_database_id(self, mock_g): - mock_g.user = security_manager.find_user("admin") - command = CreateSSHTunnelCommand( - None, - { - "server_address": "127.0.0.1", - "server_port": 5432, - "username": "test_user", - }, - ) - with pytest.raises(SSHTunnelInvalidError) as excinfo: - command.run() - assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") - - class TestUpdateSSHTunnelCommand(SupersetTestCase): @mock.patch("superset.utils.core.g") def test_update_ssh_tunnel_not_found(self, mock_g): diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index bd891b64f05ec..4e6411816332f 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None: "password": "bar", } - result = CreateSSHTunnelCommand(db.id, properties).run() + result = CreateSSHTunnelCommand(db, properties).run() assert result is not None assert isinstance(result, SSHTunnel) @@ -53,7 +53,7 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: # If we are trying to create a tunnel with a private_key_password # then a private_key is mandatory properties = { - "database_id": db.id, + "database": db, "server_address": "123.132.123.1", "server_port": "3005", "username": "foo", From 3ce71f8cae8ab8b86aa320606ddd55286509c27c Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 13:59:31 +0200 Subject: [PATCH 04/14] Fix exceptions --- superset/commands/database/create.py | 41 +++++++++++----------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index a6ae868009f62..c2b234b998653 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -76,36 +76,21 @@ def run(self) -> Model: "{}", ) + ssh_tunnel = None + try: database = DatabaseDAO.create(attributes=self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - db.session.add(database) - - ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() raise SSHTunnelingNotEnabledError() - try: - ssh_tunnel = CreateSSHTunnelCommand( - database, ssh_tunnel_properties - ).run() - except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: - event_logger.log_with_context( - action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", - engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], - ) - db.session.rollback() - # So we can show the original message - raise ex - except Exception as ex: - event_logger.log_with_context( - action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", - engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], - ) - db.session.rollback() - raise DatabaseCreateFailedError() from ex + + ssh_tunnel = CreateSSHTunnelCommand( + database, ssh_tunnel_properties + ).run() + + db.session.commit() # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) @@ -114,8 +99,14 @@ def run(self) -> Model: "schema_access", security_manager.get_schema_perm(database, schema) ) - db.session.commit() - + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError) as ex: + db.session.rollback() + event_logger.log_with_context( + action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", + engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], + ) + # So we can show the original message + raise ex except DAOCreateFailedError as ex: db.session.rollback() event_logger.log_with_context( From 90094e87096882b9109e589c4eefbc056c34d09d Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 14:10:38 +0200 Subject: [PATCH 05/14] Lint --- superset/commands/database/create.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index c2b234b998653..8d1742388572d 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -99,7 +99,11 @@ def run(self) -> Model: "schema_access", security_manager.get_schema_perm(database, schema) ) - except (SSHTunnelInvalidError, SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError) as ex: + except ( + SSHTunnelInvalidError, + SSHTunnelCreateFailedError, + SSHTunnelingNotEnabledError, + ) as ex: db.session.rollback() event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", From b78545057d863463099743c8884f14968e218d68 Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 15:45:43 +0200 Subject: [PATCH 06/14] Conditionally create database --- superset/commands/database/create.py | 28 +++++--- .../commands/database/ssh_tunnel/create.py | 72 +++++++++++-------- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 8d1742388572d..686627e70d3dc 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -30,7 +30,10 @@ DatabaseInvalidError, DatabaseRequiredFieldValidationError, ) -from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand +from superset.commands.database.ssh_tunnel.create import ( + CreateSSHTunnelCommand, + validate_ssh_tunnel, +) from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError, @@ -41,6 +44,7 @@ from superset.daos.exceptions import DAOCreateFailedError from superset.exceptions import SupersetErrorsException from superset.extensions import db, event_logger, security_manager +from superset.models.core import Database logger = logging.getLogger(__name__) stats_logger = current_app.config["STATS_LOGGER"] @@ -79,18 +83,23 @@ def run(self) -> Model: ssh_tunnel = None try: - database = DatabaseDAO.create(attributes=self._properties, commit=False) - database.set_sqlalchemy_uri(database.sqlalchemy_uri) - if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() - ssh_tunnel = CreateSSHTunnelCommand( - database, ssh_tunnel_properties - ).run() + # pre-validate the SSH tunnel properties + is_tunnel_valid = validate_ssh_tunnel(ssh_tunnel_properties) - db.session.commit() + if is_tunnel_valid: + database = self._do_create_database() + ssh_tunnel = CreateSSHTunnelCommand( + database, ssh_tunnel_properties + ).run() + + db.session.commit() + else: + database = self._do_create_database() + db.session.commit() # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) @@ -147,3 +156,6 @@ def validate(self) -> None: ) ) raise exception + + def _do_create_database(self) -> Database: + return DatabaseDAO.create(attributes=self._properties, commit=False) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index ae7b6cc1b910d..e295be1c552f1 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -28,12 +28,47 @@ ) from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOCreateFailedError -from superset.extensions import event_logger +from superset.extensions import db, event_logger from superset.models.core import Database logger = logging.getLogger(__name__) +def validate_ssh_tunnel(properties: dict[str, Any]) -> bool: + # TODO(hughhh): check to make sure the server port is not localhost + # using the config.SSH_TUNNEL_MANAGER + + tunnel_properties = properties + + exceptions: list[ValidationError] = [] + server_address: Optional[str] = tunnel_properties.get("server_address") + server_port: Optional[int] = tunnel_properties.get("server_port") + username: Optional[str] = tunnel_properties.get("username") + private_key: Optional[str] = tunnel_properties.get("private_key") + private_key_password: Optional[str] = tunnel_properties.get("private_key_password") + if not server_address: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) + if not server_port: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) + if not username: + exceptions.append(SSHTunnelRequiredFieldValidationError("username")) + if private_key_password and private_key is None: + exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) + if exceptions: + exception = SSHTunnelInvalidError() + exception.extend(exceptions) + event_logger.log_with_context( + # pylint: disable=consider-using-f-string + action="ssh_tunnel_creation_failed.{}.{}".format( + exception.__class__.__name__, + ".".join(exception.get_list_classnames()), + ) + ) + raise exception + + return True + + class CreateSSHTunnelCommand(BaseCommand): def __init__(self, database: Database, data: dict[str, Any]): self._properties = data.copy() @@ -44,37 +79,14 @@ def run(self) -> Model: self.validate() return SSHTunnelDAO.create(attributes=self._properties, commit=False) except DAOCreateFailedError as ex: + db.session.rollback() raise SSHTunnelCreateFailedError() from ex except SSHTunnelInvalidError as ex: + db.session.rollback() raise ex def validate(self) -> None: - # TODO(hughhh): check to make sure the server port is not localhost - # using the config.SSH_TUNNEL_MANAGER - exceptions: list[ValidationError] = [] - server_address: Optional[str] = self._properties.get("server_address") - server_port: Optional[int] = self._properties.get("server_port") - username: Optional[str] = self._properties.get("username") - private_key: Optional[str] = self._properties.get("private_key") - private_key_password: Optional[str] = self._properties.get( - "private_key_password" - ) - if not server_address: - exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) - if not server_port: - exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) - if not username: - exceptions.append(SSHTunnelRequiredFieldValidationError("username")) - if private_key_password and private_key is None: - exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) - if exceptions: - exception = SSHTunnelInvalidError() - exception.extend(exceptions) - event_logger.log_with_context( - # pylint: disable=consider-using-f-string - action="ssh_tunnel_creation_failed.{}.{}".format( - exception.__class__.__name__, - ".".join(exception.get_list_classnames()), - ) - ) - raise exception + try: + validate_ssh_tunnel(self._properties) + except SSHTunnelInvalidError as ex: + raise ex From 3ad72de317d0f7e8eb6f646c6579302ffb0cd5c3 Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 15:49:01 +0200 Subject: [PATCH 07/14] Clean up --- superset/commands/database/ssh_tunnel/create.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index e295be1c552f1..d5a5e043dad9d 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -38,14 +38,12 @@ def validate_ssh_tunnel(properties: dict[str, Any]) -> bool: # TODO(hughhh): check to make sure the server port is not localhost # using the config.SSH_TUNNEL_MANAGER - tunnel_properties = properties - exceptions: list[ValidationError] = [] - server_address: Optional[str] = tunnel_properties.get("server_address") - server_port: Optional[int] = tunnel_properties.get("server_port") - username: Optional[str] = tunnel_properties.get("username") - private_key: Optional[str] = tunnel_properties.get("private_key") - private_key_password: Optional[str] = tunnel_properties.get("private_key_password") + server_address: Optional[str] = properties.get("server_address") + server_port: Optional[int] = properties.get("server_port") + username: Optional[str] = properties.get("username") + private_key: Optional[str] = properties.get("private_key") + private_key_password: Optional[str] = properties.get("private_key_password") if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) if not server_port: From 83ef2a8836e070e0b11032eb95e059ec1c5a0945 Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 16:52:57 +0200 Subject: [PATCH 08/14] Update tests --- superset/commands/database/create.py | 13 ++++++++----- superset/commands/database/update.py | 2 +- .../databases/ssh_tunnel/commands/create_test.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 686627e70d3dc..4672d00521097 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -98,8 +98,7 @@ def run(self) -> Model: db.session.commit() else: - database = self._do_create_database() - db.session.commit() + database = self._do_create_database(commit=True) # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) @@ -120,7 +119,11 @@ def run(self) -> Model: ) # So we can show the original message raise ex - except DAOCreateFailedError as ex: + except ( + DAOCreateFailedError, + DatabaseInvalidError, + Exception, + ) as ex: db.session.rollback() event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", @@ -157,5 +160,5 @@ def validate(self) -> None: ) raise exception - def _do_create_database(self) -> Database: - return DatabaseDAO.create(attributes=self._properties, commit=False) + def _do_create_database(self, commit: Optional[bool] = False) -> Database: + return DatabaseDAO.create(attributes=self._properties, commit=commit) diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 039d731d72d04..edc0ba1b989d3 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -78,7 +78,7 @@ def run(self) -> Model: if existing_ssh_tunnel_model is None: # We couldn't found an existing tunnel so we need to create one try: - CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run() + CreateSSHTunnelCommand(database, ssh_tunnel_properties).run() except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: # So we can show the original message raise ex diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 4e6411816332f..1777bdc2e10dc 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -60,7 +60,7 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: "private_key_password": "bar", } - command = CreateSSHTunnelCommand(db.id, properties) + command = CreateSSHTunnelCommand(db, properties) with pytest.raises(SSHTunnelInvalidError) as excinfo: command.run() From 0b56028ee736812a1a9f86167a3df17b5d7aedcf Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 18:25:09 +0200 Subject: [PATCH 09/14] Revert tunnel params validation --- superset/commands/database/create.py | 26 +++---- .../commands/database/ssh_tunnel/create.py | 72 +++++++++---------- 2 files changed, 41 insertions(+), 57 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 4672d00521097..686060e1c3fc7 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -30,10 +30,7 @@ DatabaseInvalidError, DatabaseRequiredFieldValidationError, ) -from superset.commands.database.ssh_tunnel.create import ( - CreateSSHTunnelCommand, - validate_ssh_tunnel, -) +from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError, @@ -87,18 +84,14 @@ def run(self) -> Model: if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() - # pre-validate the SSH tunnel properties - is_tunnel_valid = validate_ssh_tunnel(ssh_tunnel_properties) - - if is_tunnel_valid: - database = self._do_create_database() - ssh_tunnel = CreateSSHTunnelCommand( - database, ssh_tunnel_properties - ).run() - - db.session.commit() + database = self._do_create_database(commit=False) + ssh_tunnel = CreateSSHTunnelCommand( + database, ssh_tunnel_properties + ).run() else: - database = self._do_create_database(commit=True) + database = self._do_create_database(commit=False) + + db.session.commit() # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) @@ -122,7 +115,6 @@ def run(self) -> Model: except ( DAOCreateFailedError, DatabaseInvalidError, - Exception, ) as ex: db.session.rollback() event_logger.log_with_context( @@ -160,5 +152,5 @@ def validate(self) -> None: ) raise exception - def _do_create_database(self, commit: Optional[bool] = False) -> Database: + def _do_create_database(self, commit: bool = False) -> Database: return DatabaseDAO.create(attributes=self._properties, commit=commit) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index d5a5e043dad9d..839f8e81cc0c5 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -34,39 +34,6 @@ logger = logging.getLogger(__name__) -def validate_ssh_tunnel(properties: dict[str, Any]) -> bool: - # TODO(hughhh): check to make sure the server port is not localhost - # using the config.SSH_TUNNEL_MANAGER - - exceptions: list[ValidationError] = [] - server_address: Optional[str] = properties.get("server_address") - server_port: Optional[int] = properties.get("server_port") - username: Optional[str] = properties.get("username") - private_key: Optional[str] = properties.get("private_key") - private_key_password: Optional[str] = properties.get("private_key_password") - if not server_address: - exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) - if not server_port: - exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) - if not username: - exceptions.append(SSHTunnelRequiredFieldValidationError("username")) - if private_key_password and private_key is None: - exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) - if exceptions: - exception = SSHTunnelInvalidError() - exception.extend(exceptions) - event_logger.log_with_context( - # pylint: disable=consider-using-f-string - action="ssh_tunnel_creation_failed.{}.{}".format( - exception.__class__.__name__, - ".".join(exception.get_list_classnames()), - ) - ) - raise exception - - return True - - class CreateSSHTunnelCommand(BaseCommand): def __init__(self, database: Database, data: dict[str, Any]): self._properties = data.copy() @@ -75,16 +42,41 @@ def __init__(self, database: Database, data: dict[str, Any]): def run(self) -> Model: try: self.validate() - return SSHTunnelDAO.create(attributes=self._properties, commit=False) + ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False) + return ssh_tunnel except DAOCreateFailedError as ex: - db.session.rollback() raise SSHTunnelCreateFailedError() from ex except SSHTunnelInvalidError as ex: - db.session.rollback() raise ex def validate(self) -> None: - try: - validate_ssh_tunnel(self._properties) - except SSHTunnelInvalidError as ex: - raise ex + # TODO(hughhh): check to make sure the server port is not localhost + # using the config.SSH_TUNNEL_MANAGER + + exceptions: list[ValidationError] = [] + server_address: Optional[str] = self._properties.get("server_address") + server_port: Optional[int] = self._properties.get("server_port") + username: Optional[str] = self._properties.get("username") + private_key: Optional[str] = self._properties.get("private_key") + private_key_password: Optional[str] = self._properties.get( + "private_key_password" + ) + if not server_address: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) + if not server_port: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) + if not username: + exceptions.append(SSHTunnelRequiredFieldValidationError("username")) + if private_key_password and private_key is None: + exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) + if exceptions: + exception = SSHTunnelInvalidError() + exception.extend(exceptions) + event_logger.log_with_context( + # pylint: disable=consider-using-f-string + action="ssh_tunnel_creation_failed.{}.{}".format( + exception.__class__.__name__, + ".".join(exception.get_list_classnames()), + ) + ) + raise exception From 99ae82fce76d8827c584916ab73c48dca9d321c6 Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 18:51:21 +0200 Subject: [PATCH 10/14] Move add database after tunnel op --- superset/commands/database/create.py | 9 +++++---- superset/daos/base.py | 6 ++++-- superset/daos/dashboard.py | 1 + superset/daos/report.py | 1 + 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 686060e1c3fc7..a5a61da3d6cf5 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -84,13 +84,14 @@ def run(self) -> Model: if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() - database = self._do_create_database(commit=False) + database = self._do_create_database() ssh_tunnel = CreateSSHTunnelCommand( database, ssh_tunnel_properties ).run() else: - database = self._do_create_database(commit=False) + database = self._do_create_database() + db.session.add(database) db.session.commit() # adding a new database we always want to force refresh schema list @@ -152,5 +153,5 @@ def validate(self) -> None: ) raise exception - def _do_create_database(self, commit: bool = False) -> Database: - return DatabaseDAO.create(attributes=self._properties, commit=commit) + def _do_create_database(self) -> Database: + return DatabaseDAO.create(attributes=self._properties, commit=False, add=False) diff --git a/superset/daos/base.py b/superset/daos/base.py index 1133a76a1ed06..f8ab23b472b06 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -133,6 +133,7 @@ def create( item: T | None = None, attributes: dict[str, Any] | None = None, commit: bool = True, + add: bool = True, ) -> T: """ Create an object from the specified item and/or attributes. @@ -151,9 +152,10 @@ def create( setattr(item, key, value) try: - db.session.add(item) + if add: + db.session.add(item) - if commit: + if add and commit: db.session.commit() except SQLAlchemyError as ex: # pragma: no cover db.session.rollback() diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index eef46362e2d9a..85d4dbf095a7f 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -373,6 +373,7 @@ def create( item: EmbeddedDashboardDAO | None = None, attributes: dict[str, Any] | None = None, commit: bool = True, + add: bool = True, ) -> Any: """ Use EmbeddedDashboardDAO.upsert() instead. diff --git a/superset/daos/report.py b/superset/daos/report.py index b5db391ec4880..1e0d42d4f42a9 100644 --- a/superset/daos/report.py +++ b/superset/daos/report.py @@ -138,6 +138,7 @@ def create( item: ReportSchedule | None = None, attributes: dict[str, Any] | None = None, commit: bool = True, + add: bool = True, ) -> ReportSchedule: """ Create a report schedule with nested recipients. From d22f6b943597ab2301acad0eedcb2db3c8f41ae5 Mon Sep 17 00:00:00 2001 From: geido Date: Fri, 2 Feb 2024 18:59:07 +0200 Subject: [PATCH 11/14] Clean up --- superset/commands/database/ssh_tunnel/create.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 839f8e81cc0c5..cbfee3ce2ae4c 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -28,7 +28,7 @@ ) from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOCreateFailedError -from superset.extensions import db, event_logger +from superset.extensions import event_logger from superset.models.core import Database logger = logging.getLogger(__name__) From 185deb43b0bba666852896c7799125e9dc7dfa2c Mon Sep 17 00:00:00 2001 From: geido Date: Tue, 6 Feb 2024 16:51:50 +0200 Subject: [PATCH 12/14] Revert changes to BaseDAO --- superset/commands/database/create.py | 9 ++++--- superset/daos/base.py | 6 ++--- superset/daos/dashboard.py | 1 - superset/daos/report.py | 1 - .../integration_tests/databases/api_tests.py | 24 +++++++++---------- 5 files changed, 17 insertions(+), 24 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index a5a61da3d6cf5..17bb38098b352 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -84,14 +84,13 @@ def run(self) -> Model: if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() - database = self._do_create_database() + database = self._create_database() ssh_tunnel = CreateSSHTunnelCommand( database, ssh_tunnel_properties ).run() else: - database = self._do_create_database() + database = self._create_database() - db.session.add(database) db.session.commit() # adding a new database we always want to force refresh schema list @@ -153,5 +152,5 @@ def validate(self) -> None: ) raise exception - def _do_create_database(self) -> Database: - return DatabaseDAO.create(attributes=self._properties, commit=False, add=False) + def _create_database(self) -> Database: + return DatabaseDAO.create(attributes=self._properties, commit=False) diff --git a/superset/daos/base.py b/superset/daos/base.py index f8ab23b472b06..1133a76a1ed06 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -133,7 +133,6 @@ def create( item: T | None = None, attributes: dict[str, Any] | None = None, commit: bool = True, - add: bool = True, ) -> T: """ Create an object from the specified item and/or attributes. @@ -152,10 +151,9 @@ def create( setattr(item, key, value) try: - if add: - db.session.add(item) + db.session.add(item) - if add and commit: + if commit: db.session.commit() except SQLAlchemyError as ex: # pragma: no cover db.session.rollback() diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index 85d4dbf095a7f..eef46362e2d9a 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -373,7 +373,6 @@ def create( item: EmbeddedDashboardDAO | None = None, attributes: dict[str, Any] | None = None, commit: bool = True, - add: bool = True, ) -> Any: """ Use EmbeddedDashboardDAO.upsert() instead. diff --git a/superset/daos/report.py b/superset/daos/report.py index 1e0d42d4f42a9..b5db391ec4880 100644 --- a/superset/daos/report.py +++ b/superset/daos/report.py @@ -138,7 +138,6 @@ def create( item: ReportSchedule | None = None, attributes: dict[str, Any] | None = None, commit: bool = True, - add: bool = True, ) -> ReportSchedule: """ Create a report schedule with nested recipients. diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 0bc1f245a1f7b..3e27fec5aae0a 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -538,26 +538,28 @@ def test_cascade_delete_ssh_tunnel( @mock.patch( "superset.models.core.Database.get_all_schema_names", ) + @mock.patch("superset.extensions.db.session.rollback") def test_do_not_create_database_if_ssh_tunnel_creation_fails( self, + mock_rollback, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_get_all_schema_names, ): """ - Database API: Test Database is not created if SSH Tunnel creation fails + Database API: Test rollback is called if SSH Tunnel creation fails """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") - example_db = get_example_database() - if example_db.backend == "sqlite": - return + # example_db = get_example_database() + # if example_db.backend == "sqlite": + # return ssh_tunnel_properties = { "server_address": "123.132.123.1", } database_data = { "database_name": "test-db-failure-ssh-tunnel", - "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "sqlalchemy_uri": "hive://", "ssh_tunnel": ssh_tunnel_properties, } fail_message = {"message": "SSH Tunnel parameters are invalid."} @@ -566,6 +568,7 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) + model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) @@ -573,14 +576,9 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( ) assert model_ssh_tunnel is None self.assertEqual(response, fail_message) - # Cleanup - model = ( - db.session.query(Database) - .filter(Database.database_name == "test-db-failure-ssh-tunnel") - .one_or_none() - ) - # the DB should not be created - assert model is None + + # Check that rollback was called + mock_rollback.assert_called() @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", From 0ead57d1dafe90bc71e89de96e894bbc0962cdaf Mon Sep 17 00:00:00 2001 From: geido Date: Wed, 7 Feb 2024 15:50:24 +0200 Subject: [PATCH 13/14] Clean up --- superset/commands/database/create.py | 5 ++--- tests/integration_tests/databases/api_tests.py | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 17bb38098b352..cf97ea001e431 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -80,16 +80,15 @@ def run(self) -> Model: ssh_tunnel = None try: + database = self._create_database() + if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() - database = self._create_database() ssh_tunnel = CreateSSHTunnelCommand( database, ssh_tunnel_properties ).run() - else: - database = self._create_database() db.session.commit() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 3e27fec5aae0a..f7b8cc0ec8cd2 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -551,15 +551,15 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") - # example_db = get_example_database() - # if example_db.backend == "sqlite": - # return + example_db = get_example_database() + if example_db.backend == "sqlite": + return ssh_tunnel_properties = { "server_address": "123.132.123.1", } database_data = { "database_name": "test-db-failure-ssh-tunnel", - "sqlalchemy_uri": "hive://", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } fail_message = {"message": "SSH Tunnel parameters are invalid."} From 9a93efa5075d50f0f31d9d3916fde4489c77d521 Mon Sep 17 00:00:00 2001 From: geido Date: Wed, 7 Feb 2024 16:46:37 +0200 Subject: [PATCH 14/14] Set sqlalchemy uri --- superset/commands/database/create.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index cf97ea001e431..cde9dd8e884b2 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -114,6 +114,7 @@ def run(self) -> Model: except ( DAOCreateFailedError, DatabaseInvalidError, + Exception, ) as ex: db.session.rollback() event_logger.log_with_context( @@ -152,4 +153,6 @@ def validate(self) -> None: raise exception def _create_database(self) -> Database: - return DatabaseDAO.create(attributes=self._properties, commit=False) + database = DatabaseDAO.create(attributes=self._properties, commit=False) + database.set_sqlalchemy_uri(database.sqlalchemy_uri) + return database