Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


@xcom_router.get(
"/{xcom_key}",
"/{xcom_key:path}",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -292,7 +292,7 @@ def create_xcom_entry(


@xcom_router.patch(
"/{xcom_key}",
"/{xcom_key:path}",
status_code=status.HTTP_200_OK,
responses=create_openapi_http_exception_doc(
[
Expand Down
212 changes: 106 additions & 106 deletions airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Annotated

from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, status
from pydantic import BaseModel, JsonValue, StringConstraints
from pydantic import BaseModel, JsonValue
from sqlalchemy import delete
from sqlalchemy.sql.selectable import Select

Expand All @@ -41,7 +41,7 @@ async def has_xcom_access(
dag_id: str,
run_id: str,
task_id: str,
xcom_key: Annotated[str, Path(alias="key")],
xcom_key: Annotated[str, Path(alias="key", min_length=1)],
request: Request,
token=JWTBearerDep,
) -> bool:
Expand Down Expand Up @@ -88,111 +88,15 @@ async def xcom_query(
return query


@router.head(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={
status.HTTP_200_OK: {
"description": "Metadata about the number of matching XCom values",
"headers": {
"Content-Range": {
"schema": {"pattern": r"^map_indexes \d+$"},
"description": "The number of (mapped) XCom values found for this task.",
},
},
},
},
description="Returns the count of mapped XCom values found in the `Content-Range` response header",
)
def head_xcom(
response: Response,
session: SessionDep,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int | None, Query()] = None,
) -> None:
"""Get the count of XComs from database - not other XCom Backends."""
if map_index is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"},
)

count = get_query_count(xcom_query, session=session)
# Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines
# "bytes" but we can add our own)
response.headers["Content-Range"] = f"map_indexes {count}"


class GetXcomFilterParams(BaseModel):
"""Class to house the params that can optionally be set for Get XCom."""

map_index: int = -1
include_prior_dates: bool = False
offset: int | None = None


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
description="Get a single XCom Value",
)
def get_xcom(
dag_id: str,
run_id: str,
task_id: str,
key: Annotated[str, StringConstraints(min_length=1)],
session: SessionDep,
params: Annotated[GetXcomFilterParams, Query()],
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=params.include_prior_dates,
)
if params.offset is not None:
xcom_query = xcom_query.where(XComModel.value.is_not(None)).order_by(None)
if params.offset >= 0:
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
else:
xcom_query = xcom_query.where(XComModel.map_index == params.map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
result = session.scalars(xcom_query).first()
if result is None:
if params.offset is None:
message = (
f"XCom with {key=} map_index={params.map_index} not found for "
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
else:
message = (
f"XCom with {key=} offset={params.offset} not found for "
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": message},
)

return XComResponse(key=key, value=result.value)


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}",
"/{dag_id}/{run_id}/{task_id}/{key:path}/item/{offset}",
description="Get a single XCom value from a mapped task by sequence index",
)
def get_mapped_xcom_by_index(
dag_id: str,
run_id: str,
task_id: str,
key: str,
key: Annotated[str, Path(min_length=1)],
offset: int,
session: SessionDep,
) -> XComSequenceIndexResponse:
Expand Down Expand Up @@ -229,14 +133,14 @@ class GetXComSliceFilterParams(BaseModel):


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}/slice",
"/{dag_id}/{run_id}/{task_id}/{key:path}/slice",
description="Get XCom values from a mapped task by sequence slice",
)
def get_mapped_xcom_by_slice(
dag_id: str,
run_id: str,
task_id: str,
key: str,
key: Annotated[str, Path(min_length=1)],
params: Annotated[GetXComSliceFilterParams, Query()],
session: SessionDep,
) -> XComSequenceSliceResponse:
Expand Down Expand Up @@ -310,17 +214,113 @@ def get_mapped_xcom_by_slice(
return XComSequenceSliceResponse(values)


@router.head(
"/{dag_id}/{run_id}/{task_id}/{key:path}",
responses={
status.HTTP_200_OK: {
"description": "Metadata about the number of matching XCom values",
"headers": {
"Content-Range": {
"schema": {"pattern": r"^map_indexes \d+$"},
"description": "The number of (mapped) XCom values found for this task.",
},
},
},
},
description="Returns the count of mapped XCom values found in the `Content-Range` response header",
)
def head_xcom(
response: Response,
session: SessionDep,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int | None, Query()] = None,
) -> None:
"""Get the count of XComs from database - not other XCom Backends."""
if map_index is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"},
)

count = get_query_count(xcom_query, session=session)
# Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines
# "bytes" but we can add our own)
response.headers["Content-Range"] = f"map_indexes {count}"


class GetXcomFilterParams(BaseModel):
"""Class to house the params that can optionally be set for Get XCom."""

