Skip to content

Commit

Permalink
feat: OAuth2 client initial work (#29109)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored and eschutho committed Jul 24, 2024
1 parent 1534965 commit 25cbbc2
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 27 deletions.
21 changes: 14 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,9 +1610,11 @@ def where_latest_partition( # pylint: disable=unused-argument
@classmethod
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
return [
literal_column(query_as)
if (query_as := c.get("query_as"))
else column(c["column_name"])
(
literal_column(query_as)
if (query_as := c.get("query_as"))
else column(c["column_name"])
)
for c in cols
]

Expand Down Expand Up @@ -1828,13 +1830,18 @@ def execute( # pylint: disable=unused-argument
cursor.arraysize = cls.arraysize
try:
cursor.execute(query)
except cls.oauth2_exception as ex:
if database.is_oauth2_enabled() and g and g.user:
cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex
except Exception as ex:
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex

@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
"""
Check if the exception is one that indicates OAuth2 is needed.
"""
return g and hasattr(g, "user") and isinstance(ex, cls.oauth2_exception)

@classmethod
def make_label_compatible(cls, label: str) -> str | quoted_name:
"""
Expand Down
54 changes: 35 additions & 19 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Callable, cast, TYPE_CHECKING

import numpy
import pandas as pd
Expand Down Expand Up @@ -78,7 +78,7 @@
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import DatasourceName, get_username
from superset.utils.oauth2 import get_oauth2_access_token
from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema

config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
Expand Down Expand Up @@ -554,17 +554,23 @@ def get_raw_connection(
nullpool=nullpool,
source=source,
) as engine:
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the
# future, the selected catalog
for prequery in self.db_engine_spec.get_prequeries(
catalog=catalog,
schema=schema,
):
cursor = conn.cursor()
cursor.execute(prequery)
try:
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the
# future, the selected catalog
for prequery in self.db_engine_spec.get_prequeries(
catalog=catalog,
schema=schema,
):
cursor = conn.cursor()
cursor.execute(prequery)

yield conn
yield conn

except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.db_engine_spec.start_oauth2_dance(self)
raise ex

def get_default_catalog(self) -> str | None:
"""
Expand Down Expand Up @@ -1063,20 +1069,30 @@ def is_oauth2_enabled(self) -> bool:
"""
Is OAuth2 enabled in the database for authentication?
Currently this looks for a global config at the DB engine spec level, but in the
future we want to be allow admins to create custom OAuth2 clients from the
Superset UI, and assign them to specific databases.
Currently this checks for configuration stored in the database `extra`, and then
for a global config at the DB engine spec level. In the future we want to allow
admins to create custom OAuth2 clients from the Superset UI, and assign them to
specific databases.
"""
return self.db_engine_spec.is_oauth2_enabled()
encrypted_extra = json.loads(self.encrypted_extra or "{}")
oauth2_client_info = encrypted_extra.get("oauth2_client_info", {})
return bool(oauth2_client_info) or self.db_engine_spec.is_oauth2_enabled()

def get_oauth2_config(self) -> OAuth2ClientConfig | None:
"""
Return OAuth2 client configuration.
This includes client ID, client secret, scope, redirect URI, endpointsm etc.
Currently this reads the global DB engine spec config, but in the future it
should first check if there's a custom client assigned to the database.
Currently this checks for configuration stored in the database `extra`, and then
for a global config at the DB engine spec level. In the future we want to allow
admins to create custom OAuth2 clients from the Superset UI, and assign them to
specific databases.
"""
encrypted_extra = json.loads(self.encrypted_extra or "{}")
if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
schema = OAuth2ClientConfigSchema()
client_config = schema.load(oauth2_client_info)
return cast(OAuth2ClientConfig, client_config)

return self.db_engine_spec.get_oauth2_config()


Expand Down
14 changes: 13 additions & 1 deletion superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import backoff
import jwt
from flask import current_app
from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema

from superset import db
Expand Down Expand Up @@ -180,3 +180,15 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State:
state = oauth2_state_schema.load(payload)

return state


class OAuth2ClientConfigSchema(Schema):
id = fields.String(required=True)
secret = fields.String(required=True)
scope = fields.String(required=True)
redirect_uri = fields.String(
required=False,
load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
3 changes: 3 additions & 0 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def app(request: SubRequest) -> Iterator[SupersetApp]:
app.config["RATELIMIT_ENABLED"] = False
app.config["CACHE_CONFIG"] = {}
app.config["DATA_CACHE_CONFIG"] = {}
app.config["SERVER_NAME"] = "example.com"
app.config["APPLICATION_ROOT"] = "/"
app.config["PREFERRED_URL_SCHEME="] = "http"

# loop over extra configs passed in by tests
# and update the app config
Expand Down
83 changes: 83 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

# pylint: disable=import-outside-toplevel

from datetime import datetime

import pytest
Expand All @@ -24,11 +25,23 @@
from sqlalchemy.engine.url import make_url

from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils import json
from tests.unit_tests.conftest import with_feature_flags

# sample config for OAuth2 tests
oauth2_client_info = {
"oauth2_client_info": {
"id": "my_client_id",
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
}
}


def test_get_metrics(mocker: MockFixture) -> None:
"""
Expand Down Expand Up @@ -378,3 +391,73 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockFixture) -> None:
make_url("trino:///"),
connect_args={"user": "alice.doe", "source": "Apache Superset"},
)


def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)

assert not database.is_oauth2_enabled()

database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.is_oauth2_enabled()


def test_get_oauth2_config(app_context: None) -> None:
"""
Test the `get_oauth2_config` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)

assert database.get_oauth2_config() is None

database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.get_oauth2_config() == {
"id": "my_client_id",
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
}


def test_raw_connection_oauth(mocker: MockFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
Some databases that use OAuth2 need to trigger the flow when the connection is
created, rather than when the query runs. This happens when the SQLAlchemy engine
URI cannot be built without the user personal token.
This test verifies that the exception is captured and raised correctly so that the
frontend can trigger the OAuth2 dance.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42

database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
"OAuth2 required"
)

with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."
62 changes: 62 additions & 0 deletions tests/unit_tests/sql_lab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@
# under the License.
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals

import json
from uuid import UUID

import sqlparse
from freezegun import freeze_time
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session

from superset import db
from superset.common.db_query_status import QueryStatus
from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import OAuth2Error
from superset.models.core import Database
from superset.sql_lab import get_sql_results
from superset.utils.core import override_user
from tests.unit_tests.models.core_test import oauth2_client_info


def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
Expand Down Expand Up @@ -218,3 +228,55 @@ def test_sql_lab_insert_rls_as_subquery(
query.executed_sql
== "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6"
)


@freeze_time("2021-04-01T00:00:00Z")
def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
"""
Test that `get_sql_results` works with OAuth2.
"""
app_context = app.test_request_context()
app_context.push()

mocker.patch(
"superset.db_engine_specs.base.uuid4",
return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
)

g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42

database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
"OAuth2 required"
)

query = mocker.MagicMock()
query.database = database
mocker.patch("superset.sql_lab.get_query", return_value=query)

payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
assert payload == {
"status": QueryStatus.FAILED,
"error": "You don't have permission to access the data.",
"errors": [
{
"message": "You don't have permission to access the data.",
"error_type": SupersetErrorType.OAUTH2_REDIRECT,
"level": ErrorLevel.WARNING,
"extra": {
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO1928flOEFQIqqg9ctiXjcM&redirect_uri=http%3A%2F%2Fexample.com%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent",
"tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
},
}
],
}

0 comments on commit 25cbbc2

Please sign in to comment.