diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 80c8d485df994..8fe00520d89c4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -5403,9 +5403,7 @@ paths: description: Successful Response content: application/json: - schema: - type: 'null' - title: Response Delete Task Instance + schema: {} '401': content: application/json: @@ -7743,9 +7741,7 @@ paths: description: Successful Response content: application/json: - schema: - type: 'null' - title: Response Reparse Dag File + schema: {} '401': content: application/json: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 57b83f0dbb3a2..5ae77e5000fcd 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -18,6 +18,7 @@ import copy from typing import Annotated +from urllib.parse import unquote from fastapi import Depends, HTTPException, Query, status from sqlalchemy import and_, select @@ -80,6 +81,7 @@ def get_xcom_entry( stringify: Annotated[bool, Query()] = False, ) -> XComResponseNative | XComResponseString: """Get an XCom entry.""" + xcom_key = unquote(xcom_key) xcom_query = XComModel.get_many( run_id=dag_run_id, key=xcom_key, @@ -156,6 +158,7 @@ def get_xcom_entries( This endpoint allows specifying `~` as the dag_id, dag_run_id, task_id to retrieve XCom entries for all DAGs. """ + xcom_key = unquote(xcom_key) if xcom_key else None query = select(XComModel) if dag_id != "~": query = query.where(XComModel.dag_id == dag_id) @@ -242,6 +245,7 @@ def create_xcom_entry( ) # Check existing XCom + request_body.key = unquote(request_body.key) already_existing_query = XComModel.get_many( key=request_body.key, task_ids=task_id, @@ -315,6 +319,7 @@ def update_xcom_entry( ) -> XComResponseNative: """Update an existing XCom entry.""" # Check if XCom entry exists + xcom_key = unquote(xcom_key) xcom_new_value = XComModel.serialize_value(patch_body.value) xcom_entry = session.scalar( select(XComModel) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index b2399635499cc..846e51c10c7b8 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -19,6 +19,7 @@ import logging from typing import Annotated +from urllib.parse import unquote from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, status from pydantic import BaseModel, JsonValue, StringConstraints @@ -78,6 +79,7 @@ async def xcom_query( key: str, map_index: Annotated[int | None, Query()] = None, ) -> Select: + key = unquote(key) query = XComModel.get_many( run_id=run_id, key=key, @@ -143,6 +145,7 @@ def get_xcom( params: Annotated[GetXcomFilterParams, Query()], ) -> XComResponse: """Get an Airflow XCom from database - not other XCom Backends.""" + key = unquote(key) xcom_query = XComModel.get_many( run_id=run_id, key=key, @@ -196,6 +199,7 @@ def get_mapped_xcom_by_index( offset: int, session: SessionDep, ) -> XComSequenceIndexResponse: + key = unquote(key) xcom_query = XComModel.get_many( run_id=run_id, key=key, @@ -240,6 +244,7 @@ def get_mapped_xcom_by_slice( params: Annotated[GetXComSliceFilterParams, Query()], session: SessionDep, ) -> XComSequenceSliceResponse: + key = unquote(key) query = XComModel.get_many( run_id=run_id, key=key, @@ -360,7 +365,7 @@ def set_xcom( "message": "XCom key must be a non-empty string.", }, ) - + key = unquote(key) if mapped_length is not None: task_map = TaskMap( dag_id=dag_id, @@ -444,6 +449,7 @@ def delete_xcom( map_index: Annotated[int, Query()] = -1, ): """Delete a single XCom Value.""" + key = unquote(key) query = delete(XComModel).where( XComModel.key == key, XComModel.run_id == run_id, diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 0a353271c1025..f9e2a7f4118da 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -24,6 +24,7 @@ from functools import cache from http import HTTPStatus from typing import TYPE_CHECKING, Any, TypeVar +from urllib.parse import quote import certifi import httpx @@ -418,6 +419,7 @@ def __init__(self, client: Client): def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse: """Get the number of mapped XCom values.""" + key = quote(key, safe="") resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") # content_range: str | None @@ -444,6 +446,7 @@ def get( params.update({"map_index": map_index}) if include_prior_dates: params.update({"include_prior_dates": include_prior_dates}) + key = quote(key, safe="") try: resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) except ServerResponseError as e: @@ -483,6 +486,7 @@ def set( params = {"map_index": map_index} if mapped_length is not None and mapped_length >= 0: params["mapped_length"] = mapped_length + key = quote(key, safe="") self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value) # Any error from the server will anyway be propagated down to the supervisor, # so we choose to send a generic response to the supervisor over the server response to @@ -501,6 +505,7 @@ def delete( params = {} if map_index is not None and map_index >= 0: params = {"map_index": map_index} + key = quote(key, safe="") self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) # Any error from the server will anyway be propagated down to the supervisor, # so we choose to send a generic response to the supervisor over the server response to @@ -515,6 +520,7 @@ def get_sequence_item( key: str, offset: int, ) -> XComSequenceIndexResponse | ErrorResponse: + key = quote(key, safe="") try: resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}") except ServerResponseError as e: @@ -553,6 +559,7 @@ def get_sequence_slice( step: int | None, include_prior_dates: bool = False, ) -> XComSequenceSliceResponse: + key = quote(key, safe="") params = {} if start is not None: params["start"] = start diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bb92ff001544b..b86d47f82b2f0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -336,6 +336,7 @@ def xcom_pull( a non-str iterable), a list of matching XComs is returned. Elements in the list is ordered by item ordering in ``task_id`` and ``map_index``. """ + key = quote(key, safe="") if dag_id is None: dag_id = self.dag_id if run_id is None: @@ -1363,8 +1364,10 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): "Returned dictionary keys must be strings when using " f"multiple_outputs, found {key} ({type(key)}) instead" ) + for k, v in result.items(): - ti.xcom_push(k, v) + encoded_key = quote(k, safe="") + ti.xcom_push(encoded_key, v) _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 525d137a48108..b45ff4479b8dd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch +from urllib.parse import quote import pandas as pd import pytest @@ -1986,6 +1987,85 @@ def test_xcom_clearing_without_keys_to_clear(self, create_runtime_ti, mock_super mock_delete.assert_not_called() + def test_xcom_push_pull_with_slash_in_key(self, create_runtime_ti, mock_supervisor_comms): + """ + Ensure that XCom keys containing slashes are correctly quoted/unquoted + and do not break API routes (no 400/404). + """ + + class PushOperator(BaseOperator): + def execute(self, context): + context["ti"].xcom_push(key="some/key/with/slash", value="slash_value") + + task = PushOperator(task_id="push_task") + runtime_ti = create_runtime_ti(task=task, dag_id="test_dag") + + # Run the task (which should trigger xcom_push) + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + # Verify supervisor received a SetXCom with quoted key + called_args = [ + call.kwargs.get("msg") or call.args[0] for call in mock_supervisor_comms.send.call_args_list + ] + assert any(getattr(arg, "key", None) == "some/key/with/slash" for arg in called_args) + + ser_value = BaseXCom.serialize_value("slash_value") + mock_supervisor_comms.send.reset_mock() + mock_supervisor_comms.send.return_value = XComSequenceSliceResult( + key="some/key/with/slash", + root=[ser_value], + ) + + pulled_value = runtime_ti.xcom_pull(key="some/key/with/slash", task_ids="push_task") + assert pulled_value == "slash_value" + + expected_key = quote("some/key/with/slash", safe="") + mock_supervisor_comms.send.assert_any_call( + GetXComSequenceSlice( + key=expected_key, + dag_id="test_dag", + run_id="test_run", + task_id="push_task", + map_index=0, + include_prior_dates=False, + start=None, + stop=None, + step=None, + type="GetXComSequenceSlice", + ) + ) + + def test_taskflow_dict_return_with_slash_key(self, create_runtime_ti, mock_supervisor_comms): + """ + High-level: Ensure TaskFlow returning dict with slash in key doesn't 404 during XCom push. + """ + + @dag_decorator(schedule=None, start_date=timezone.datetime(2024, 12, 3)) + def dag_with_slash_key(): + @task_decorator + def dict_task(): + return {"key with slash /": "Some Value"} + + return dict_task() # returns XComArg + + dag_obj = dag_with_slash_key() + task_op = dag_obj.get_task("dict_task") + runtime_ti = create_runtime_ti(task=task_op, dag_id=dag_obj.dag_id) + + # Run task instance → should trigger TaskFlow dict expansion + XCom push + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + # Mock supervisor response to simulate retrieval + ser_value = BaseXCom.serialize_value("Some Value") + mock_supervisor_comms.send.reset_mock() + mock_supervisor_comms.send.return_value = XComSequenceSliceResult( + key="key/slash", + root=[ser_value], + ) + + pulled = runtime_ti.xcom_pull(key="key/slash", task_ids="dict_task") + assert pulled == "Some Value" + class TestXComAfterTaskExecution: @pytest.mark.parametrize(