Skip to content

Commit

Permalink
chore: Remove database ID dependency for SSH Tunnel creation (#26989)
Browse files Browse the repository at this point in the history
  • Loading branch information
geido authored Feb 7, 2024
1 parent 43e1dc4 commit d8e26cf
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 72 deletions.
58 changes: 32 additions & 26 deletions superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
from superset.models.core import Database

logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
Expand Down Expand Up @@ -76,34 +77,20 @@ def run(self) -> Model:
"{}",
)

ssh_tunnel = None

try:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
database = self._create_database()

ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseCreateFailedError() from ex

ssh_tunnel = CreateSSHTunnelCommand(
database, ssh_tunnel_properties
).run()

db.session.commit()

# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
Expand All @@ -112,9 +99,23 @@ def run(self) -> Model:
"schema_access", security_manager.get_schema_perm(database, schema)
)

db.session.commit()

except DAOCreateFailedError as ex:
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except (
DAOCreateFailedError,
DatabaseInvalidError,
Exception,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
Expand Down Expand Up @@ -150,3 +151,8 @@ def validate(self) -> None:
)
)
raise exception

def _create_database(self) -> Database:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database
23 changes: 7 additions & 16 deletions superset/commands/database/ssh_tunnel/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,39 @@
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.extensions import db, event_logger
from superset.extensions import event_logger
from superset.models.core import Database

logger = logging.getLogger(__name__)


class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: dict[str, Any]):
def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
self._properties["database"] = database

def run(self) -> Model:
try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False)
return ssh_tunnel
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex

def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER

exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if not database_id:
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def run(self) -> Model:
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
Expand Down
16 changes: 7 additions & 9 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,16 @@ def test_cascade_delete_ssh_tunnel(
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.extensions.db.session.rollback")
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self,
mock_rollback,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test Database is not created if SSH Tunnel creation fails
Database API: Test rollback is called if SSH Tunnel creation fails
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
Expand All @@ -566,21 +568,17 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails(
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)

model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
.one_or_none()
)
# the DB should not be created
assert model is None

# Check that rollback was called
mock_rollback.assert_called()

@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,6 @@
from tests.integration_tests.base_tests import SupersetTestCase


class TestCreateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_create_invalid_database_id(self, mock_g):
mock_g.user = security_manager.find_user("admin")
command = CreateSSHTunnelCommand(
None,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")


class TestUpdateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_update_ssh_tunnel_not_found(self, mock_g):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None:
"password": "bar",
}

result = CreateSSHTunnelCommand(db.id, properties).run()
result = CreateSSHTunnelCommand(db, properties).run()

assert result is not None
assert isinstance(result, SSHTunnel)
Expand All @@ -53,14 +53,14 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database_id": db.id,
"database": db,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}

command = CreateSSHTunnelCommand(db.id, properties)
command = CreateSSHTunnelCommand(db, properties)

with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
Expand Down

0 comments on commit d8e26cf

Please sign in to comment.