Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Remove database ID dependency for SSH Tunnel creation #26989

Merged
merged 14 commits into from
Feb 7, 2024
59 changes: 32 additions & 27 deletions superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
from superset.models.core import Database

logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
Expand Down Expand Up @@ -76,34 +77,22 @@
"{}",
)

try:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = None

ssh_tunnel = None
try:
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseCreateFailedError() from ex

database = self._do_create_database()
geido marked this conversation as resolved.
Show resolved Hide resolved
ssh_tunnel = CreateSSHTunnelCommand(
database, ssh_tunnel_properties
).run()
else:
database = self._do_create_database()

db.session.add(database)
db.session.commit()

# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
Expand All @@ -112,9 +101,22 @@
"schema_access", security_manager.get_schema_perm(database, schema)
)

db.session.commit()

except DAOCreateFailedError as ex:
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except (

Check warning on line 116 in superset/commands/database/create.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/database/create.py#L116

Added line #L116 was not covered by tests
DAOCreateFailedError,
DatabaseInvalidError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
Expand Down Expand Up @@ -150,3 +152,6 @@
)
)
raise exception

def _do_create_database(self) -> Database:
return DatabaseDAO.create(attributes=self._properties, commit=False, add=False)
23 changes: 7 additions & 16 deletions superset/commands/database/ssh_tunnel/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,39 @@
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.extensions import db, event_logger
from superset.extensions import event_logger
from superset.models.core import Database

logger = logging.getLogger(__name__)


class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: dict[str, Any]):
def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
self._properties["database"] = database

def run(self) -> Model:
try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False)
geido marked this conversation as resolved.
Show resolved Hide resolved
return ssh_tunnel
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex

def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER

exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if not database_id:
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def run(self) -> Model:
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
Expand Down
6 changes: 4 additions & 2 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def create(
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
add: bool = True,
geido marked this conversation as resolved.
Show resolved Hide resolved
) -> T:
"""
Create an object from the specified item and/or attributes.
Expand All @@ -151,9 +152,10 @@ def create(
setattr(item, key, value)

try:
db.session.add(item)
if add:
db.session.add(item)

if commit:
if add and commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
Expand Down
1 change: 1 addition & 0 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def create(
item: EmbeddedDashboardDAO | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
add: bool = True,
) -> Any:
"""
Use EmbeddedDashboardDAO.upsert() instead.
Expand Down
1 change: 1 addition & 0 deletions superset/daos/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def create(
item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
add: bool = True,
) -> ReportSchedule:
"""
Create a report schedule with nested recipients.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,6 @@
from tests.integration_tests.base_tests import SupersetTestCase


class TestCreateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_create_invalid_database_id(self, mock_g):
mock_g.user = security_manager.find_user("admin")
command = CreateSSHTunnelCommand(
None,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")


class TestUpdateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_update_ssh_tunnel_not_found(self, mock_g):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None:
"password": "bar",
}

result = CreateSSHTunnelCommand(db.id, properties).run()
result = CreateSSHTunnelCommand(db, properties).run()

assert result is not None
assert isinstance(result, SSHTunnel)
Expand All @@ -53,14 +53,14 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database_id": db.id,
"database": db,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}

command = CreateSSHTunnelCommand(db.id, properties)
command = CreateSSHTunnelCommand(db, properties)

with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
Expand Down
Loading