diff --git a/superset/config.py b/superset/config.py index 228a9f2e4c98..2637c0032bef 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1316,6 +1316,7 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument GUEST_TOKEN_JWT_ALGO = "HS256" GUEST_TOKEN_HEADER_NAME = "X-GuestToken" GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes +GUEST_TOKEN_JWT_AUDIENCE = None # A SQL dataset health check. Note if enabled it is strongly advised that the callable # be memoized to aid with performance, i.e., diff --git a/superset/security/manager.py b/superset/security/manager.py index 0bed4476e526..ac494a183782 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -74,6 +74,7 @@ GuestUser, ) from superset.utils.core import DatasourceName, RowLevelSecurityFilterType +from superset.utils.urls import get_url_host if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -1308,6 +1309,7 @@ def create_guest_access_token( secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] + aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() # calculate expiration time now = self._get_current_epoch_time() @@ -1319,6 +1321,8 @@ def create_guest_access_token( # standard jwt claims: "iat": now, # issued at "exp": exp, # expiration time + "aud": aud, + "type": "guest", } token = jwt.encode(claims, secret, algorithm=algo) return token @@ -1344,6 +1348,8 @@ def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]: raise ValueError("Guest token does not contain a resources claim") if token.get("rls_rules") is None: raise ValueError("Guest token does not contain an rls_rules claim") + if token.get("type") != "guest": + raise ValueError("This is not a guest token.") except Exception: # pylint: disable=broad-except # The login manager will handle sending 401s. # We don't need to send a special error message. @@ -1366,7 +1372,8 @@ def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: """ secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] - return jwt.decode(raw_token, secret, algorithms=[algo]) + aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + return jwt.decode(raw_token, secret, algorithms=[algo], audience=aud) @staticmethod def is_guest_user(user: Optional[Any] = None) -> bool: diff --git a/tests/integration_tests/security/api_tests.py b/tests/integration_tests/security/api_tests.py index d7b365985d9b..86be5e7da58e 100644 --- a/tests/integration_tests/security/api_tests.py +++ b/tests/integration_tests/security/api_tests.py @@ -22,6 +22,7 @@ from tests.integration_tests.base_tests import SupersetTestCase from flask_wtf.csrf import generate_csrf +from superset.utils.urls import get_url_host class TestSecurityCsrfApi(SupersetTestCase): @@ -90,6 +91,8 @@ def test_post_guest_token_authorized(self): self.assert200(response) token = json.loads(response.data)["token"] - decoded_token = jwt.decode(token, self.app.config["GUEST_TOKEN_JWT_SECRET"]) + decoded_token = jwt.decode( + token, self.app.config["GUEST_TOKEN_JWT_SECRET"], audience=get_url_host() + ) self.assertEqual(user, decoded_token["user"]) self.assertEqual(resource, decoded_token["resources"][0]) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 46ca679deebc..efcd191ffafc 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -44,6 +44,7 @@ get_example_default_schema, ) from superset.utils.database import get_example_database +from superset.utils.urls import get_url_host from superset.views.access_requests import AccessRequestsModelView from .base_tests import SupersetTestCase @@ -1177,17 +1178,20 @@ def test_create_guest_access_token(self, get_time_mock): resources = [{"some": "resource"}] rls = [{"dataset": 1, "clause": "access = 1"}] token = security_manager.create_guest_access_token(user, resources, rls) - + aud = get_url_host() # unfortunately we cannot mock time in the jwt lib decoded_token = jwt.decode( token, self.app.config["GUEST_TOKEN_JWT_SECRET"], algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]], + audience=aud, ) self.assertEqual(user, decoded_token["user"]) self.assertEqual(resources, decoded_token["resources"]) self.assertEqual(now, decoded_token["iat"]) + self.assertEqual(aud, decoded_token["aud"]) + self.assertEqual("guest", decoded_token["type"]) self.assertEqual( now + (self.app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] * 1000), decoded_token["exp"], @@ -1241,3 +1245,57 @@ def test_get_guest_user_no_resource(self): self.assertRaisesRegex( ValueError, "Guest token does not contain a resources claim" ) + + def test_get_guest_user_not_guest_type(self): + now = time.time() + user = {"username": "test_guest"} + resources = [{"some": "resource"}] + aud = get_url_host() + + claims = { + "user": user, + "resources": resources, + "rls_rules": [], + # standard jwt claims: + "aud": aud, + "iat": now, # issued at + "type": "not_guest", + } + token = jwt.encode( + claims, + self.app.config["GUEST_TOKEN_JWT_SECRET"], + algorithm=self.app.config["GUEST_TOKEN_JWT_ALGO"], + ) + fake_request = FakeRequest() + fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token + guest_user = security_manager.get_guest_user_from_request(fake_request) + + self.assertIsNone(guest_user) + self.assertRaisesRegex(ValueError, "This is not a guest token.") + + def test_get_guest_user_bad_audience(self): + now = time.time() + user = {"username": "test_guest"} + resources = [{"some": "resource"}] + aud = get_url_host() + + claims = { + "user": user, + "resources": resources, + "rls_rules": [], + # standard jwt claims: + "aud": "bad_audience", + "iat": now, # issued at + "type": "guest", + } + token = jwt.encode( + claims, + self.app.config["GUEST_TOKEN_JWT_SECRET"], + algorithm=self.app.config["GUEST_TOKEN_JWT_ALGO"], + ) + fake_request = FakeRequest() + fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token + guest_user = security_manager.get_guest_user_from_request(fake_request) + + self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience") + self.assertIsNone(guest_user)