Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"JWKS",
"JWTGenerator",
"JWTValidator",
"TOKEN_SCOPE_WORKLOAD",
"generate_private_key",
"get_sig_validation_args",
"get_signing_args",
Expand All @@ -54,6 +55,8 @@
"key_to_jwk_dict",
]

TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload"


class InvalidClaimError(ValueError):
"""Raised when a claim in the JWT is invalid."""
Expand Down Expand Up @@ -434,15 +437,28 @@ def signing_arg(self) -> AllowedPrivateKeys | str:
assert self._secret_key
return self._secret_key

def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str:
"""Generate a signed JWT for the subject."""
def generate(
self,
extras: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
valid_for: int | None = None,
) -> str:
"""
Generate a signed JWT.

Args:
extras: Additional claims to include in the token. These are merged with default claims.
headers: Additional headers to include in the JWT.
valid_for: Optional custom validity duration in seconds. If not provided, uses self.valid_for.
"""
now = int(datetime.now(tz=timezone.utc).timestamp())
token_valid_for = valid_for if valid_for is not None else self.valid_for
claims = {
"jti": uuid.uuid4().hex,
"iss": self.issuer,
"aud": self.audience,
"nbf": now,
"exp": int(now + self.valid_for),
"exp": int(now + token_valid_for),
"iat": now,
}

Expand All @@ -458,6 +474,23 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any]
headers["kid"] = self.kid
return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers)

def generate_workload_token(self, sub: str) -> str:
"""
Generate a long-lived workload token for task execution.

Workload tokens have a special 'scope' claim that restricts them to the /run endpoint only.
They are valid for longer (default 24h) to survive executor queue wait times.
"""
from airflow.configuration import conf

workload_valid_for = conf.getint(
"execution_api", "jwt_workload_token_expiration_time", fallback=86400
)
return self.generate(
extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD},
valid_for=workload_valid_for,
)


def generate_private_key(key_type: str = "RSA", key_size: int = 2048):
"""
Expand Down
48 changes: 48 additions & 0 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import secrets
import time
from contextlib import AsyncExitStack
from functools import cached_property
Expand Down Expand Up @@ -303,18 +304,64 @@ def app(self):
JWTBearerTIPathDep,
)
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
from airflow.configuration import conf

# Ensure JWT secret is available for in-process execution.
# The /run endpoint needs JWTGenerator to issue execution tokens.
# If the config option is empty, generate a random one for the duration of this process.
if not conf.get("api_auth", "jwt_secret", fallback=None):
logger.debug(
"`api_auth/jwt_secret` is not set, generating a temporary one for in-process execution"
)
conf.set("api_auth", "jwt_secret", secrets.token_urlsafe(16))

self._app = create_task_execution_api_app()

# Set up dag_bag in app state for dependency injection
self._app.state.dag_bag = create_dag_bag()

self._app.state.jwt_generator = _jwt_generator()
self._app.state.jwt_validator = _jwt_validator()

# Why InProcessContainer instead of lifespan.registry or svcs.Container?
#
# The normal app uses @svcs.fastapi.lifespan which manages the registry lifecycle.
# In tests (conftest.py), lifespan.registry.register_value() works because the
# TestClient initializes the lifespan before requests. However, in InProcessExecutionAPI,
# the lifespan runs later (when transport is accessed), but services may be needed
# before that. Using lifespan.registry fails in CI with ServiceNotFoundError.
#
# This minimal container bypasses the svcs lifecycle and directly returns pre-created
# service instances from app.state. If you add new services, update this class.
from airflow.api_fastapi.execution_api.deps import _container

class InProcessContainer:
"""Minimal container for in-process execution, bypassing svcs lifecycle."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use svcs? Why do we need to implement our version of it?

Copy link
Contributor Author

@anishgirianish anishgirianish Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still getting familiar with the codebase, but from what I understand, I added this to fix ServiceNotFoundError failures in CI. With InProcessExecutionAPI, the svcs lifespan runs later (when transport is accessed), but services like JWTGenerator are needed before that. This container bypasses the lifecycle and returns pre-created instances from app.state. I may well be missing something - if there's a cleaner pattern you'd recommend, I'd really appreciate the guidance.


def __init__(self, app_state):
self._services = {
JWTGenerator: app_state.jwt_generator,
JWTValidator: app_state.jwt_validator,
}

async def aget(self, svc_type):
if svc_type not in self._services:
raise KeyError(f"{svc_type} not registered in InProcessContainer")
return self._services[svc_type]

async def _inprocess_container():
yield InProcessContainer(self._app.state)

self._app.dependency_overrides[_container] = _inprocess_container

async def always_allow(): ...

self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow
self._app.dependency_overrides[JWTBearerWorkloadDep.dependency] = always_allow
self._app.dependency_overrides[has_connection_access] = always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
Expand All @@ -337,6 +384,7 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI):
self._cm = AsyncExitStack()

asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop)

return httpx.WSGITransport(app=middleware) # type: ignore[arg-type]

@cached_property
Expand Down
64 changes: 51 additions & 13 deletions airflow-core/src/airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fastapi.security import HTTPBearer
from sqlalchemy import select

from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_WORKLOAD, JWTValidator
from airflow.api_fastapi.common.db.common import AsyncSessionDep
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.configuration import conf
Expand All @@ -46,14 +46,12 @@ async def _container(request: Request):
DepContainer: svcs.Container = Depends(_container)


