diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index df2ed70340fd7..a956496c56a34 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import time from contextlib import AsyncExitStack from functools import cached_property from typing import TYPE_CHECKING, Any @@ -29,6 +30,7 @@ ) from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware from airflow.api_fastapi.auth.tokens import ( JWTGenerator, @@ -39,6 +41,7 @@ if TYPE_CHECKING: import httpx + from fastapi import Response from fastapi.routing import APIRoute import structlog @@ -96,6 +99,39 @@ async def lifespan(app: FastAPI, registry: svcs.Registry): yield +class JWTReissueMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + from airflow.configuration import conf + + response: Response = await call_next(request) + + refreshed_token: str | None = None + auth_header = request.headers.get("authorization") + if auth_header and auth_header.lower().startswith("bearer "): + token = auth_header.split(" ", 1)[1] + try: + async with svcs.Container(request.app.state.svcs_registry) as services: + validator: JWTValidator = await services.aget(JWTValidator) + claims = await validator.avalidated_claims(token, {}) + + now = int(time.time()) + validity = conf.getint("execution_api", "jwt_expiration_time") + refresh_when_less_than = max(int(validity * 0.20), 30) + valid_left = int(claims.get("exp", 0)) - now + if valid_left <= refresh_when_less_than: + generator: JWTGenerator = await services.aget(JWTGenerator) + refreshed_token = generator.generate(claims) + except Exception as err: + # Do not block the response if refreshing fails; log a warning for visibility + logger.warning( + "JWT reissue middleware failed to refresh token", error=str(err), exc_info=True + ) + + if refreshed_token: + response.headers["Refreshed-API-Token"] = refreshed_token + return response + + class CadwynWithOpenAPICustomization(Cadwyn): # Workaround lack of customzation https://github.com/zmievsa/cadwyn/issues/255 async def openapi_jsons(self, req: Request) -> JSONResponse: @@ -179,6 +215,7 @@ def custom_generate_unique_id(route: APIRoute): versions=bundle, ) + app.add_middleware(JWTReissueMiddleware) app.generate_and_include_versioned_routers(execution_api_router) # As we are mounted as a sub app, we don't get any logs for unhandled exceptions without this! @@ -233,7 +270,6 @@ def app(self): from airflow.api_fastapi.execution_api.deps import ( JWTBearerDep, JWTBearerTIPathDep, - JWTRefresherDep, ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access @@ -248,7 +284,6 @@ async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow - self._app.dependency_overrides[JWTRefresherDep.dependency] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index 2648a64ffad7a..d247a31f5f4fa 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -18,17 +18,14 @@ # Disable future annotations in this file to work around https://github.com/fastapi/fastapi/issues/13056 # ruff: noqa: I002 -import sys -import time from typing import Any import structlog import svcs -from fastapi import Depends, HTTPException, Request, Response, status +from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPBearer -from starlette.exceptions import HTTPException as StarletteHTTPException -from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator +from airflow.api_fastapi.auth.tokens import JWTValidator from airflow.api_fastapi.execution_api.datamodels.token import TIToken log = structlog.get_logger(logger_name=__name__) @@ -98,58 +95,3 @@ async def __call__( # type: ignore[override] # This checks that the UUID in the url matches the one in the token for us. JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) - - -class JWTReissuer: - """Re-issue JWTs to requests when they are about to run out.""" - - def __init__(self): - from airflow.configuration import conf - - self.refresh_when_less_than = max( - # Issue a new token to a task when the current one is valid for only either 20% of the total validity, - # or 30s - int(conf.getint("execution_api", "jwt_expiration_time") * 0.20), - 30, - ) - - async def __call__( - self, - response: Response, - token=JWTBearerDep, - services=DepContainer, - ): - try: - yield - finally: - # We want to run this even in the case of 404 errors etc - now = int(time.time()) - - try: - valid_left = token.claims["exp"] - now - if valid_left <= self.refresh_when_less_than: - generator: JWTGenerator = await services.aget(JWTGenerator) - new = generator.generate(token.claims) - response.headers["Refreshed-API-Token"] = new - log.debug( - "Refreshed token issued to Task", - valid_left=valid_left, - refresh_when_less_than=self.refresh_when_less_than, - ) - - exc, val, _ = sys.exc_info() - if val and isinstance(val, StarletteHTTPException): - # If there is an exception thrown, we need to set the headers there instead - if val.headers is None: - val.headers = {} - - # Defined as a "mapping type", but 99.9% of the time it's a mutable dict. We catch - # errors if not - val.headers["Refreshed-API-Token"] = new # type: ignore[index] - - except Exception as e: - # Don't 500 if there's a problem - log.warning("Error refreshing Task JWT", err=f"{type(e).__name__}: {e}") - - -JWTRefresherDep = Depends(JWTReissuer()) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 89d96083876db..562b8588fbf2c 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -19,7 +19,7 @@ from cadwyn import VersionedAPIRouter from fastapi import APIRouter -from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTRefresherDep +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.api_fastapi.execution_api.routes import ( asset_events, assets, @@ -37,7 +37,7 @@ execution_api_router.include_router(health.router, prefix="/health", tags=["Health"]) # _Every_ single endpoint under here must be authenticated. Some do further checks on top of these -authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep, JWTRefresherDep]) # type: ignore[list-item] +authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item] authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py index e4aadcead6cc1..85f7df4691566 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py @@ -24,7 +24,6 @@ from airflow.api_fastapi.auth.tokens import JWTValidator from airflow.api_fastapi.execution_api.app import lifespan -from airflow.api_fastapi.execution_api.deps import JWTRefresherDep, JWTReissuer from tests_common.test_utils.config import conf_vars @@ -57,15 +56,14 @@ def test_expiring_token_is_reissued( "exp": moment + validity, } - with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}): - exec_app.dependency_overrides[JWTRefresherDep.dependency] = JWTReissuer() - time_machine.move_to(moment + age, tick=False) # Inject our fake JWTValidator object. Can be over-ridden by tests if they want lifespan.registry.register_value(JWTValidator, auth) # In order to test this we need any endpoint to hit. The easiest one to use is variable get - response = client.get("/execution/variables/key1") + + with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}): + response = client.get("/execution/variables/key1", headers={"Authorization": "Bearer dummy"}) if expect_refreshed_token: assert "Refreshed-API-Token" in response.headers diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 793e6ffac8d57..3fbaebbb73ebc 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -98,7 +98,6 @@ def side_effect(cred, validators): raise RuntimeError("Fake auth denied") return claims - # validator.avalidated_claims.side_effect = [{}, RuntimeError("fail for tests"), claims, claims] validator.avalidated_claims.side_effect = side_effect lifespan.registry.register_value(JWTValidator, validator) @@ -113,7 +112,7 @@ def side_effect(cred, validators): resp = client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run", json=payload) assert resp.status_code == 403 - validator.avalidated_claims.assert_called_with( + assert validator.avalidated_claims.call_args_list[1] == mock.call( mock.ANY, {"sub": {"essential": True, "value": "9c230b40-da03-451d-8bd7-be30471be383"}} ) validator.avalidated_claims.reset_mock()