Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from functools import cache
from typing import TYPE_CHECKING, Annotated, Any, Callable
from typing import TYPE_CHECKING, Annotated, Callable

from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
Expand Down Expand Up @@ -47,9 +47,7 @@ def get_signer() -> JWTSigner:

def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
try:
signer = get_signer()
payload: dict[str, Any] = signer.verify_token(token_str)
return get_auth_manager().deserialize_user(payload)
return get_auth_manager().get_user_from_token(token_str)
except InvalidTokenError:
raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")

Expand Down
32 changes: 26 additions & 6 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
# under the License.
from __future__ import annotations

import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from jwt import InvalidTokenError
from sqlalchemy import select

from airflow.auth.managers.models.base_user import BaseUser
Expand Down Expand Up @@ -61,6 +63,7 @@
# TODO: Move this inside once all providers drop Airflow 2.x support.
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]

log = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseUser)


Expand Down Expand Up @@ -102,14 +105,18 @@ def deserialize_user(self, token: dict[str, Any]) -> T:
def serialize_user(self, user: T) -> dict[str, Any]:
"""Create a dict from a user object."""

def get_user_from_token(self, token: str) -> BaseUser:
"""Verify the JWT token is valid and create a user object from it if valid."""
try:
payload: dict[str, Any] = self._get_token_signer().verify_token(token)
return self.deserialize_user(payload)
except InvalidTokenError as e:
log.error("JWT token is not valid")
raise e

def get_jwt_token(self, user: T) -> str:
"""Return the JWT token from a user object."""
signer = JWTSigner(
secret_key=conf.get("api", "auth_jwt_secret"),
expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"),
audience="front-apis",
)
return signer.generate_signed_token(self.serialize_user(user))
return self._get_token_signer().generate_signed_token(self.serialize_user(user))

def get_user_id(self) -> str | None:
"""Return the user ID associated to the user in session."""
Expand Down Expand Up @@ -437,3 +444,16 @@ def get_fastapi_app(self) -> FastAPI | None:

def register_views(self) -> None:
"""Register views specific to the auth manager."""

@staticmethod
def _get_token_signer():
"""
Return the signer used to sign JWT token.

:meta private:
"""
return JWTSigner(
secret_key=conf.get("api", "auth_jwt_secret"),
expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"),
audience="front-apis",
)
25 changes: 9 additions & 16 deletions tests/api_fastapi/core_api/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,32 @@ def setup_class(cls):
):
create_app()

@patch("airflow.api_fastapi.core_api.security.get_signer")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_get_user(self, mock_get_auth_manager, mock_get_signer):
def test_get_user(self, mock_get_auth_manager):
token_str = "test-token"
user_dict = {"user": "XXXXXXXXX"}
user = SimpleAuthManagerUser(username="username", role="admin")

auth_manager = Mock()
auth_manager.deserialize_user.return_value = user
auth_manager.get_user_from_token.return_value = user
mock_get_auth_manager.return_value = auth_manager

signer = Mock()
signer.verify_token.return_value = user_dict
mock_get_signer.return_value = signer

result = get_user(token_str)

signer.verify_token.assert_called_once_with(token_str)
auth_manager.deserialize_user.assert_called_once_with(user_dict)
auth_manager.get_user_from_token.assert_called_once_with(token_str)
assert result == user

@patch("airflow.api_fastapi.core_api.security.get_signer")
def test_get_user_unsuccessful(self, mock_get_signer):
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_get_user_unsuccessful(self, mock_get_auth_manager):
token_str = "test-token"

signer = Mock()
signer.verify_token.side_effect = InvalidTokenError()
mock_get_signer.return_value = signer
auth_manager = Mock()
auth_manager.get_user_from_token.side_effect = InvalidTokenError()
mock_get_auth_manager.return_value = auth_manager

with pytest.raises(HTTPException, match="Forbidden"):
get_user(token_str)

signer.verify_token.assert_called_once_with(token_str)
auth_manager.get_user_from_token.assert_called_once_with(token_str)

@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_requires_access_dag_authorized(self, mock_get_auth_manager):
Expand Down
17 changes: 17 additions & 0 deletions tests/auth/managers/test_base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,23 @@ def test_get_user_id_raise_exception_when_no_user(self, auth_manager):
def test_get_url_user_profile_return_none(self, auth_manager):
assert auth_manager.get_url_user_profile() is None

@patch("airflow.auth.managers.base_auth_manager.JWTSigner")
@patch.object(EmptyAuthManager, "deserialize_user")
def test_get_user_from_token(self, mock_deserialize_user, mock_jwt_signer, auth_manager):
token = "token"
payload = {}
user = BaseAuthManagerUserTest(name="test")
signer = Mock()
signer.verify_token.return_value = payload
mock_jwt_signer.return_value = signer
mock_deserialize_user.return_value = user

result = auth_manager.get_user_from_token(token)

mock_deserialize_user.assert_called_once_with(payload)
signer.verify_token.assert_called_once_with(token)
assert result == user

@patch("airflow.auth.managers.base_auth_manager.JWTSigner")
@patch.object(EmptyAuthManager, "serialize_user")
def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer, auth_manager):
Expand Down