Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import time
from contextlib import AsyncExitStack
from functools import cached_property
from typing import TYPE_CHECKING, Any
Expand All @@ -29,6 +30,7 @@
)
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

from airflow.api_fastapi.auth.tokens import (
JWTGenerator,
Expand All @@ -39,6 +41,7 @@

if TYPE_CHECKING:
import httpx
from fastapi import Response
from fastapi.routing import APIRoute

import structlog
Expand Down Expand Up @@ -96,6 +99,39 @@ async def lifespan(app: FastAPI, registry: svcs.Registry):
yield


class JWTReissueMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
from airflow.configuration import conf

response: Response = await call_next(request)

refreshed_token: str | None = None
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
async with svcs.Container(request.app.state.svcs_registry) as services:
validator: JWTValidator = await services.aget(JWTValidator)
claims = await validator.avalidated_claims(token, {})

now = int(time.time())
validity = conf.getint("execution_api", "jwt_expiration_time")
refresh_when_less_than = max(int(validity * 0.20), 30)
valid_left = int(claims.get("exp", 0)) - now
if valid_left <= refresh_when_less_than:
generator: JWTGenerator = await services.aget(JWTGenerator)
refreshed_token = generator.generate(claims)
except Exception as err:
# Do not block the response if refreshing fails; log a warning for visibility
logger.warning(
"JWT reissue middleware failed to refresh token", error=str(err), exc_info=True
)

if refreshed_token:
response.headers["Refreshed-API-Token"] = refreshed_token
return response


class CadwynWithOpenAPICustomization(Cadwyn):
# Workaround lack of customzation https://github.com/zmievsa/cadwyn/issues/255
async def openapi_jsons(self, req: Request) -> JSONResponse:
Expand Down Expand Up @@ -179,6 +215,7 @@ def custom_generate_unique_id(route: APIRoute):
versions=bundle,
)

app.add_middleware(JWTReissueMiddleware)
app.generate_and_include_versioned_routers(execution_api_router)

# As we are mounted as a sub app, we don't get any logs for unhandled exceptions without this!
Expand Down Expand Up @@ -233,7 +270,6 @@ def app(self):
from airflow.api_fastapi.execution_api.deps import (
JWTBearerDep,
JWTBearerTIPathDep,
JWTRefresherDep,
)
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
Expand All @@ -248,7 +284,6 @@ async def always_allow(): ...

self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow
self._app.dependency_overrides[JWTRefresherDep.dependency] = always_allow
self._app.dependency_overrides[has_connection_access] = always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
Expand Down
62 changes: 2 additions & 60 deletions airflow-core/src/airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
# Disable future annotations in this file to work around https://github.com/fastapi/fastapi/issues/13056
# ruff: noqa: I002

import sys
import time
from typing import Any

import structlog
import svcs
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer
from starlette.exceptions import HTTPException as StarletteHTTPException

from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.datamodels.token import TIToken

log = structlog.get_logger(logger_name=__name__)
Expand Down Expand Up @@ -98,58 +95,3 @@ async def __call__( # type: ignore[override]

# This checks that the UUID in the url matches the one in the token for us.
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))


class JWTReissuer:
"""Re-issue JWTs to requests when they are about to run out."""

def __init__(self):
from airflow.configuration import conf

self.refresh_when_less_than = max(
# Issue a new token to a task when the current one is valid for only either 20% of the total validity,
# or 30s
int(conf.getint("execution_api", "jwt_expiration_time") * 0.20),
30,
)

async def __call__(
self,
response: Response,
token=JWTBearerDep,
services=DepContainer,
):
try:
yield
finally:
# We want to run this even in the case of 404 errors etc
now = int(time.time())

try:
valid_left = token.claims["exp"] - now
if valid_left <= self.refresh_when_less_than:
generator: JWTGenerator = await services.aget(JWTGenerator)
new = generator.generate(token.claims)
response.headers["Refreshed-API-Token"] = new
log.debug(
"Refreshed token issued to Task",
valid_left=valid_left,
refresh_when_less_than=self.refresh_when_less_than,
)

exc, val, _ = sys.exc_info()
if val and isinstance(val, StarletteHTTPException):
# If there is an exception thrown, we need to set the headers there instead
if val.headers is None:
val.headers = {}

# Defined as a "mapping type", but 99.9% of the time it's a mutable dict. We catch
# errors if not
val.headers["Refreshed-API-Token"] = new # type: ignore[index]

except Exception as e:
# Don't 500 if there's a problem
log.warning("Error refreshing Task JWT", err=f"{type(e).__name__}: {e}")


JWTRefresherDep = Depends(JWTReissuer())
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cadwyn import VersionedAPIRouter
from fastapi import APIRouter

from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTRefresherDep
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.routes import (
asset_events,
assets,
Expand All @@ -37,7 +37,7 @@
execution_api_router.include_router(health.router, prefix="/health", tags=["Health"])

# _Every_ single endpoint under here must be authenticated. Some do further checks on top of these
authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep, JWTRefresherDep]) # type: ignore[list-item]
authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item]

authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
from airflow.api_fastapi.execution_api.deps import JWTRefresherDep, JWTReissuer

from tests_common.test_utils.config import conf_vars

Expand Down Expand Up @@ -57,15 +56,14 @@ def test_expiring_token_is_reissued(
"exp": moment + validity,
}

with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
exec_app.dependency_overrides[JWTRefresherDep.dependency] = JWTReissuer()

time_machine.move_to(moment + age, tick=False)

# Inject our fake JWTValidator object. Can be over-ridden by tests if they want
lifespan.registry.register_value(JWTValidator, auth)
# In order to test this we need any endpoint to hit. The easiest one to use is variable get
response = client.get("/execution/variables/key1")

with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
response = client.get("/execution/variables/key1", headers={"Authorization": "Bearer dummy"})

if expect_refreshed_token:
assert "Refreshed-API-Token" in response.headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def side_effect(cred, validators):
raise RuntimeError("Fake auth denied")
return claims

# validator.avalidated_claims.side_effect = [{}, RuntimeError("fail for tests"), claims, claims]
validator.avalidated_claims.side_effect = side_effect

lifespan.registry.register_value(JWTValidator, validator)
Expand All @@ -113,7 +112,7 @@ def side_effect(cred, validators):

resp = client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run", json=payload)
assert resp.status_code == 403
validator.avalidated_claims.assert_called_with(
assert validator.avalidated_claims.call_args_list[1] == mock.call(
mock.ANY, {"sub": {"essential": True, "value": "9c230b40-da03-451d-8bd7-be30471be383"}}
)
validator.avalidated_claims.reset_mock()
Expand Down