Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow configuring an engine context manager #30266

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
import re
import sys
from collections import OrderedDict
from contextlib import contextmanager
from datetime import timedelta
from email.mime.multipart import MIMEMultipart
from importlib.resources import files
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
from typing import Any, Callable, Iterator, Literal, TYPE_CHECKING, TypedDict

import click
import pkg_resources
Expand Down Expand Up @@ -1142,16 +1143,18 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# uploading CSVs will be stored.
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None


# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_file_upload
# db configuration and a result of this function.
def allowed_schemas_for_csv_upload( # pylint: disable=unused-argument
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote this as a function since it's easier to read and removes the lambda complaint from the linter.

database: Database,
user: models.User,
) -> list[str]:
return [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else []

# mypy doesn't catch that if case ensures list content being always str
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( # noqa: E731
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
)

ALLOWED_USER_CSV_SCHEMA_FUNC = allowed_schemas_for_csv_upload

# Values that should be treated as nulls for the csv uploads.
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
Expand Down Expand Up @@ -1262,6 +1265,21 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# The id of a template dashboard that should be copied to every new user
DASHBOARD_TEMPLATE_ID = None


# A context manager that wraps the call to `create_engine`. This can be used for many
# things, such as chrooting to prevent 3rd party drivers to access the filesystem, or
# setting up custom configuration for database drivers.
@contextmanager
def engine_context_manager( # pylint: disable=unused-argument
database: Database,
catalog: str | None,
schema: str | None,
) -> Iterator[None]:
yield None


ENGINE_CONTEXT_MANAGER = engine_context_manager

# A callable that allows altering the database connection URL and params
# on the fly, at runtime. This allows for things like impersonation or
# arbitrary logic. For instance you can wire different users to
Expand Down
40 changes: 21 additions & 19 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,38 +418,40 @@ def get_sqla_engine( # pylint: disable=too-many-arguments
)

sqlalchemy_uri = self.sqlalchemy_uri_decrypted
engine_context = nullcontext()
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
database_id=self.id
)

if ssh_tunnel:
# if ssh_tunnel is available build engine with information
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(self.id)
ssh_context_manager = (
ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
sqlalchemy_database_uri=sqlalchemy_uri,
)
if ssh_tunnel
else nullcontext()
)

with engine_context as server_context:
if ssh_tunnel and server_context:
with ssh_context_manager as ssh_context:
if ssh_context:
logger.info(
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s",
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
"ssh_timeout at %s",
sshtunnel.TUNNEL_TIMEOUT,
sshtunnel.SSH_TIMEOUT,
server_context.local_bind_address,
ssh_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri,
server_context,
ssh_context,
)

yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)

def _get_sqla_engine( # pylint: disable=too-many-locals
self,
Expand Down
37 changes: 29 additions & 8 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Tests for ``get_prequeries``.
"""
mocker.patch.object(
Database,
"get_sqla_engine",
)
mocker.patch.object(Database, "get_sqla_engine")
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]

Expand Down Expand Up @@ -397,10 +394,7 @@ def test_get_sqla_engine(mocker: MockerFixture) -> None:

create_engine = mocker.patch("superset.models.core.create_engine")

database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
)
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
database._get_sqla_engine(nullpool=False)

create_engine.assert_called_with(
Expand Down Expand Up @@ -556,3 +550,30 @@ def test_get_schema_access_for_file_upload() -> None:
)

assert database.get_schema_access_for_file_upload() == {"public"}


def test_engine_context_manager(mocker: MockerFixture) -> None:
"""
Test the engine context manager.
"""
engine_context_manager = mocker.MagicMock()
mocker.patch(
"superset.models.core.config",
new={"ENGINE_CONTEXT_MANAGER": engine_context_manager},
)
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")

database = Database(database_name="my_db", sqlalchemy_uri="trino://")
with database.get_sqla_engine("catalog", "schema"):
pass

engine_context_manager.assert_called_once_with(database, "catalog", "schema")
engine_context_manager().__enter__.assert_called_once()
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
_get_sqla_engine.assert_called_once_with(
catalog="catalog",
schema="schema",
nullpool=True,
source=None,
sqlalchemy_uri="trino://",
)
Loading