diff --git a/superset/config.py b/superset/config.py index 34c36c210fb9f..a7b7e655ecbb9 100644 --- a/superset/config.py +++ b/superset/config.py @@ -38,6 +38,7 @@ List, Literal, Optional, + Set, Type, TYPE_CHECKING, Union, @@ -1107,6 +1108,12 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # in security manager EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None +# For database connections, this dictionary will remove engines from the available +# list/dropdown if you do not want these dbs to show as available. +# The available list is generated by driver installed, and some engines have multiple +# drivers. +# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": ("pyhive", "pyodbc")} +DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {} # This auth provider is used by background (offline) tasks that need to access # protected resources. Can be overridden by end users in order to support diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index 29e4877337b61..f19dffd4a3bbe 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -41,6 +41,7 @@ from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.url import URL +from superset import app from superset.db_engine_specs.base import BaseEngineSpec logger = logging.getLogger(__name__) @@ -72,7 +73,6 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]: for attr in module.__dict__ if is_engine_spec(getattr(module, attr)) ) - # load additional engines from external modules for ep in iter_entry_points("superset.db_engine_specs"): try: @@ -170,6 +170,17 @@ def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]: for engine_spec in load_engine_specs(): driver = drivers[engine_spec.engine] + # do not add denied db engine specs to available list + dbs_denylist = app.config["DBS_AVAILABLE_DENYLIST"] + dbs_denylist_engines = dbs_denylist.keys() + + if ( + engine_spec.engine in dbs_denylist_engines + and hasattr(engine_spec, "default_driver") + and engine_spec.default_driver in dbs_denylist[engine_spec.engine] + ): + continue + # lookup driver by engine aliases. if not driver and engine_spec.engine_aliases: for alias in engine_spec.engine_aliases: diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 935e128dc9b99..5ad6af888a17f 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Iterator import pytest +from _pytest.fixtures import SubRequest from pytest_mock import MockFixture from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -68,7 +69,7 @@ def session(get_session) -> Iterator[Session]: @pytest.fixture(scope="module") -def app() -> Iterator[SupersetApp]: +def app(request: SubRequest) -> Iterator[SupersetApp]: """ A fixture that generates a Superset app. """ @@ -82,6 +83,11 @@ def app() -> Iterator[SupersetApp]: app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False app.config["TESTING"] = True + # loop over extra configs passed in by tests + if request and hasattr(request, "param"): + for key, val in request.param.items(): + app.config[key] = val + # ``superset.extensions.appbuilder`` is a singleton, and won't rebuild the # routes when this fixture is called multiple times; we need to clear the # registered views to ensure the initialization can happen more than once. diff --git a/tests/unit_tests/db_engine_specs/test_init.py b/tests/unit_tests/db_engine_specs/test_init.py new file mode 100644 index 0000000000000..3189256c70f12 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_init.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import pytest +from pkg_resources import EntryPoint +from pytest_mock import MockFixture + +from superset.db_engine_specs import get_available_engine_specs + + +def test_get_available_engine_specs(mocker: MockFixture) -> None: + """ + get_available_engine_specs should return all engine specs + """ + from superset.db_engine_specs.databricks import ( + DatabricksHiveEngineSpec, + DatabricksNativeEngineSpec, + DatabricksODBCEngineSpec, + ) + + mocker.patch( + "superset.db_engine_specs.load_engine_specs", + return_value=iter( + [ + DatabricksHiveEngineSpec, + DatabricksNativeEngineSpec, + DatabricksODBCEngineSpec, + ] + ), + ) + + assert list(get_available_engine_specs().keys()) == [ + DatabricksHiveEngineSpec, + DatabricksNativeEngineSpec, + DatabricksODBCEngineSpec, + ] + + +@pytest.mark.parametrize( + "app", + [{"DBS_AVAILABLE_DENYLIST": {"databricks": {"pyhive", "pyodbc"}}}], + indirect=True, +) +def test_get_available_engine_specs_with_denylist(mocker: MockFixture) -> None: + """ + The denylist removes items from the db engine spec list + """ + from superset.db_engine_specs.databricks import ( + DatabricksHiveEngineSpec, + DatabricksNativeEngineSpec, + DatabricksODBCEngineSpec, + ) + + mocker.patch( + "superset.db_engine_specs.load_engine_specs", + return_value=iter( + [ + DatabricksHiveEngineSpec, + DatabricksNativeEngineSpec, + DatabricksODBCEngineSpec, + ] + ), + ) + available = get_available_engine_specs() + assert list(available.keys()) == [DatabricksNativeEngineSpec]