diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index baa852a4c54c4..1477b1b715e6b 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import cache -from typing import TYPE_CHECKING, Annotated, Any, Callable +from typing import TYPE_CHECKING, Annotated, Callable from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer @@ -47,9 +47,7 @@ def get_signer() -> JWTSigner: def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser: try: - signer = get_signer() - payload: dict[str, Any] = signer.verify_token(token_str) - return get_auth_manager().deserialize_user(payload) + return get_auth_manager().get_user_from_token(token_str) except InvalidTokenError: raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden") diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index fe86bc8f05acf..4cbc5044be6e8 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -17,9 +17,11 @@ # under the License. from __future__ import annotations +import logging from abc import abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeVar +from jwt import InvalidTokenError from sqlalchemy import select from airflow.auth.managers.models.base_user import BaseUser @@ -61,6 +63,7 @@ # TODO: Move this inside once all providers drop Airflow 2.x support. ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] +log = logging.getLogger(__name__) T = TypeVar("T", bound=BaseUser) @@ -102,14 +105,18 @@ def deserialize_user(self, token: dict[str, Any]) -> T: def serialize_user(self, user: T) -> dict[str, Any]: """Create a dict from a user object.""" + def get_user_from_token(self, token: str) -> BaseUser: + """Verify the JWT token is valid and create a user object from it if valid.""" + try: + payload: dict[str, Any] = self._get_token_signer().verify_token(token) + return self.deserialize_user(payload) + except InvalidTokenError as e: + log.error("JWT token is not valid") + raise e + def get_jwt_token(self, user: T) -> str: """Return the JWT token from a user object.""" - signer = JWTSigner( - secret_key=conf.get("api", "auth_jwt_secret"), - expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), - audience="front-apis", - ) - return signer.generate_signed_token(self.serialize_user(user)) + return self._get_token_signer().generate_signed_token(self.serialize_user(user)) def get_user_id(self) -> str | None: """Return the user ID associated to the user in session.""" @@ -437,3 +444,16 @@ def get_fastapi_app(self) -> FastAPI | None: def register_views(self) -> None: """Register views specific to the auth manager.""" + + @staticmethod + def _get_token_signer(): + """ + Return the signer used to sign JWT token. + + :meta private: + """ + return JWTSigner( + secret_key=conf.get("api", "auth_jwt_secret"), + expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), + audience="front-apis", + ) diff --git a/tests/api_fastapi/core_api/test_security.py b/tests/api_fastapi/core_api/test_security.py index 90bf3f647bba3..b9e1c58aa2010 100644 --- a/tests/api_fastapi/core_api/test_security.py +++ b/tests/api_fastapi/core_api/test_security.py @@ -43,39 +43,32 @@ def setup_class(cls): ): create_app() - @patch("airflow.api_fastapi.core_api.security.get_signer") @patch("airflow.api_fastapi.core_api.security.get_auth_manager") - def test_get_user(self, mock_get_auth_manager, mock_get_signer): + def test_get_user(self, mock_get_auth_manager): token_str = "test-token" - user_dict = {"user": "XXXXXXXXX"} user = SimpleAuthManagerUser(username="username", role="admin") auth_manager = Mock() - auth_manager.deserialize_user.return_value = user + auth_manager.get_user_from_token.return_value = user mock_get_auth_manager.return_value = auth_manager - signer = Mock() - signer.verify_token.return_value = user_dict - mock_get_signer.return_value = signer - result = get_user(token_str) - signer.verify_token.assert_called_once_with(token_str) - auth_manager.deserialize_user.assert_called_once_with(user_dict) + auth_manager.get_user_from_token.assert_called_once_with(token_str) assert result == user - @patch("airflow.api_fastapi.core_api.security.get_signer") - def test_get_user_unsuccessful(self, mock_get_signer): + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_get_user_unsuccessful(self, mock_get_auth_manager): token_str = "test-token" - signer = Mock() - signer.verify_token.side_effect = InvalidTokenError() - mock_get_signer.return_value = signer + auth_manager = Mock() + auth_manager.get_user_from_token.side_effect = InvalidTokenError() + mock_get_auth_manager.return_value = auth_manager with pytest.raises(HTTPException, match="Forbidden"): get_user(token_str) - signer.verify_token.assert_called_once_with(token_str) + auth_manager.get_user_from_token.assert_called_once_with(token_str) @patch("airflow.api_fastapi.core_api.security.get_auth_manager") def test_requires_access_dag_authorized(self, mock_get_auth_manager): diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 370e401da0609..9e4889d3d3d5d 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -190,6 +190,23 @@ def test_get_user_id_raise_exception_when_no_user(self, auth_manager): def test_get_url_user_profile_return_none(self, auth_manager): assert auth_manager.get_url_user_profile() is None + @patch("airflow.auth.managers.base_auth_manager.JWTSigner") + @patch.object(EmptyAuthManager, "deserialize_user") + def test_get_user_from_token(self, mock_deserialize_user, mock_jwt_signer, auth_manager): + token = "token" + payload = {} + user = BaseAuthManagerUserTest(name="test") + signer = Mock() + signer.verify_token.return_value = payload + mock_jwt_signer.return_value = signer + mock_deserialize_user.return_value = user + + result = auth_manager.get_user_from_token(token) + + mock_deserialize_user.assert_called_once_with(payload) + signer.verify_token.assert_called_once_with(token) + assert result == user + @patch("airflow.auth.managers.base_auth_manager.JWTSigner") @patch.object(EmptyAuthManager, "serialize_user") def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer, auth_manager):