diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py index 2e3a215f1753b..706d18de6f458 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py @@ -21,7 +21,7 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import is_safe_url +from airflow.api_fastapi.core_api.security import AuthManagerDep, is_safe_url from airflow.configuration import conf auth_router = AirflowRouter(tags=["Login"], prefix="/auth") @@ -31,9 +31,9 @@ "/login", responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]), ) -def login(request: Request, next: None | str = None) -> RedirectResponse: +def login(request: Request, auth_manager: AuthManagerDep, next: None | str = None) -> RedirectResponse: """Redirect to the login URL depending on the AuthManager configured.""" - login_url = request.app.state.auth_manager.get_url_login() + login_url = auth_manager.get_url_login() if next and not is_safe_url(next, request=request): raise HTTPException(status_code=400, detail="Invalid or unsafe next URL") @@ -48,12 +48,12 @@ def login(request: Request, next: None | str = None) -> RedirectResponse: "/logout", responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]), ) -def logout(request: Request, next: None | str = None) -> RedirectResponse: +def logout(auth_manager: AuthManagerDep, next: None | str = None) -> RedirectResponse: """Logout the user.""" - logout_url = request.app.state.auth_manager.get_url_logout() + logout_url = auth_manager.get_url_logout() if not logout_url: - logout_url = request.app.state.auth_manager.get_url_login() + logout_url = auth_manager.get_url_login() return RedirectResponse(logout_url) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index 7bc7b155b68eb..385ffc90a3398 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -27,7 +27,10 @@ from pydantic import NonNegativeInt from airflow.api_fastapi.app import get_auth_manager -from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN +from airflow.api_fastapi.auth.managers.base_auth_manager import ( + COOKIE_NAME_JWT_TOKEN, + BaseAuthManager, +) from airflow.api_fastapi.auth.managers.models.base_user import BaseUser from airflow.api_fastapi.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, @@ -69,7 +72,20 @@ from fastapi.security import HTTPAuthorizationCredentials from sqlalchemy.sql import Select - from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod + from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod + + +def auth_manager_from_app(request: Request) -> BaseAuthManager: + """ + FastAPI dependency resolver that returns the shared AuthManager instance from app.state. + + This ensures that all API routes using AuthManager via dependency injection receive the same + singleton instance that was initialized at app startup. + """ + return request.app.state.auth_manager + + +AuthManagerDep = Annotated[BaseAuthManager, Depends(auth_manager_from_app)] auth_description = ( "To authenticate Airflow API requests, clients must include a JWT (JSON Web Token) in " @@ -194,7 +210,7 @@ def to_orm(self, select: Select) -> Select: def permitted_dag_filter_factory( method: ResourceMethod, filter_class=PermittedDagFilter -) -> Callable[[Request, BaseUser], PermittedDagFilter]: +) -> Callable[[BaseUser, BaseAuthManager], PermittedDagFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. @@ -203,10 +219,9 @@ def permitted_dag_filter_factory( """ def depends_permitted_dags_filter( - request: Request, user: GetUserDep, + auth_manager: AuthManagerDep, ) -> PermittedDagFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager authorized_dags: set[str] = auth_manager.get_authorized_dag_ids(user=user, method=method) return filter_class(authorized_dags) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py index 96d7451c8f26f..1b7b15d1e6557 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py @@ -475,3 +475,40 @@ async def test_requires_access_pool_bulk(self, mock_get_auth_manager): ], user=user, ) + + +class TestAuthManagerDependency: + """Test the auth_manager_from_app dependency function.""" + + def test_auth_manager_from_app_returns_instance_from_state(self): + """Test that auth_manager_from_app correctly retrieves auth_manager from app.state.""" + from airflow.api_fastapi.core_api.security import auth_manager_from_app + + # Create a mock auth manager + mock_auth_manager = Mock() + + # Create a mock request with app.state.auth_manager + mock_request = Mock() + mock_request.app.state.auth_manager = mock_auth_manager + + # Call the dependency function + result = auth_manager_from_app(mock_request) + + # Assert it returns the correct auth manager + assert result is mock_auth_manager + + def test_auth_manager_from_app_integration_with_test_client(self, test_client): + """Test that auth_manager_from_app works with the test client setup.""" + from airflow.api_fastapi.core_api.security import auth_manager_from_app + + # Create a mock request using the test client's app + mock_request = Mock() + mock_request.app = test_client.app + + # Get the auth manager + auth_manager = auth_manager_from_app(mock_request) + + # Verify it's not None (should be SimpleAuthManager from test fixture) + assert auth_manager is not None + assert hasattr(auth_manager, "get_url_login") + assert hasattr(auth_manager, "get_url_logout")