Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Remove database ID dependency for SSH Tunnel creation #26989

Merged
merged 14 commits into from
Feb 7, 2024
58 changes: 32 additions & 26 deletions superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
from superset.models.core import Database

logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
Expand Down Expand Up @@ -76,34 +77,20 @@
"{}",
)

ssh_tunnel = None

try:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
database = self._create_database()

ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseCreateFailedError() from ex

ssh_tunnel = CreateSSHTunnelCommand(
database, ssh_tunnel_properties
).run()

db.session.commit()

# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
Expand All @@ -112,9 +99,23 @@
"schema_access", security_manager.get_schema_perm(database, schema)
)

db.session.commit()

except DAOCreateFailedError as ex:
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except (

Check warning on line 114 in superset/commands/database/create.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/database/create.py#L114

Added line #L114 was not covered by tests
DAOCreateFailedError,
DatabaseInvalidError,
Exception,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
Expand Down Expand Up @@ -150,3 +151,8 @@
)
)
raise exception

def _create_database(self) -> Database:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database
23 changes: 7 additions & 16 deletions superset/commands/database/ssh_tunnel/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,39 @@
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.extensions import db, event_logger
from superset.extensions import event_logger
from superset.models.core import Database

logger = logging.getLogger(__name__)


class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: dict[str, Any]):
def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
self._properties["database"] = database

def run(self) -> Model:
try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False)
geido marked this conversation as resolved.
Show resolved Hide resolved
return ssh_tunnel
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex

def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER

exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if not database_id:
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def run(self) -> Model:
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
Expand Down
16 changes: 7 additions & 9 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,16 @@ def test_cascade_delete_ssh_tunnel(
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.extensions.db.session.rollback")
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self,
mock_rollback,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test Database is not created if SSH Tunnel creation fails
Database API: Test rollback is called if SSH Tunnel creation fails
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
Expand All @@ -566,21 +568,17 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails(
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)

model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
.one_or_none()
)
# the DB should not be created
assert model is None

# Check that rollback was called
mock_rollback.assert_called()
geido marked this conversation as resolved.
Show resolved Hide resolved

@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,6 @@
from tests.integration_tests.base_tests import SupersetTestCase


class TestCreateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_create_invalid_database_id(self, mock_g):
mock_g.user = security_manager.find_user("admin")
command = CreateSSHTunnelCommand(
None,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")


class TestUpdateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_update_ssh_tunnel_not_found(self, mock_g):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None:
"password": "bar",
}

result = CreateSSHTunnelCommand(db.id, properties).run()
result = CreateSSHTunnelCommand(db, properties).run()

assert result is not None
assert isinstance(result, SSHTunnel)
Expand All @@ -53,14 +53,14 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database_id": db.id,
"database": db,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}

command = CreateSSHTunnelCommand(db.id, properties)
command = CreateSSHTunnelCommand(db, properties)

with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
Expand Down
Loading