diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 276ae17153da0..3e82671799efc 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -46,6 +46,7 @@ "JWKS", "JWTGenerator", "JWTValidator", + "TOKEN_SCOPE_WORKLOAD", "generate_private_key", "get_sig_validation_args", "get_signing_args", @@ -54,6 +55,8 @@ "key_to_jwk_dict", ] +TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload" + class InvalidClaimError(ValueError): """Raised when a claim in the JWT is invalid.""" @@ -434,15 +437,28 @@ def signing_arg(self) -> AllowedPrivateKeys | str: assert self._secret_key return self._secret_key - def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str: - """Generate a signed JWT for the subject.""" + def generate( + self, + extras: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + valid_for: int | None = None, + ) -> str: + """ + Generate a signed JWT. + + Args: + extras: Additional claims to include in the token. These are merged with default claims. + headers: Additional headers to include in the JWT. + valid_for: Optional custom validity duration in seconds. If not provided, uses self.valid_for. + """ now = int(datetime.now(tz=timezone.utc).timestamp()) + token_valid_for = valid_for if valid_for is not None else self.valid_for claims = { "jti": uuid.uuid4().hex, "iss": self.issuer, "aud": self.audience, "nbf": now, - "exp": int(now + self.valid_for), + "exp": int(now + token_valid_for), "iat": now, } @@ -458,6 +474,23 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] headers["kid"] = self.kid return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + def generate_workload_token(self, sub: str) -> str: + """ + Generate a long-lived workload token for task execution. + + Workload tokens have a special 'scope' claim that restricts them to the /run endpoint only. + They are valid for longer (default 24h) to survive executor queue wait times. + """ + from airflow.configuration import conf + + workload_valid_for = conf.getint( + "execution_api", "jwt_workload_token_expiration_time", fallback=86400 + ) + return self.generate( + extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD}, + valid_for=workload_valid_for, + ) + def generate_private_key(key_type: str = "RSA", key_size: int = 2048): """ diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 9d93f3bf84daf..fbfd27a54a254 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import secrets import time from contextlib import AsyncExitStack from functools import cached_property @@ -303,18 +304,64 @@ def app(self): JWTBearerTIPathDep, ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access + from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access + from airflow.configuration import conf + + # Ensure JWT secret is available for in-process execution. + # The /run endpoint needs JWTGenerator to issue execution tokens. + # If the config option is empty, generate a random one for the duration of this process. + if not conf.get("api_auth", "jwt_secret", fallback=None): + logger.debug( + "`api_auth/jwt_secret` is not set, generating a temporary one for in-process execution" + ) + conf.set("api_auth", "jwt_secret", secrets.token_urlsafe(16)) self._app = create_task_execution_api_app() # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() + self._app.state.jwt_generator = _jwt_generator() + self._app.state.jwt_validator = _jwt_validator() + + # Why InProcessContainer instead of lifespan.registry or svcs.Container? + # + # The normal app uses @svcs.fastapi.lifespan which manages the registry lifecycle. + # In tests (conftest.py), lifespan.registry.register_value() works because the + # TestClient initializes the lifespan before requests. However, in InProcessExecutionAPI, + # the lifespan runs later (when transport is accessed), but services may be needed + # before that. Using lifespan.registry fails in CI with ServiceNotFoundError. + # + # This minimal container bypasses the svcs lifecycle and directly returns pre-created + # service instances from app.state. If you add new services, update this class. + from airflow.api_fastapi.execution_api.deps import _container + + class InProcessContainer: + """Minimal container for in-process execution, bypassing svcs lifecycle.""" + + def __init__(self, app_state): + self._services = { + JWTGenerator: app_state.jwt_generator, + JWTValidator: app_state.jwt_validator, + } + + async def aget(self, svc_type): + if svc_type not in self._services: + raise KeyError(f"{svc_type} not registered in InProcessContainer") + return self._services[svc_type] + + async def _inprocess_container(): + yield InProcessContainer(self._app.state) + + self._app.dependency_overrides[_container] = _inprocess_container + async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow + self._app.dependency_overrides[JWTBearerWorkloadDep.dependency] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow @@ -337,6 +384,7 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): self._cm = AsyncExitStack() asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop) + return httpx.WSGITransport(app=middleware) # type: ignore[arg-type] @cached_property diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index fce188d48ed6d..4195bbddcdec3 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -26,7 +26,7 @@ from fastapi.security import HTTPBearer from sqlalchemy import select -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_WORKLOAD, JWTValidator from airflow.api_fastapi.common.db.common import AsyncSessionDep from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.configuration import conf @@ -46,14 +46,12 @@ async def _container(request: Request): DepContainer: svcs.Container = Depends(_container) -class JWTBearer(HTTPBearer): +class _BaseJWTBearer(HTTPBearer): """ - A FastAPI security dependency that validates JWT tokens using for the Execution API. + Base class for JWT validation in the Execution API. - This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that. - - The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other - validated claims. + Validates JWT tokens are properly signed and extracts claims. Subclasses + handle scope-specific validation. """ def __init__( @@ -77,7 +75,6 @@ async def __call__( # type: ignore[override] validator: JWTValidator = await services.aget(JWTValidator) try: - # Example: Validate "task_instance_id" component of the path matches the one in the token if self.path_param_name: id = request.path_params[self.path_param_name] validators: dict[str, Any] = { @@ -87,15 +84,56 @@ async def __call__( # type: ignore[override] else: validators = self.required_claims claims = await validator.avalidated_claims(creds.credentials, validators) + + # Let subclasses validate scope + self._check_scope(claims) + return TIToken(id=claims["sub"], claims=claims) + except HTTPException: + raise except Exception as err: - log.warning( - "Failed to validate JWT", - exc_info=True, - token=creds.credentials, - ) + log.warning("Failed to validate JWT", exc_info=True) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + def _check_scope(self, claims: dict[str, Any]) -> None: + """Override in subclasses to validate scope. Raise HTTPException if invalid.""" + pass + + +class JWTBearer(_BaseJWTBearer): + """ + JWT validation that rejects workload-scoped tokens. + + Used for most Execution API endpoints. Workload-scoped tokens can only be used + on the /run endpoint, which exchanges them for short-lived execution tokens. + """ + + def _check_scope(self, claims: dict[str, Any]) -> None: + if claims.get("scope"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Scoped tokens cannot access this endpoint. Use the token from /run response.", + ) + + +class JWTBearerWorkloadScope(_BaseJWTBearer): + """ + JWT validation that ONLY accepts workload-scoped tokens. + + Used exclusively by the /run endpoint. Workload tokens have scope="ExecuteTaskWorkload" + and are long-lived to survive executor queue wait times. The /run endpoint validates + the workload token and issues a short-lived execution token for subsequent API calls. + """ + + def _check_scope(self, claims: dict[str, Any]) -> None: + scope = claims.get("scope") + # Reject if scope is missing or not the workload scope + if scope != TOKEN_SCOPE_WORKLOAD: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="This endpoint requires a workload-scoped token", + ) + JWTBearerDep: TIToken = Depends(JWTBearer()) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 562b8588fbf2c..bb2e7c33b7ae9 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -52,3 +52,8 @@ authenticated_router.include_router(hitl.router, prefix="/hitlDetails", tags=["Human in the Loop"]) execution_api_router.include_router(authenticated_router) + +# ti_run_router: /run endpoint - requires workload-scoped tokens (JWTBearerWorkloadDep) +execution_api_router.include_router( + task_instances.ti_run_router, prefix="/task-instances", tags=["Task Instances"] +) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index a729a1ee83b30..3594b1bf6f328 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, Depends, HTTPException, Query, Response, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.engine import CursorResult, Row @@ -38,6 +38,7 @@ from structlog.contextvars import bind_contextvars from airflow._shared.timezones import timezone +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime @@ -59,7 +60,7 @@ TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep +from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerTIPathDep, JWTBearerWorkloadScope from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive from airflow.models.dag import DagModel @@ -78,6 +79,12 @@ if TYPE_CHECKING: from sqlalchemy.sql.dml import Update +log = structlog.get_logger(__name__) + +JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id")) + +ti_run_router = VersionedAPIRouter(dependencies=[JWTBearerWorkloadDep]) + router = VersionedAPIRouter() ti_id_router = VersionedAPIRouter( @@ -88,10 +95,7 @@ ) -log = structlog.get_logger(__name__) - - -@ti_id_router.patch( +@ti_run_router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, responses={ @@ -101,11 +105,13 @@ }, response_model_exclude_unset=True, ) -def ti_run( +async def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep, dag_bag: DagBagDep, + response: Response, + services=DepContainer, ) -> TIRunContext: """ Run a TaskInstance. @@ -279,6 +285,11 @@ def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs + # Generate short-lived execution token for subsequent API calls + generator: JWTGenerator = await services.aget(JWTGenerator) + execution_token = generator.generate(extras={"sub": ti_id_str}) + response.headers["X-Execution-Token"] = execution_token + return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 71eeea1af180f..c5b025ad86ec4 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1831,6 +1831,17 @@ execution_api: type: integer example: ~ default: "600" + jwt_workload_token_expiration_time: + description: | + Number in seconds until the workload JWT token expires. Workload tokens are long-lived tokens + sent with task workloads to executors (e.g., Celery). They can only be used to call + the /run endpoint, which then issues a short-lived execution token. + + This should be set long enough to cover the maximum expected queue wait time. + version_added: 3.1.7 + type: integer + example: ~ + default: "86400" jwt_audience: version_added: 3.0.0 description: | diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 7cf1aae60ff21..c73231dc0571b 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -45,7 +45,13 @@ class BaseWorkload(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" + """ + Generate a workload-scoped token for this workload. + + Workload tokens are long-lived and can only be used on the /run endpoint, + which exchanges them for short-lived execution tokens. + """ + return generator.generate_workload_token(sub_id) if generator else "" class BundleInfo(BaseModel): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index b17c8147dae24..ed7d51b7fad9e 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -30,6 +30,7 @@ from airflow._shared.timezones import timezone from airflow.api_fastapi.auth.tokens import ( JWKS, + TOKEN_SCOPE_WORKLOAD, InvalidClaimError, JWTGenerator, JWTValidator, @@ -238,3 +239,42 @@ def rsa_private_key(): @pytest.fixture(scope="session") def ed25519_private_key(): return generate_private_key(key_type="Ed25519") + + +async def test_generate_workload_token(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): + """Test that generate_workload_token creates tokens with workload scope and longer expiration.""" + token = jwt_generator.generate_workload_token("test_subject") + + claims = await jwt_validator.avalidated_claims( + token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + # Verify workload scope is set + assert claims.get("scope") == TOKEN_SCOPE_WORKLOAD, "Workload token should have workload scope" + + # Verify the token has extended expiration (default 24h = 86400s) + nbf = datetime.fromtimestamp(claims["nbf"], timezone.utc) + exp = datetime.fromtimestamp(claims["exp"], timezone.utc) + expiration_seconds = (exp - nbf).total_seconds() + + # Should be around 24 hours (86400 seconds) - allow some tolerance + assert expiration_seconds >= 86000, "Workload token should have extended expiration (~24h)" + assert expiration_seconds <= 90000, "Workload token expiration should not exceed expected duration" + + +async def test_workload_token_vs_regular_token_scope( + jwt_generator: JWTGenerator, jwt_validator: JWTValidator +): + """Test that regular tokens don't have scope claim while workload tokens do.""" + regular_token = jwt_generator.generate({"sub": "test_subject"}) + workload_token = jwt_generator.generate_workload_token("test_subject") + + regular_claims = await jwt_validator.avalidated_claims( + regular_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + workload_claims = await jwt_validator.avalidated_claims( + workload_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + assert "scope" not in regular_claims, "Regular token should not have scope claim" + assert workload_claims.get("scope") == TOKEN_SCOPE_WORKLOAD, "Workload token should have workload scope" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 9e26937b63c06..c9ae4ea364681 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -16,14 +16,22 @@ # under the License. from __future__ import annotations -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest from fastapi.testclient import TestClient from airflow.api_fastapi.app import cached_app -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTBearerTIPathDep +from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep + + +def _always_allow(ti_id: str | None = None) -> TIToken: + """Return a mock TIToken for bypassing auth in tests.""" + return TIToken(id=ti_id or "00000000-0000-0000-0000-000000000000", claims={}) @pytest.fixture @@ -63,4 +71,29 @@ def smart_validated_claims(cred, validators=None): auth.avalidated_claims.side_effect = smart_validated_claims lifespan.registry.register_value(JWTValidator, auth) + # Mock JWTGenerator for /run endpoint that returns execution tokens + jwt_generator = MagicMock(spec=JWTGenerator) + jwt_generator.generate.return_value = "mock-execution-token" + lifespan.registry.register_value(JWTGenerator, jwt_generator) + + jwt_bearer_instance = JWTBearerDep.dependency + jwt_bearer_ti_path_instance = JWTBearerTIPathDep.dependency + jwt_bearer_workload_instance = JWTBearerWorkloadDep.dependency + + execution_app = None + for route in app.routes: + if hasattr(route, "path") and route.path == "/execution": + execution_app = route.app + break + + if execution_app: + execution_app.dependency_overrides[jwt_bearer_instance] = lambda: _always_allow() + execution_app.dependency_overrides[jwt_bearer_ti_path_instance] = lambda: _always_allow() + execution_app.dependency_overrides[jwt_bearer_workload_instance] = lambda: _always_allow() + yield client + + if execution_app: + execution_app.dependency_overrides.pop(jwt_bearer_instance, None) + execution_app.dependency_overrides.pop(jwt_bearer_ti_path_instance, None) + execution_app.dependency_overrides.pop(jwt_bearer_workload_instance, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index e99b7d9380f03..be66f2af0e126 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -29,8 +29,6 @@ from sqlalchemy.orm import Session from airflow._shared.timezones import timezone -from airflow.api_fastapi.auth.tokens import JWTValidator -from airflow.api_fastapi.execution_api.app import lifespan from airflow.exceptions import AirflowSkipException from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel @@ -79,50 +77,35 @@ def _create_asset_aliases(session, num: int = 2) -> None: def client_with_extra_route(): ... -def test_id_matches_sub_claim(client, session, create_task_instance): - # Test that this is validated at the router level, so we don't have to test it in each component - # We validate it is set correctly, and test it once +def test_run_endpoint_returns_execution_token(client, session, create_task_instance, time_machine): + """Test that /run endpoint returns an execution token in the response header.""" + instant = timezone.parse("2024-09-30T12:00:00Z") + time_machine.move_to(instant, tick=False) ti = create_task_instance( - task_id="test_ti_run_state_conflict_if_not_queued", - state="queued", + task_id="test_run_endpoint_returns_execution_token", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), ) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - claims = {"sub": ti.id} - - def side_effect(cred, validators): - if not validators: - return claims - if validators["sub"]["value"] != ti.id: - raise RuntimeError("Fake auth denied") - return claims - - validator.avalidated_claims.side_effect = side_effect - - lifespan.registry.register_value(JWTValidator, validator) - payload = { "state": "running", "hostname": "random-hostname", "unixname": "random-unixname", "pid": 100, - "start_date": "2024-10-31T12:00:00Z", + "start_date": "2024-09-30T12:00:00Z", } - resp = client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run", json=payload) - assert resp.status_code == 403 - assert validator.avalidated_claims.call_args_list[1] == mock.call( - mock.ANY, {"sub": {"essential": True, "value": "9c230b40-da03-451d-8bd7-be30471be383"}} - ) - validator.avalidated_claims.reset_mock() - resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload) + assert resp.status_code == 200 - assert resp.status_code == 200, resp.json() - - validator.avalidated_claims.assert_awaited() + # Verify execution token is returned in header + assert "X-Execution-Token" in resp.headers + assert resp.headers["X-Execution-Token"] == "mock-execution-token" class TestTIRunState: diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 2b9ed36c7cb4d..04bddc46bc97f 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -233,7 +233,7 @@ def set_instance_attrs(self) -> Generator: @pytest.fixture def mock_executors(self): mock_jwt_generator = MagicMock(spec=JWTGenerator) - mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-token" default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 4e95ed3a4eea7..6280bf03f0ee0 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -58,7 +58,7 @@ def __init__(self, do_update=True, *args, **kwargs): # Mock JWT generator for token generation mock_jwt_generator = MagicMock() - mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-token" self.jwt_generator = mock_jwt_generator diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 2a38ef2fad7b3..c6e70a91306c6 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -928,7 +928,12 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * ) def _update_auth(self, response: httpx.Response): - if new_token := response.headers.get("Refreshed-API-Token"): + # Check for execution token from /run endpoint (replaces queue token with short-lived execution token) + if new_token := response.headers.get("X-Execution-Token"): + log.debug("Received execution token from /run endpoint") + self.auth = BearerAuth(new_token) + # Check for refreshed token from heartbeat/other endpoints + elif new_token := response.headers.get("Refreshed-API-Token"): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token)