diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py index c62eb29bd71bf..2ec4f20bf72d9 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py @@ -149,22 +149,29 @@ def get_url_logout(self) -> str | None: def refresh_user(self, *, user: KeycloakAuthManagerUser) -> KeycloakAuthManagerUser | None: if self._token_expired(user.access_token): - try: - log.debug("Refreshing the token") - client = self.get_keycloak_client() - tokens = client.refresh_token(user.refresh_token) + tokens = self.refresh_tokens(user=user) + + if tokens: user.refresh_token = tokens["refresh_token"] user.access_token = tokens["access_token"] return user - except KeycloakPostError as exc: - log.warning( - "KeycloakPostError encountered during token refresh. " - "Suppressing the exception and returning None.", - exc_info=exc, - ) return None + def refresh_tokens(self, *, user: KeycloakAuthManagerUser) -> dict[str, str]: + try: + log.debug("Refreshing the token") + client = self.get_keycloak_client() + return client.refresh_token(user.refresh_token) + except KeycloakPostError as exc: + log.warning( + "KeycloakPostError encountered during token refresh. " + "Suppressing the exception and returning None.", + exc_info=exc, + ) + + return {} + def is_authorized_configuration( self, *, diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/routes/login.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/routes/login.py index 247f0c4cbd9f4..48cec7dc89424 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/routes/login.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/routes/login.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import Annotated +from typing import Annotated, cast from fastapi import Depends, Request from starlette.responses import HTMLResponse, RedirectResponse @@ -83,14 +83,18 @@ def login_callback(request: Request): @login_router.get("/logout") def logout(request: Request, user: Annotated[KeycloakAuthManagerUser, Depends(get_user)]): """Log out the user from Keycloak.""" - client = KeycloakAuthManager.get_keycloak_client() - keycloak_config = client.well_known() + auth_manager = cast("KeycloakAuthManager", get_auth_manager()) + keycloak_config = auth_manager.get_keycloak_client().well_known() end_session_endpoint = keycloak_config["end_session_endpoint"] # Use the refresh flow to get the id token, it avoids us to save the id token - tokens = client.refresh_token(user.refresh_token) + token_id = auth_manager.refresh_tokens(user=user).get("id_token") post_logout_redirect_uri = request.url_for("logout_callback") - logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}&id_token_hint={tokens['id_token']}" + + if token_id: + logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}&id_token_hint={token_id}" + else: + logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}" return RedirectResponse(logout_url) @@ -118,16 +122,14 @@ def refresh( request: Request, user: Annotated[KeycloakAuthManagerUser, Depends(get_user)] ) -> RedirectResponse: """Refresh the token.""" - client = KeycloakAuthManager.get_keycloak_client() - - tokens = client.refresh_token(user.refresh_token) - user.refresh_token = tokens["refresh_token"] - user.access_token = tokens["access_token"] - token = get_auth_manager().generate_jwt(user) - + auth_manager = cast("KeycloakAuthManager", get_auth_manager()) + refreshed_user = auth_manager.refresh_user(user=user) redirect_url = request.query_params.get("next", conf.get("api", "base_url", fallback="/")) response = RedirectResponse(url=redirect_url, status_code=303) - secure = bool(conf.get("api", "ssl_cert", fallback="")) - response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure) + if refreshed_user: + token = auth_manager.generate_jwt(refreshed_user) + secure = bool(conf.get("api", "ssl_cert", fallback="")) + response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure) + return response diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py b/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py index 091408839790b..735f12cf711b5 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py @@ -18,7 +18,10 @@ from unittest.mock import ANY, Mock, patch +from keycloak import KeycloakPostError + from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX +from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser class TestLoginRouter: @@ -90,17 +93,33 @@ def test_logout(self, mock_get_keycloak_client, client): ) mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token") - @patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager") @patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client") - def test_refresh_token(self, mock_get_keycloak_client, mock_get_auth_manager, client): + def test_logout_when_keycloak_client_raises_keycloak_post_error(self, mock_get_keycloak_client, client): mock_keycloak_client = Mock() - mock_keycloak_client.refresh_token.return_value = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - } + mock_keycloak_client.well_known.return_value = {"end_session_endpoint": "logout_url"} + mock_keycloak_client.refresh_token.side_effect = KeycloakPostError( + response_code=400, + response_body=b'{"error":"invalid_grant","error_description":"Token is not active"}', + ) mock_get_keycloak_client.return_value = mock_keycloak_client + response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/logout", follow_redirects=False) + assert response.status_code == 307 + assert "location" in response.headers + assert ( + response.headers["location"] + == "logout_url?post_logout_redirect_uri=http://testserver/auth/logout_callback" + ) + mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token") + @patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager") + def test_refresh_token(self, mock_get_auth_manager, client): mock_auth_manager = Mock() + mock_auth_manager.refresh_user.return_value = KeycloakAuthManagerUser( + user_id="user_id", + name="name", + access_token="new_access_token", + refresh_token="new_refresh_token", + ) mock_auth_manager.generate_jwt.return_value = "token" mock_get_auth_manager.return_value = mock_auth_manager @@ -110,5 +129,5 @@ def test_refresh_token(self, mock_get_keycloak_client, mock_get_auth_manager, cl assert response.headers["location"] == "/" assert "_token" in response.cookies assert response.cookies["_token"] == "token" - mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token") + mock_auth_manager.refresh_user.assert_called_once() mock_auth_manager.generate_jwt.assert_called_once()