diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 6aa86d09f22e7..b1a4bc84414fd 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -60,7 +60,7 @@ @xcom_router.get( - "/{xcom_key}", + "/{xcom_key:path}", responses=create_openapi_http_exception_doc( [ status.HTTP_400_BAD_REQUEST, @@ -292,7 +292,7 @@ def create_xcom_entry( @xcom_router.patch( - "/{xcom_key}", + "/{xcom_key:path}", status_code=status.HTTP_200_OK, responses=create_openapi_http_exception_doc( [ diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index b2399635499cc..1408adcefee50 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -21,7 +21,7 @@ from typing import Annotated from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, status -from pydantic import BaseModel, JsonValue, StringConstraints +from pydantic import BaseModel, JsonValue from sqlalchemy import delete from sqlalchemy.sql.selectable import Select @@ -41,7 +41,7 @@ async def has_xcom_access( dag_id: str, run_id: str, task_id: str, - xcom_key: Annotated[str, Path(alias="key")], + xcom_key: Annotated[str, Path(alias="key", min_length=1)], request: Request, token=JWTBearerDep, ) -> bool: @@ -88,111 +88,15 @@ async def xcom_query( return query -@router.head( - "/{dag_id}/{run_id}/{task_id}/{key}", - responses={ - status.HTTP_200_OK: { - "description": "Metadata about the number of matching XCom values", - "headers": { - "Content-Range": { - "schema": {"pattern": r"^map_indexes \d+$"}, - "description": "The number of (mapped) XCom values found for this task.", - }, - }, - }, - }, - description="Returns the count of mapped XCom values found in the `Content-Range` response header", -) -def head_xcom( - response: Response, - session: SessionDep, - xcom_query: Annotated[Select, Depends(xcom_query)], - map_index: Annotated[int | None, Query()] = None, -) -> None: - """Get the count of XComs from database - not other XCom Backends.""" - if map_index is not None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"}, - ) - - count = get_query_count(xcom_query, session=session) - # Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines - # "bytes" but we can add our own) - response.headers["Content-Range"] = f"map_indexes {count}" - - -class GetXcomFilterParams(BaseModel): - """Class to house the params that can optionally be set for Get XCom.""" - - map_index: int = -1 - include_prior_dates: bool = False - offset: int | None = None - - @router.get( - "/{dag_id}/{run_id}/{task_id}/{key}", - description="Get a single XCom Value", -) -def get_xcom( - dag_id: str, - run_id: str, - task_id: str, - key: Annotated[str, StringConstraints(min_length=1)], - session: SessionDep, - params: Annotated[GetXcomFilterParams, Query()], -) -> XComResponse: - """Get an Airflow XCom from database - not other XCom Backends.""" - xcom_query = XComModel.get_many( - run_id=run_id, - key=key, - task_ids=task_id, - dag_ids=dag_id, - include_prior_dates=params.include_prior_dates, - ) - if params.offset is not None: - xcom_query = xcom_query.where(XComModel.value.is_not(None)).order_by(None) - if params.offset >= 0: - xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset) - else: - xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset) - else: - xcom_query = xcom_query.where(XComModel.map_index == params.map_index) - - # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. - # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead - # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` - # (which automatically deserializes using the backend), we avoid potential - # performance hits from retrieving large data files into the API server. - result = session.scalars(xcom_query).first() - if result is None: - if params.offset is None: - message = ( - f"XCom with {key=} map_index={params.map_index} not found for " - f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" - ) - else: - message = ( - f"XCom with {key=} offset={params.offset} not found for " - f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" - ) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"reason": "not_found", "message": message}, - ) - - return XComResponse(key=key, value=result.value) - - -@router.get( - "/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}", + "/{dag_id}/{run_id}/{task_id}/{key:path}/item/{offset}", description="Get a single XCom value from a mapped task by sequence index", ) def get_mapped_xcom_by_index( dag_id: str, run_id: str, task_id: str, - key: str, + key: Annotated[str, Path(min_length=1)], offset: int, session: SessionDep, ) -> XComSequenceIndexResponse: @@ -229,14 +133,14 @@ class GetXComSliceFilterParams(BaseModel): @router.get( - "/{dag_id}/{run_id}/{task_id}/{key}/slice", + "/{dag_id}/{run_id}/{task_id}/{key:path}/slice", description="Get XCom values from a mapped task by sequence slice", ) def get_mapped_xcom_by_slice( dag_id: str, run_id: str, task_id: str, - key: str, + key: Annotated[str, Path(min_length=1)], params: Annotated[GetXComSliceFilterParams, Query()], session: SessionDep, ) -> XComSequenceSliceResponse: @@ -310,17 +214,113 @@ def get_mapped_xcom_by_slice( return XComSequenceSliceResponse(values) +@router.head( + "/{dag_id}/{run_id}/{task_id}/{key:path}", + responses={ + status.HTTP_200_OK: { + "description": "Metadata about the number of matching XCom values", + "headers": { + "Content-Range": { + "schema": {"pattern": r"^map_indexes \d+$"}, + "description": "The number of (mapped) XCom values found for this task.", + }, + }, + }, + }, + description="Returns the count of mapped XCom values found in the `Content-Range` response header", +) +def head_xcom( + response: Response, + session: SessionDep, + xcom_query: Annotated[Select, Depends(xcom_query)], + map_index: Annotated[int | None, Query()] = None, +) -> None: + """Get the count of XComs from database - not other XCom Backends.""" + if map_index is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"}, + ) + + count = get_query_count(xcom_query, session=session) + # Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines + # "bytes" but we can add our own) + response.headers["Content-Range"] = f"map_indexes {count}" + + +class GetXcomFilterParams(BaseModel): + """Class to house the params that can optionally be set for Get XCom.""" + + map_index: int = -1 + include_prior_dates: bool = False + offset: int | None = None + + +@router.get( + "/{dag_id}/{run_id}/{task_id}/{key:path}", + description="Get a single XCom Value", +) +def get_xcom( + dag_id: str, + run_id: str, + task_id: str, + key: Annotated[str, Path(min_length=1)], + session: SessionDep, + params: Annotated[GetXcomFilterParams, Query()], +) -> XComResponse: + """Get an Airflow XCom from database - not other XCom Backends.""" + xcom_query = XComModel.get_many( + run_id=run_id, + key=key, + task_ids=task_id, + dag_ids=dag_id, + include_prior_dates=params.include_prior_dates, + ) + if params.offset is not None: + xcom_query = xcom_query.where(XComModel.value.is_not(None)).order_by(None) + if params.offset >= 0: + xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset) + else: + xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset) + else: + xcom_query = xcom_query.where(XComModel.map_index == params.map_index) + + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. + # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead + # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` + # (which automatically deserializes using the backend), we avoid potential + # performance hits from retrieving large data files into the API server. + result = session.scalars(xcom_query).first() + if result is None: + if params.offset is None: + message = ( + f"XCom with {key=} map_index={params.map_index} not found for " + f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) + else: + message = ( + f"XCom with {key=} offset={params.offset} not found for " + f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"reason": "not_found", "message": message}, + ) + + return XComResponse(key=key, value=result.value) + + # TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the URL and just use the info in # the token @router.post( - "/{dag_id}/{run_id}/{task_id}/{key}", + "/{dag_id}/{run_id}/{task_id}/{key:path}", status_code=status.HTTP_201_CREATED, ) def set_xcom( dag_id: str, run_id: str, task_id: str, - key: Annotated[str, StringConstraints(min_length=1)], + key: Annotated[str, Path(min_length=1)], session: SessionDep, value: Annotated[ JsonValue, @@ -431,7 +431,7 @@ def set_xcom( @router.delete( - "/{dag_id}/{run_id}/{task_id}/{key}", + "/{dag_id}/{run_id}/{task_id}/{key:path}", responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, description="Delete a single XCom Value", ) @@ -440,7 +440,7 @@ def delete_xcom( dag_id: str, run_id: str, task_id: str, - key: str, + key: Annotated[str, Path(min_length=1)], map_index: Annotated[int, Query()] = -1, ): """Delete a single XCom Value.""" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index d7a4837ddcc17..395e45f0a9e46 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -161,6 +161,18 @@ def test_should_raise_404_for_non_existent_xcom(self, test_client): assert response.status_code == 404 assert response.json()["detail"] == f"XCom entry with key: `{TEST_XCOM_KEY_2}` not found" + def test_should_respond_200_native_with_slash_key(self, test_client): + slash_key = "folder/sub/value" + self._create_xcom(slash_key, TEST_XCOM_VALUE) + # Use raw slash_key directly - FastAPI with :path converter handles it + response = test_client.get( + f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{slash_key}" + ) + assert response.status_code == 200 + current_data = response.json() + assert current_data["key"] == slash_key + assert current_data["value"] == json.dumps(TEST_XCOM_VALUE) + @pytest.mark.parametrize( "params, expected_value", [ @@ -630,6 +642,23 @@ def test_should_respond_403(self, unauthorized_test_client): ) assert response.status_code == 403 + def test_create_xcom_entry_with_slash_key(self, test_client): + slash_key = "a/b/c" + body = XComCreateBody(key=slash_key, value=TEST_XCOM_VALUE) + response = test_client.post( + f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries", + json=body.dict(), + ) + assert response.status_code == 201 + assert response.json()["key"] == slash_key + # Verify retrieval via encoded path + get_resp = test_client.get( + f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{slash_key}" + ) + assert get_resp.status_code == 200 + assert get_resp.json()["key"] == slash_key + assert get_resp.json()["value"] == json.dumps(TEST_XCOM_VALUE) + class TestPatchXComEntry(TestXComEndpoint): @pytest.mark.parametrize( @@ -657,7 +686,8 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai # Ensure the XCom entry exists before updating if expected_status != 404: self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE) - new_value = XComModel.serialize_value(patch_body["value"]) + # The value is double-serialized: first json.dumps(patch_body["value"]), then json.dumps() again + new_value = json.dumps(json.dumps(patch_body["value"])) response = test_client.patch( f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{key}", @@ -667,7 +697,7 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai assert response.status_code == expected_status if expected_status == 200: - assert response.json()["value"] == XComModel.serialize_value(new_value) + assert response.json()["value"] == new_value else: assert response.json()["detail"] == expected_detail check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None) @@ -685,3 +715,17 @@ def test_should_respond_403(self, unauthorized_test_client): json={}, ) assert response.status_code == 403 + + def test_patch_xcom_entry_with_slash_key(self, test_client, session): + slash_key = "x/y" + self._create_xcom(slash_key, TEST_XCOM_VALUE) + new_value = {"updated": True} + response = test_client.patch( + f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{slash_key}", + json={"value": new_value}, + ) + assert response.status_code == 200 + assert response.json()["key"] == slash_key + # The value is double-serialized: first json.dumps(new_value), then json.dumps() again + assert response.json()["value"] == json.dumps(json.dumps(new_value)) + check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 9728325cc2392..96128bc810de8 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2076,6 +2076,85 @@ def test_xcom_clearing_without_keys_to_clear(self, create_runtime_ti, mock_super mock_delete.assert_not_called() + def test_xcom_push_pull_with_slash_in_key(self, create_runtime_ti, mock_supervisor_comms): + """ + Ensure that XCom keys containing slashes are correctly quoted/unquoted + and do not break API routes (no 400/404). + """ + + class PushOperator(BaseOperator): + def execute(self, context): + context["ti"].xcom_push(key="some/key/with/slash", value="slash_value") + + task = PushOperator(task_id="push_task") + runtime_ti = create_runtime_ti(task=task, dag_id="test_dag") + + # Run the task (which should trigger xcom_push) + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + # Verify supervisor received a SetXCom with quoted key + called_args = [ + call.kwargs.get("msg") or call.args[0] for call in mock_supervisor_comms.send.call_args_list + ] + assert any(getattr(arg, "key", None) == "some/key/with/slash" for arg in called_args) + + ser_value = BaseXCom.serialize_value("slash_value") + mock_supervisor_comms.send.reset_mock() + mock_supervisor_comms.send.return_value = XComSequenceSliceResult( + key="some/key/with/slash", + root=[ser_value], + ) + + pulled_value = runtime_ti.xcom_pull(key="some/key/with/slash", task_ids="push_task") + assert pulled_value == "slash_value" + + # Key should NOT be quoted here - client API will handle encoding + mock_supervisor_comms.send.assert_any_call( + GetXComSequenceSlice( + key="some/key/with/slash", + dag_id="test_dag", + run_id="test_run", + task_id="push_task", + map_index=0, + include_prior_dates=False, + start=None, + stop=None, + step=None, + type="GetXComSequenceSlice", + ) + ) + + def test_taskflow_dict_return_with_slash_key(self, create_runtime_ti, mock_supervisor_comms): + """ + High-level: Ensure TaskFlow returning dict with slash in key doesn't 404 during XCom push. + """ + + @dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3)) + def dag_with_slash_key(): + @task_decorator + def dict_task(): + return {"key with slash /": "Some Value"} + + return dict_task() # returns XComArg + + dag_obj = dag_with_slash_key() + task_op = dag_obj.get_task("dict_task") + runtime_ti = create_runtime_ti(task=task_op, dag_id=dag_obj.dag_id) + + # Run task instance → should trigger TaskFlow dict expansion + XCom push + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + # Mock supervisor response to simulate retrieval + ser_value = BaseXCom.serialize_value("Some Value") + mock_supervisor_comms.send.reset_mock() + mock_supervisor_comms.send.return_value = XComSequenceSliceResult( + key="key/slash", + root=[ser_value], + ) + + pulled = runtime_ti.xcom_pull(key="key/slash", task_ids="dict_task") + assert pulled == "Some Value" + class TestXComAfterTaskExecution: @pytest.mark.parametrize(