class JWTBearer(HTTPBearer):
class _BaseJWTBearer(HTTPBearer):
"""
A FastAPI security dependency that validates JWT tokens using for the Execution API.
Base class for JWT validation in the Execution API.

This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that.

The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other
validated claims.
Validates JWT tokens are properly signed and extracts claims. Subclasses
handle scope-specific validation.
"""

def __init__(
Expand All @@ -77,7 +75,6 @@ async def __call__( # type: ignore[override]
validator: JWTValidator = await services.aget(JWTValidator)

try:
# Example: Validate "task_instance_id" component of the path matches the one in the token
if self.path_param_name:
id = request.path_params[self.path_param_name]
validators: dict[str, Any] = {
Expand All @@ -87,15 +84,56 @@ async def __call__( # type: ignore[override]
else:
validators = self.required_claims
claims = await validator.avalidated_claims(creds.credentials, validators)

# Let subclasses validate scope
self._check_scope(claims)

return TIToken(id=claims["sub"], claims=claims)
except HTTPException:
raise
except Exception as err:
log.warning(
"Failed to validate JWT",
exc_info=True,
token=creds.credentials,
)
log.warning("Failed to validate JWT", exc_info=True)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}")

def _check_scope(self, claims: dict[str, Any]) -> None:
"""Override in subclasses to validate scope. Raise HTTPException if invalid."""
pass


class JWTBearer(_BaseJWTBearer):
"""
JWT validation that rejects workload-scoped tokens.

Used for most Execution API endpoints. Workload-scoped tokens can only be used
on the /run endpoint, which exchanges them for short-lived execution tokens.
"""

def _check_scope(self, claims: dict[str, Any]) -> None:
if claims.get("scope"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Scoped tokens cannot access this endpoint. Use the token from /run response.",
)


class JWTBearerWorkloadScope(_BaseJWTBearer):
"""
JWT validation that ONLY accepts workload-scoped tokens.

Used exclusively by the /run endpoint. Workload tokens have scope="ExecuteTaskWorkload"
and are long-lived to survive executor queue wait times. The /run endpoint validates
the workload token and issues a short-lived execution token for subsequent API calls.
"""

def _check_scope(self, claims: dict[str, Any]) -> None:
scope = claims.get("scope")
# Reject if scope is missing or not the workload scope
if scope != TOKEN_SCOPE_WORKLOAD:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="This endpoint requires a workload-scoped token",
)


JWTBearerDep: TIToken = Depends(JWTBearer())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,8 @@
authenticated_router.include_router(hitl.router, prefix="/hitlDetails", tags=["Human in the Loop"])

execution_api_router.include_router(authenticated_router)

# ti_run_router: /run endpoint - requires workload-scoped tokens (JWTBearerWorkloadDep)
execution_api_router.include_router(
task_instances.ti_run_router, prefix="/task-instances", tags=["Task Instances"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
from fastapi import Body, Depends, HTTPException, Query, Response, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.engine import CursorResult, Row
Expand All @@ -38,6 +38,7 @@
from structlog.contextvars import bind_contextvars

from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
Expand All @@ -59,7 +60,7 @@
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerTIPathDep, JWTBearerWorkloadScope
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dag import DagModel
Expand All @@ -78,6 +79,12 @@
if TYPE_CHECKING:
from sqlalchemy.sql.dml import Update

log = structlog.get_logger(__name__)

JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id"))

ti_run_router = VersionedAPIRouter(dependencies=[JWTBearerWorkloadDep])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need this extra router, when we already have router, and we can then declare the JWTBearerWorkloadDep on the path


router = VersionedAPIRouter()

ti_id_router = VersionedAPIRouter(
Expand All @@ -88,10 +95,7 @@
)


log = structlog.get_logger(__name__)


@ti_id_router.patch(
@ti_run_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
responses={
Expand All @@ -101,11 +105,13 @@
},
response_model_exclude_unset=True,
)
def ti_run(
async def ti_run(
task_instance_id: UUID,
ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
session: SessionDep,
dag_bag: DagBagDep,
response: Response,
services=DepContainer,
) -> TIRunContext:
"""
Run a TaskInstance.
Expand Down Expand Up @@ -279,6 +285,11 @@ def ti_run(
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs

# Generate short-lived execution token for subsequent API calls
generator: JWTGenerator = await services.aget(JWTGenerator)
execution_token = generator.generate(extras={"sub": ti_id_str})
response.headers["X-Execution-Token"] = execution_token

return context
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
Expand Down
11 changes: 11 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,17 @@ execution_api:
type: integer
example: ~
default: "600"
jwt_workload_token_expiration_time:
description: |
Number in seconds until the workload JWT token expires. Workload tokens are long-lived tokens
sent with task workloads to executors (e.g., Celery). They can only be used to call
the /run endpoint, which then issues a short-lived execution token.

This should be set long enough to cover the maximum expected queue wait time.
version_added: 3.1.7
type: integer
example: ~
default: "86400"
jwt_audience:
version_added: 3.0.0
description: |
Expand Down
8 changes: 7 additions & 1 deletion airflow-core/src/airflow/executors/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ class BaseWorkload(BaseModel):

@staticmethod
def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str:
return generator.generate({"sub": sub_id}) if generator else ""
"""
Generate a workload-scoped token for this workload.

Workload tokens are long-lived and can only be used on the /run endpoint,
which exchanges them for short-lived execution tokens.
"""
return generator.generate_workload_token(sub_id) if generator else ""


class BundleInfo(BaseModel):
Expand Down
Loading
Loading