map_index: int = -1
include_prior_dates: bool = False
offset: int | None = None


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key:path}",
description="Get a single XCom Value",
)
def get_xcom(
dag_id: str,
run_id: str,
task_id: str,
key: Annotated[str, Path(min_length=1)],
session: SessionDep,
params: Annotated[GetXcomFilterParams, Query()],
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=params.include_prior_dates,
)
if params.offset is not None:
xcom_query = xcom_query.where(XComModel.value.is_not(None)).order_by(None)
if params.offset >= 0:
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
else:
xcom_query = xcom_query.where(XComModel.map_index == params.map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
result = session.scalars(xcom_query).first()
if result is None:
if params.offset is None:
message = (
f"XCom with {key=} map_index={params.map_index} not found for "
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
else:
message = (
f"XCom with {key=} offset={params.offset} not found for "
f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": message},
)

return XComResponse(key=key, value=result.value)


# TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the URL and just use the info in
# the token
@router.post(
"/{dag_id}/{run_id}/{task_id}/{key}",
"/{dag_id}/{run_id}/{task_id}/{key:path}",
status_code=status.HTTP_201_CREATED,
)
def set_xcom(
dag_id: str,
run_id: str,
task_id: str,
key: Annotated[str, StringConstraints(min_length=1)],
key: Annotated[str, Path(min_length=1)],
session: SessionDep,
value: Annotated[
JsonValue,
Expand Down Expand Up @@ -431,7 +431,7 @@ def set_xcom(


@router.delete(
"/{dag_id}/{run_id}/{task_id}/{key}",
"/{dag_id}/{run_id}/{task_id}/{key:path}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
description="Delete a single XCom Value",
)
Expand All @@ -440,7 +440,7 @@ def delete_xcom(
dag_id: str,
run_id: str,
task_id: str,
key: str,
key: Annotated[str, Path(min_length=1)],
map_index: Annotated[int, Query()] = -1,
):
"""Delete a single XCom Value."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ def test_should_raise_404_for_non_existent_xcom(self, test_client):
assert response.status_code == 404
assert response.json()["detail"] == f"XCom entry with key: `{TEST_XCOM_KEY_2}` not found"

def test_should_respond_200_native_with_slash_key(self, test_client):
slash_key = "folder/sub/value"
self._create_xcom(slash_key, TEST_XCOM_VALUE)
# Use raw slash_key directly - FastAPI with :path converter handles it
response = test_client.get(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{slash_key}"
)
assert response.status_code == 200
current_data = response.json()
assert current_data["key"] == slash_key
assert current_data["value"] == json.dumps(TEST_XCOM_VALUE)

@pytest.mark.parametrize(
"params, expected_value",
[
Expand Down Expand Up @@ -630,6 +642,23 @@ def test_should_respond_403(self, unauthorized_test_client):
)
assert response.status_code == 403

def test_create_xcom_entry_with_slash_key(self, test_client):
slash_key = "a/b/c"
body = XComCreateBody(key=slash_key, value=TEST_XCOM_VALUE)
response = test_client.post(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries",
json=body.dict(),
)
assert response.status_code == 201
assert response.json()["key"] == slash_key
# Verify retrieval via encoded path
get_resp = test_client.get(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{slash_key}"
)
assert get_resp.status_code == 200
assert get_resp.json()["key"] == slash_key
assert get_resp.json()["value"] == json.dumps(TEST_XCOM_VALUE)


class TestPatchXComEntry(TestXComEndpoint):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -657,7 +686,8 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai
# Ensure the XCom entry exists before updating
if expected_status != 404:
self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE)
new_value = XComModel.serialize_value(patch_body["value"])
# The value is double-serialized: first json.dumps(patch_body["value"]), then json.dumps() again
new_value = json.dumps(json.dumps(patch_body["value"]))

response = test_client.patch(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{key}",
Expand All @@ -667,7 +697,7 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai
assert response.status_code == expected_status

if expected_status == 200:
assert response.json()["value"] == XComModel.serialize_value(new_value)
assert response.json()["value"] == new_value
else:
assert response.json()["detail"] == expected_detail
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)
Expand All @@ -685,3 +715,17 @@ def test_should_respond_403(self, unauthorized_test_client):
json={},
)
assert response.status_code == 403

def test_patch_xcom_entry_with_slash_key(self, test_client, session):
slash_key = "x/y"
self._create_xcom(slash_key, TEST_XCOM_VALUE)
new_value = {"updated": True}
response = test_client.patch(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{slash_key}",
json={"value": new_value},
)
assert response.status_code == 200
assert response.json()["key"] == slash_key
# The value is double-serialized: first json.dumps(new_value), then json.dumps() again
assert response.json()["value"] == json.dumps(json.dumps(new_value))
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)
Loading