diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 548fb390d8f89..cd37e4e60218e 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1610,9 +1610,11 @@ def where_latest_partition( # pylint: disable=unused-argument @classmethod def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: return [ - literal_column(query_as) - if (query_as := c.get("query_as")) - else column(c["column_name"]) + ( + literal_column(query_as) + if (query_as := c.get("query_as")) + else column(c["column_name"]) + ) for c in cols ] @@ -1828,13 +1830,18 @@ def execute( # pylint: disable=unused-argument cursor.arraysize = cls.arraysize try: cursor.execute(query) - except cls.oauth2_exception as ex: - if database.is_oauth2_enabled() and g and g.user: - cls.start_oauth2_dance(database) - raise cls.get_dbapi_mapped_exception(ex) from ex except Exception as ex: + if database.is_oauth2_enabled() and cls.needs_oauth2(ex): + cls.start_oauth2_dance(database) raise cls.get_dbapi_mapped_exception(ex) from ex + @classmethod + def needs_oauth2(cls, ex: Exception) -> bool: + """ + Check if the exception is one that indicates OAuth2 is needed. + """ + return g and hasattr(g, "user") and isinstance(ex, cls.oauth2_exception) + @classmethod def make_label_compatible(cls, label: str) -> str | quoted_name: """ diff --git a/superset/models/core.py b/superset/models/core.py index e6d97a197b04d..c8c875e4358f6 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -29,7 +29,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, cast, TYPE_CHECKING import numpy import pandas as pd @@ -78,7 +78,7 @@ from superset.utils import cache as cache_util, core as utils, json from superset.utils.backports import StrEnum from superset.utils.core import DatasourceName, get_username -from superset.utils.oauth2 import get_oauth2_access_token +from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema config = app.config custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] @@ -554,17 +554,23 @@ def get_raw_connection( nullpool=nullpool, source=source, ) as engine: - with closing(engine.raw_connection()) as conn: - # pre-session queries are used to set the selected schema and, in the - # future, the selected catalog - for prequery in self.db_engine_spec.get_prequeries( - catalog=catalog, - schema=schema, - ): - cursor = conn.cursor() - cursor.execute(prequery) + try: + with closing(engine.raw_connection()) as conn: + # pre-session queries are used to set the selected schema and, in the + # future, the selected catalog + for prequery in self.db_engine_spec.get_prequeries( + catalog=catalog, + schema=schema, + ): + cursor = conn.cursor() + cursor.execute(prequery) - yield conn + yield conn + + except Exception as ex: + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.db_engine_spec.start_oauth2_dance(self) + raise ex def get_default_catalog(self) -> str | None: """ @@ -1063,20 +1069,30 @@ def is_oauth2_enabled(self) -> bool: """ Is OAuth2 enabled in the database for authentication? - Currently this looks for a global config at the DB engine spec level, but in the - future we want to be allow admins to create custom OAuth2 clients from the - Superset UI, and assign them to specific databases. + Currently this checks for configuration stored in the database `extra`, and then + for a global config at the DB engine spec level. In the future we want to allow + admins to create custom OAuth2 clients from the Superset UI, and assign them to + specific databases. """ - return self.db_engine_spec.is_oauth2_enabled() + encrypted_extra = json.loads(self.encrypted_extra or "{}") + oauth2_client_info = encrypted_extra.get("oauth2_client_info", {}) + return bool(oauth2_client_info) or self.db_engine_spec.is_oauth2_enabled() def get_oauth2_config(self) -> OAuth2ClientConfig | None: """ Return OAuth2 client configuration. - This includes client ID, client secret, scope, redirect URI, endpointsm etc. - Currently this reads the global DB engine spec config, but in the future it - should first check if there's a custom client assigned to the database. + Currently this checks for configuration stored in the database `extra`, and then + for a global config at the DB engine spec level. In the future we want to allow + admins to create custom OAuth2 clients from the Superset UI, and assign them to + specific databases. """ + encrypted_extra = json.loads(self.encrypted_extra or "{}") + if oauth2_client_info := encrypted_extra.get("oauth2_client_info"): + schema = OAuth2ClientConfigSchema() + client_config = schema.load(oauth2_client_info) + return cast(OAuth2ClientConfig, client_config) + return self.db_engine_spec.get_oauth2_config() diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 9cc58a0b7ffca..bc4805fd81924 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -22,7 +22,7 @@ import backoff import jwt -from flask import current_app +from flask import current_app, url_for from marshmallow import EXCLUDE, fields, post_load, Schema from superset import db @@ -180,3 +180,15 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State: state = oauth2_state_schema.load(payload) return state + + +class OAuth2ClientConfigSchema(Schema): + id = fields.String(required=True) + secret = fields.String(required=True) + scope = fields.String(required=True) + redirect_uri = fields.String( + required=False, + load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True), + ) + authorization_request_uri = fields.String(required=True) + token_request_uri = fields.String(required=True) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 3905a15b32b0c..2b8f39d6dd09b 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -90,6 +90,9 @@ def app(request: SubRequest) -> Iterator[SupersetApp]: app.config["RATELIMIT_ENABLED"] = False app.config["CACHE_CONFIG"] = {} app.config["DATA_CACHE_CONFIG"] = {} + app.config["SERVER_NAME"] = "example.com" + app.config["APPLICATION_ROOT"] = "/" + app.config["PREFERRED_URL_SCHEME="] = "http" # loop over extra configs passed in by tests # and update the app config diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 2004ff482fe21..c4d642baf5b69 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=import-outside-toplevel + from datetime import datetime import pytest @@ -24,11 +25,23 @@ from sqlalchemy.engine.url import make_url from superset.connectors.sqla.models import SqlaTable, TableColumn +from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.models.core import Database from superset.sql_parse import Table from superset.utils import json from tests.unit_tests.conftest import with_feature_flags +# sample config for OAuth2 tests +oauth2_client_info = { + "oauth2_client_info": { + "id": "my_client_id", + "secret": "my_client_secret", + "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", + "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", + "scope": "refresh_token session:role:SYSADMIN", + } +} + def test_get_metrics(mocker: MockFixture) -> None: """ @@ -378,3 +391,73 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockFixture) -> None: make_url("trino:///"), connect_args={"user": "alice.doe", "source": "Apache Superset"}, ) + + +def test_is_oauth2_enabled() -> None: + """ + Test the `is_oauth2_enabled` method. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + + assert not database.is_oauth2_enabled() + + database.encrypted_extra = json.dumps(oauth2_client_info) + assert database.is_oauth2_enabled() + + +def test_get_oauth2_config(app_context: None) -> None: + """ + Test the `get_oauth2_config` method. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + + assert database.get_oauth2_config() is None + + database.encrypted_extra = json.dumps(oauth2_client_info) + assert database.get_oauth2_config() == { + "id": "my_client_id", + "secret": "my_client_secret", + "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", + "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", + "scope": "refresh_token session:role:SYSADMIN", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + } + + +def test_raw_connection_oauth(mocker: MockFixture) -> None: + """ + Test that we can start OAuth2 from `raw_connection()` errors. + + Some databases that use OAuth2 need to trigger the flow when the connection is + created, rather than when the query runs. This happens when the SQLAlchemy engine + URI cannot be built without the user personal token. + + This test verifies that the exception is captured and raised correctly so that the + frontend can trigger the OAuth2 dance. + """ + g = mocker.patch("superset.db_engine_specs.base.g") + g.user = mocker.MagicMock() + g.user.id = 42 + + database = Database( + id=1, + database_name="my_db", + sqlalchemy_uri="sqlite://", + encrypted_extra=json.dumps(oauth2_client_info), + ) + database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore + get_sqla_engine = mocker.patch.object(database, "get_sqla_engine") + get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error( + "OAuth2 required" + ) + + with pytest.raises(OAuth2RedirectError) as excinfo: + with database.get_raw_connection() as conn: + conn.cursor() + assert str(excinfo.value) == "You don't have permission to access the data." diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 3b2e7690e141a..6ac055df13480 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -16,12 +16,22 @@ # under the License. # pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals +import json +from uuid import UUID + import sqlparse +from freezegun import freeze_time from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session from superset import db +from superset.common.db_query_status import QueryStatus +from superset.errors import ErrorLevel, SupersetErrorType +from superset.exceptions import OAuth2Error +from superset.models.core import Database +from superset.sql_lab import get_sql_results from superset.utils.core import override_user +from tests.unit_tests.models.core_test import oauth2_client_info def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: @@ -218,3 +228,55 @@ def test_sql_lab_insert_rls_as_subquery( query.executed_sql == "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6" ) + + +@freeze_time("2021-04-01T00:00:00Z") +def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: + """ + Test that `get_sql_results` works with OAuth2. + """ + app_context = app.test_request_context() + app_context.push() + + mocker.patch( + "superset.db_engine_specs.base.uuid4", + return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"), + ) + + g = mocker.patch("superset.db_engine_specs.base.g") + g.user = mocker.MagicMock() + g.user.id = 42 + + database = Database( + id=1, + database_name="my_db", + sqlalchemy_uri="sqlite://", + encrypted_extra=json.dumps(oauth2_client_info), + ) + database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore + get_sqla_engine = mocker.patch.object(database, "get_sqla_engine") + get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error( + "OAuth2 required" + ) + + query = mocker.MagicMock() + query.database = database + mocker.patch("superset.sql_lab.get_query", return_value=query) + + payload = get_sql_results(query_id=1, rendered_query="SELECT 1") + assert payload == { + "status": QueryStatus.FAILED, + "error": "You don't have permission to access the data.", + "errors": [ + { + "message": "You don't have permission to access the data.", + "error_type": SupersetErrorType.OAUTH2_REDIRECT, + "level": ErrorLevel.WARNING, + "extra": { + "url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO1928flOEFQIqqg9ctiXjcM&redirect_uri=http%3A%2F%2Fexample.com%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent", + "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + }, + } + ], + }