Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c9bb323
refactor: Fix logout route in Keycloak provider also so the KeycloakP…
dabla Dec 13, 2025
8f2b4dd
refactor: Fixed static checks
dabla Dec 13, 2025
6b043f9
refactor: Fixed refresh_token invocations
dabla Dec 13, 2025
3340f4d
refactor: Must call refresh_user in refresh route
dabla Dec 13, 2025
db2bc99
refactor: refresh_token must always return a dict
dabla Dec 13, 2025
3f9e648
Merge branch 'main' into fix/login-refresh-route-keycloak-provider
dabla Dec 13, 2025
986c3b8
refactor: Added test when keycloak client raises KeycloakPostError wh…
dabla Dec 13, 2025
32dc04a
refactor: Fixed some additional static checks
dabla Dec 13, 2025
8763f39
refactor: Refactored refresh_user
dabla Dec 13, 2025
a0fa31d
Merge branch 'main' into fix/login-refresh-route-keycloak-provider
dabla Dec 13, 2025
12dff72
refactor: Reformatted imports
dabla Dec 13, 2025
3c34eed
refactor: Fixed mocking in refresh test
dabla Dec 13, 2025
5303bf6
refactor: Removed unused mocking of keycloak client in test_refresh_t…
dabla Dec 13, 2025
7d84193
refactor: Fixed mock get_auth_manager and added missing import Keyclo…
dabla Dec 13, 2025
9af6eaa
refactor: Refresh token route calls refresh_user instead of refresh_t…
dabla Dec 13, 2025
c07a6b7
refactor: Changed assert on refresh user
dabla Dec 14, 2025
5ac54cd
Merge branch 'main' into fix/login-refresh-route-keycloak-provider
dabla Dec 14, 2025
b22af36
Update providers/keycloak/src/airflow/providers/keycloak/auth_manager…
dabla Dec 15, 2025
fbb8326
Merge branch 'main' into fix/login-refresh-route-keycloak-provider
dabla Dec 15, 2025
a30eb51
refactor: Fixed calls to refresh_tokens instead of refresh_token
dabla Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,29 @@ def get_url_logout(self) -> str | None:

def refresh_user(self, *, user: KeycloakAuthManagerUser) -> KeycloakAuthManagerUser | None:
if self._token_expired(user.access_token):
try:
log.debug("Refreshing the token")
client = self.get_keycloak_client()
tokens = client.refresh_token(user.refresh_token)
tokens = self.refresh_tokens(user=user)

if tokens:
user.refresh_token = tokens["refresh_token"]
user.access_token = tokens["access_token"]
return user
except KeycloakPostError as exc:
log.warning(
"KeycloakPostError encountered during token refresh. "
"Suppressing the exception and returning None.",
exc_info=exc,
)

return None

def refresh_tokens(self, *, user: KeycloakAuthManagerUser) -> dict[str, str]:
try:
log.debug("Refreshing the token")
client = self.get_keycloak_client()
return client.refresh_token(user.refresh_token)
except KeycloakPostError as exc:
log.warning(
"KeycloakPostError encountered during token refresh. "
"Suppressing the exception and returning None.",
exc_info=exc,
)

return {}

def is_authorized_configuration(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import logging
from typing import Annotated
from typing import Annotated, cast

from fastapi import Depends, Request
from starlette.responses import HTMLResponse, RedirectResponse
Expand Down Expand Up @@ -83,14 +83,18 @@ def login_callback(request: Request):
@login_router.get("/logout")
def logout(request: Request, user: Annotated[KeycloakAuthManagerUser, Depends(get_user)]):
"""Log out the user from Keycloak."""
client = KeycloakAuthManager.get_keycloak_client()
keycloak_config = client.well_known()
auth_manager = cast("KeycloakAuthManager", get_auth_manager())
keycloak_config = auth_manager.get_keycloak_client().well_known()
end_session_endpoint = keycloak_config["end_session_endpoint"]

# Use the refresh flow to get the id token, it avoids us to save the id token
tokens = client.refresh_token(user.refresh_token)
token_id = auth_manager.refresh_tokens(user=user).get("id_token")
post_logout_redirect_uri = request.url_for("logout_callback")
logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}&id_token_hint={tokens['id_token']}"

if token_id:
logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}&id_token_hint={token_id}"
else:
logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}"

return RedirectResponse(logout_url)

Expand Down Expand Up @@ -118,16 +122,14 @@ def refresh(
request: Request, user: Annotated[KeycloakAuthManagerUser, Depends(get_user)]
) -> RedirectResponse:
"""Refresh the token."""
client = KeycloakAuthManager.get_keycloak_client()

tokens = client.refresh_token(user.refresh_token)
user.refresh_token = tokens["refresh_token"]
user.access_token = tokens["access_token"]
token = get_auth_manager().generate_jwt(user)

auth_manager = cast("KeycloakAuthManager", get_auth_manager())
refreshed_user = auth_manager.refresh_user(user=user)
redirect_url = request.query_params.get("next", conf.get("api", "base_url", fallback="/"))
response = RedirectResponse(url=redirect_url, status_code=303)
secure = bool(conf.get("api", "ssl_cert", fallback=""))

response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
if refreshed_user:
token = auth_manager.generate_jwt(refreshed_user)
secure = bool(conf.get("api", "ssl_cert", fallback=""))
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)

return response
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from unittest.mock import ANY, Mock, patch

from keycloak import KeycloakPostError

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser


class TestLoginRouter:
Expand Down Expand Up @@ -90,17 +93,33 @@ def test_logout(self, mock_get_keycloak_client, client):
)
mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token")

@patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager")
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_refresh_token(self, mock_get_keycloak_client, mock_get_auth_manager, client):
def test_logout_when_keycloak_client_raises_keycloak_post_error(self, mock_get_keycloak_client, client):
mock_keycloak_client = Mock()
mock_keycloak_client.refresh_token.return_value = {
"access_token": "new_access_token",
"refresh_token": "new_refresh_token",
}
mock_keycloak_client.well_known.return_value = {"end_session_endpoint": "logout_url"}
mock_keycloak_client.refresh_token.side_effect = KeycloakPostError(
response_code=400,
response_body=b'{"error":"invalid_grant","error_description":"Token is not active"}',
)
mock_get_keycloak_client.return_value = mock_keycloak_client
response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/logout", follow_redirects=False)
assert response.status_code == 307
assert "location" in response.headers
assert (
response.headers["location"]
== "logout_url?post_logout_redirect_uri=http://testserver/auth/logout_callback"
)
mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token")

@patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager")
def test_refresh_token(self, mock_get_auth_manager, client):
mock_auth_manager = Mock()
mock_auth_manager.refresh_user.return_value = KeycloakAuthManagerUser(
user_id="user_id",
name="name",
access_token="new_access_token",
refresh_token="new_refresh_token",
)
mock_auth_manager.generate_jwt.return_value = "token"
mock_get_auth_manager.return_value = mock_auth_manager

Expand All @@ -110,5 +129,5 @@ def test_refresh_token(self, mock_get_keycloak_client, mock_get_auth_manager, cl
assert response.headers["location"] == "/"
assert "_token" in response.cookies
assert response.cookies["_token"] == "token"
mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token")
mock_auth_manager.refresh_user.assert_called_once()
mock_auth_manager.generate_jwt.assert_called_once()