Skip to content

Commit

Permalink
feat: default ports for SSH tunnel (apache#32403)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Feb 27, 2025
1 parent 74733ae commit f4105e9
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 15 deletions.
5 changes: 4 additions & 1 deletion superset/commands/database/ssh_tunnel/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
from superset.utils.ssh_tunnel import get_default_port

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +73,9 @@ def validate(self) -> None:
"private_key_password"
)
url = make_url_safe(self._database.sqlalchemy_uri)
if not url.port:
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/database/ssh_tunnel/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.utils.decorators import on_error, transaction
from superset.utils.ssh_tunnel import get_default_port

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,5 +76,7 @@ def validate(self) -> None:
raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
)
if not url.port:
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
5 changes: 4 additions & 1 deletion superset/extensions/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ def create_tunnel(
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> sshtunnel.SSHTunnelForwarder:
from superset.utils.ssh_tunnel import get_default_port

url = make_url_safe(sqlalchemy_database_uri)
backend = url.get_backend_name()
params = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (url.host, url.port),
"remote_bind_address": (url.host, url.port or get_default_port(backend)),
"local_bind_address": (self.local_bind_address,),
"debug_level": logging.getLogger("flask_appbuilder").level,
}
Expand Down
14 changes: 14 additions & 0 deletions superset/utils/ssh_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from superset.constants import PASSWORD_MASK
from superset.databases.ssh_tunnel.models import SSHTunnel

DEFAULT_PORTS: dict[str, int] = {
"postgresql": 5432,
"mysql": 3306,
"oracle": 1521,
"mssql": 1433,
}


def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]:
if ssh_tunnel.pop("password", None) is not None:
Expand All @@ -41,3 +48,10 @@ def unmask_password_info(
if ssh_tunnel.get("private_key_password") == PASSWORD_MASK:
ssh_tunnel["private_key_password"] = model.private_key_password
return ssh_tunnel


def get_default_port(backend: str) -> int | None:
"""
Get the default port for the given backend.
"""
return DEFAULT_PORTS.get(backend)
121 changes: 115 additions & 6 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,18 @@ def test_create_database_with_ssh_tunnel(
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_missing_port_raises_error(
def test_create_database_with_ssh_tunnel_no_port(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
Database API: Test create with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
Expand All @@ -369,13 +369,58 @@ def test_create_database_with_missing_port_raises_error(
"username": "foo",
"password": "bar",
}

database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" # noqa: S105
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()

@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel_no_port_no_default(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return

modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db"

ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
Expand Down Expand Up @@ -459,7 +504,71 @@ def test_update_database_with_ssh_tunnel(
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_missing_port_raises_error(
def test_update_database_with_ssh_tunnel_no_port(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test update Database with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return

modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"

ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201

uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200

model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response_update.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()

@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_no_port_no_default(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
Expand All @@ -477,7 +586,7 @@ def test_update_database_with_missing_port_raises_error(
if example_db.backend == "sqlite":
return

modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db"

ssh_tunnel_properties = {
"server_address": "123.132.123.1",
Expand Down
57 changes: 51 additions & 6 deletions tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@


import pytest
from sqlalchemy.orm.session import Session

from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)


def test_create_ssh_tunnel_command() -> None:
def test_create_ssh_tunnel_command(session: Session) -> None:
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
Expand All @@ -49,12 +53,15 @@ def test_create_ssh_tunnel_command() -> None:
assert isinstance(result, SSHTunnel)


def test_create_ssh_tunnel_command_invalid_params() -> None:
def test_create_ssh_tunnel_command_invalid_params(session: Session) -> None:
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
Expand All @@ -76,12 +83,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")


def test_create_ssh_tunnel_command_no_port() -> None:
def test_create_ssh_tunnel_command_no_port(session: Session) -> None:
"""
Test that SSH Tunnel can be created without explicit port but with a default one.
"""
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost/db",
)
Expand All @@ -94,6 +108,37 @@ def test_create_ssh_tunnel_command_no_port() -> None:
"password": "bar",
}

result = CreateSSHTunnelCommand(database, properties).run()

assert result is not None
assert isinstance(result, SSHTunnel)


def test_create_ssh_tunnel_command_no_port_no_default(session: Session) -> None:
"""
Test that error is raised when creating SSH Tunnel without explicit/default ports.
"""
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.models.core import Database

engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="weird+db://u:p@localhost/db",
)

properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}

command = CreateSSHTunnelCommand(database, properties)

with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
"""
Test that SSH Tunnel can be updated without explicit port but with a default one.
"""
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel

result = DatabaseDAO.get_ssh_tunnel(1)

assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address

update_payload = {"server_address": "Test2"}
UpdateSSHTunnelCommand(1, update_payload).run()

result = DatabaseDAO.get_ssh_tunnel(1)

assert result
assert isinstance(result, SSHTunnel)
assert "Test2" == result.server_address


@pytest.mark.parametrize(
"session_with_data", ["weird+db://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port_no_default(session_with_data: Session) -> None:
"""
Test that error is raised when updating SSH Tunnel without explicit/default ports.
"""
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
Expand Down

0 comments on commit f4105e9

Please sign in to comment.