Skip to content

Commit

Permalink
feat: allow configuring an engine context manager (#30266)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Sep 23, 2024
1 parent ee3a567 commit 710406a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 34 deletions.
32 changes: 25 additions & 7 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
import re
import sys
from collections import OrderedDict
from contextlib import contextmanager
from datetime import timedelta
from email.mime.multipart import MIMEMultipart
from importlib.resources import files
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
from typing import Any, Callable, Iterator, Literal, TYPE_CHECKING, TypedDict

import click
import pkg_resources
Expand Down Expand Up @@ -1146,16 +1147,18 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# uploading CSVs will be stored.
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None


# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_file_upload
# db configuration and a result of this function.
def allowed_schemas_for_csv_upload( # pylint: disable=unused-argument
database: Database,
user: models.User,
) -> list[str]:
return [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else []

# mypy doesn't catch that if case ensures list content being always str
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( # noqa: E731
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
)

ALLOWED_USER_CSV_SCHEMA_FUNC = allowed_schemas_for_csv_upload

# Values that should be treated as nulls for the csv uploads.
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
Expand Down Expand Up @@ -1266,6 +1269,21 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# The id of a template dashboard that should be copied to every new user
DASHBOARD_TEMPLATE_ID = None


# A context manager that wraps the call to `create_engine`. This can be used for many
# things, such as chrooting to prevent 3rd party drivers to access the filesystem, or
# setting up custom configuration for database drivers.
@contextmanager
def engine_context_manager( # pylint: disable=unused-argument
database: Database,
catalog: str | None,
schema: str | None,
) -> Iterator[None]:
yield None


ENGINE_CONTEXT_MANAGER = engine_context_manager

# A callable that allows altering the database connection URL and params
# on the fly, at runtime. This allows for things like impersonation or
# arbitrary logic. For instance you can wire different users to
Expand Down
40 changes: 21 additions & 19 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,38 +418,40 @@ def get_sqla_engine( # pylint: disable=too-many-arguments
)

sqlalchemy_uri = self.sqlalchemy_uri_decrypted
engine_context = nullcontext()
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
database_id=self.id
)

if ssh_tunnel:
# if ssh_tunnel is available build engine with information
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(self.id)
ssh_context_manager = (
ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
sqlalchemy_database_uri=sqlalchemy_uri,
)
if ssh_tunnel
else nullcontext()
)

with engine_context as server_context:
if ssh_tunnel and server_context:
with ssh_context_manager as ssh_context:
if ssh_context:
logger.info(
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s",
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
"ssh_timeout at %s",
sshtunnel.TUNNEL_TIMEOUT,
sshtunnel.SSH_TIMEOUT,
server_context.local_bind_address,
ssh_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri,
server_context,
ssh_context,
)

yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)

def _get_sqla_engine( # pylint: disable=too-many-locals
self,
Expand Down
37 changes: 29 additions & 8 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Tests for ``get_prequeries``.
"""
mocker.patch.object(
Database,
"get_sqla_engine",
)
mocker.patch.object(Database, "get_sqla_engine")
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]

Expand Down Expand Up @@ -397,10 +394,7 @@ def test_get_sqla_engine(mocker: MockerFixture) -> None:

create_engine = mocker.patch("superset.models.core.create_engine")

database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
)
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
database._get_sqla_engine(nullpool=False)

create_engine.assert_called_with(
Expand Down Expand Up @@ -556,3 +550,30 @@ def test_get_schema_access_for_file_upload() -> None:
)

assert database.get_schema_access_for_file_upload() == {"public"}


def test_engine_context_manager(mocker: MockerFixture) -> None:
"""
Test the engine context manager.
"""
engine_context_manager = mocker.MagicMock()
mocker.patch(
"superset.models.core.config",
new={"ENGINE_CONTEXT_MANAGER": engine_context_manager},
)
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")

database = Database(database_name="my_db", sqlalchemy_uri="trino://")
with database.get_sqla_engine("catalog", "schema"):
pass

engine_context_manager.assert_called_once_with(database, "catalog", "schema")
engine_context_manager().__enter__.assert_called_once()
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
_get_sqla_engine.assert_called_once_with(
catalog="catalog",
schema="schema",
nullpool=True,
source=None,
sqlalchemy_uri="trino://",
)

0 comments on commit 710406a

Please sign in to comment.