diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index a9ed4a5b48d20..31bfcb31a0a3a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -230,6 +230,7 @@ class GetXComSliceFilterParams(BaseModel): start: int | None = None stop: int | None = None step: int | None = None + include_prior_dates: bool = False @router.get( @@ -249,6 +250,7 @@ def get_mapped_xcom_by_slice( key=key, task_ids=task_id, dag_ids=dag_id, + include_prior_dates=params.include_prior_dates, session=session, ) query = query.order_by(None) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 0555c5fd46bda..1fe0b9d155fc6 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -24,11 +24,17 @@ from airflow.api_fastapi.execution_api.versions.v2025_08_10 import ( AddDagRunStateFieldAndPreviousEndpoint, AddDagVersionIdField, + AddIncludePriorDatesToGetXComSlice, ) bundle = VersionBundle( HeadVersion(), - Version("2025-08-10", AddDagVersionIdField, AddDagRunStateFieldAndPreviousEndpoint), + Version( + "2025-08-10", + AddDagVersionIdField, + AddDagRunStateFieldAndPreviousEndpoint, + AddIncludePriorDatesToGetXComSlice, + ), Version("2025-05-20", DowngradeUpstreamMapIndexes), Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py index 3b1aa885d8f73..c6c95c7dfdc38 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_08_10.py @@ -20,6 +20,7 @@ from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TaskInstance, TIRunContext +from airflow.api_fastapi.execution_api.routes.xcoms import GetXComSliceFilterParams class AddDagVersionIdField(VersionChange): @@ -45,3 +46,13 @@ def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[m """Remove the `state` field from the dag_run object when converting to the previous version.""" if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): response.body["dag_run"].pop("state", None) + + +class AddIncludePriorDatesToGetXComSlice(VersionChange): + """Add the `include_prior_dates` field to GetXComSliceFilterParams.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(GetXComSliceFilterParams).field("include_prior_dates").didnt_exist, + ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index 1b10e81cd2338..2d03eb45367de 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -25,6 +25,7 @@ import pytest from fastapi import FastAPI, HTTPException, Path, Request, status +from airflow._shared.timezones import timezone from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.dagrun import DagRun from airflow.models.taskmap import TaskMap @@ -273,6 +274,54 @@ def __init__(self, *, x, **kwargs): assert response.status_code == 200 assert response.json() == ["f", "o", "b"][key] + @pytest.mark.parametrize( + "include_prior_dates, expected_xcoms", + [[True, ["earlier_value", "later_value"]], [False, ["later_value"]]], + ) + def test_xcom_get_slice_accepts_include_prior_dates( + self, client, dag_maker, session, include_prior_dates, expected_xcoms + ): + """Test that the slice endpoint accepts include_prior_dates parameter and works correctly.""" + + with dag_maker(dag_id="dag"): + EmptyOperator(task_id="task") + + earlier_run = dag_maker.create_dagrun( + run_id="earlier_run", logical_date=timezone.parse("2024-01-01T00:00:00Z") + ) + later_run = dag_maker.create_dagrun( + run_id="later_run", logical_date=timezone.parse("2024-01-02T00:00:00Z") + ) + + earlier_ti = earlier_run.get_task_instance("task") + later_ti = later_run.get_task_instance("task") + + earlier_xcom = XComModel( + key="test_key", + value="earlier_value", + dag_run_id=earlier_ti.dag_run.id, + run_id=earlier_ti.run_id, + task_id=earlier_ti.task_id, + dag_id=earlier_ti.dag_id, + ) + later_xcom = XComModel( + key="test_key", + value="later_value", + dag_run_id=later_ti.dag_run.id, + run_id=later_ti.run_id, + task_id=later_ti.task_id, + dag_id=later_ti.dag_id, + ) + session.add_all([earlier_xcom, later_xcom]) + session.commit() + + response = client.get( + f"/execution/xcoms/dag/later_run/task/test_key/slice?include_prior_dates={include_prior_dates}" + ) + assert response.status_code == 200 + + assert response.json() == expected_xcoms + class TestXComsSetEndpoint: @pytest.mark.parametrize( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 6ec13cd070242..ef7c896909cfe 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -496,6 +496,7 @@ def get_sequence_slice( start: int | None, stop: int | None, step: int | None, + include_prior_dates: bool = False, ) -> XComSequenceSliceResponse: params = {} if start is not None: @@ -504,6 +505,8 @@ def get_sequence_slice( params["stop"] = stop if step is not None: params["step"] = step + if include_prior_dates: + params["include_prior_dates"] = include_prior_dates resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params) return XComSequenceSliceResponse.model_validate_json(resp.read()) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index ce38a8679a37b..7fdc53aa2e51e 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -275,6 +275,7 @@ def get_all( dag_id: str, task_id: str, run_id: str, + include_prior_dates: bool = False, ) -> Any: """ Retrieve all XCom values for a task, typically from all map indexes. @@ -289,6 +290,9 @@ def get_all( :param run_id: DAG run ID for the task. :param dag_id: DAG ID to pull XComs from. :param task_id: Task ID to pull XComs from. + :param include_prior_dates: If *False* (default), only XComs from the + specified DAG run are returned. If *True*, the latest matching XComs are + returned regardless of the run they belong to. :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -303,6 +307,7 @@ def get_all( start=None, stop=None, step=None, + include_prior_dates=include_prior_dates, ), ) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 041f81fed067d..1e97ed71db7be 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -707,6 +707,7 @@ class GetXComSequenceSlice(BaseModel): start: int | None stop: int | None step: int | None + include_prior_dates: bool = False type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index cbfd8d207ecfc..b05919efac8e6 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1132,7 +1132,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp = xcom elif isinstance(msg, GetXComSequenceSlice): xcoms = self.client.xcoms.get_sequence_slice( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start, msg.stop, msg.step + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.start, + msg.stop, + msg.step, + msg.include_prior_dates, ) resp = XComSequenceSliceResult.from_response(xcoms) elif isinstance(msg, DeferTask): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index be680a08d045f..897543507f939 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -362,6 +362,7 @@ def xcom_pull( key=key, task_id=t_id, dag_id=dag_id, + include_prior_dates=include_prior_dates, ) if values is None: diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 0331a7a2b9ca6..80395bcf5e597 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1893,10 +1893,11 @@ def watched_subprocess(self, mocker): start=None, stop=None, step=None, + include_prior_dates=False, ), {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", - ("test_dag", "test_run", "test_task", "test_key", None, None, None), + ("test_dag", "test_run", "test_task", "test_key", None, None, None, False), {}, XComSequenceSliceResult(root=["foo", "bar"]), None, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index a88b244897d0b..2ba524090a361 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2066,8 +2066,7 @@ def test_xcom_pull_from_custom_xcom_backend( class CustomOperator(BaseOperator): def execute(self, context): - value = context["ti"].xcom_pull(task_ids="pull_task", key="key") - print(f"Pulled XCom Value: {value}") + context["ti"].xcom_pull(task_ids="pull_task", key="key") task = CustomOperator(task_id="pull_task") runtime_ti = create_runtime_ti(task=task) @@ -2078,6 +2077,7 @@ def execute(self, context): dag_id="test_dag", task_id="pull_task", run_id="test_run", + include_prior_dates=False, ) assert not any( @@ -2094,6 +2094,57 @@ def execute(self, context): for x in mock_supervisor_comms.send.call_args_list ) + @pytest.mark.parametrize( + ("include_prior_dates", "expected_value"), + [ + pytest.param(True, True, id="include_prior_dates_true"), + pytest.param(False, False, id="include_prior_dates_false"), + pytest.param(None, False, id="include_prior_dates_default"), + ], + ) + def test_xcom_pull_with_include_prior_dates( + self, + create_runtime_ti, + mock_supervisor_comms, + include_prior_dates, + expected_value, + ): + """Test that xcom_pull with include_prior_dates parameter correctly behaves as we expect.""" + task = BaseOperator(task_id="pull_task") + runtime_ti = create_runtime_ti(task=task) + + value = {"previous_run_data": "test_value"} + ser_value = BaseXCom.serialize_value(value) + + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + if isinstance(msg, GetXComSequenceSlice): + assert msg.include_prior_dates is expected_value, ( + f"include_prior_dates should be {expected_value} in GetXComSequenceSlice" + ) + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="test_key", value=None) + + mock_supervisor_comms.send.side_effect = mock_send_side_effect + kwargs = {"key": "test_key", "task_ids": "previous_task"} + if include_prior_dates is not None: + kwargs["include_prior_dates"] = include_prior_dates + result = runtime_ti.xcom_pull(**kwargs) + assert result == value + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComSequenceSlice( + key="test_key", + dag_id=runtime_ti.dag_id, + run_id=runtime_ti.run_id, + task_id="previous_task", + start=None, + stop=None, + step=None, + include_prior_dates=expected_value, + ), + ) + class TestDagParamRuntime: DEFAULT_ARGS = {