Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class GetXComSliceFilterParams(BaseModel):
start: int | None = None
stop: int | None = None
step: int | None = None
include_prior_dates: bool = False


@router.get(
Expand All @@ -249,6 +250,7 @@ def get_mapped_xcom_by_slice(
key=key,
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=params.include_prior_dates,
session=session,
)
query = query.order_by(None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@
from airflow.api_fastapi.execution_api.versions.v2025_08_10 import (
AddDagRunStateFieldAndPreviousEndpoint,
AddDagVersionIdField,
AddIncludePriorDatesToGetXComSlice,
)

bundle = VersionBundle(
HeadVersion(),
Version("2025-08-10", AddDagVersionIdField, AddDagRunStateFieldAndPreviousEndpoint),
Version(
"2025-08-10",
AddDagVersionIdField,
AddDagRunStateFieldAndPreviousEndpoint,
AddIncludePriorDatesToGetXComSlice,
),
Version("2025-05-20", DowngradeUpstreamMapIndexes),
Version("2025-04-28", AddRenderedMapIndexField),
Version("2025-04-11"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema

from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TaskInstance, TIRunContext
from airflow.api_fastapi.execution_api.routes.xcoms import GetXComSliceFilterParams


class AddDagVersionIdField(VersionChange):
Expand All @@ -45,3 +46,13 @@ def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[m
"""Remove the `state` field from the dag_run object when converting to the previous version."""
if "dag_run" in response.body and isinstance(response.body["dag_run"], dict):
response.body["dag_run"].pop("state", None)


class AddIncludePriorDatesToGetXComSlice(VersionChange):
"""Add the `include_prior_dates` field to GetXComSliceFilterParams."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(GetXComSliceFilterParams).field("include_prior_dates").didnt_exist,
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
from fastapi import FastAPI, HTTPException, Path, Request, status

from airflow._shared.timezones import timezone
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.models.dagrun import DagRun
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -273,6 +274,54 @@ def __init__(self, *, x, **kwargs):
assert response.status_code == 200
assert response.json() == ["f", "o", "b"][key]

@pytest.mark.parametrize(
"include_prior_dates, expected_xcoms",
[[True, ["earlier_value", "later_value"]], [False, ["later_value"]]],
)
def test_xcom_get_slice_accepts_include_prior_dates(
self, client, dag_maker, session, include_prior_dates, expected_xcoms
):
"""Test that the slice endpoint accepts include_prior_dates parameter and works correctly."""

with dag_maker(dag_id="dag"):
EmptyOperator(task_id="task")

earlier_run = dag_maker.create_dagrun(
run_id="earlier_run", logical_date=timezone.parse("2024-01-01T00:00:00Z")
)
later_run = dag_maker.create_dagrun(
run_id="later_run", logical_date=timezone.parse("2024-01-02T00:00:00Z")
)

earlier_ti = earlier_run.get_task_instance("task")
later_ti = later_run.get_task_instance("task")

earlier_xcom = XComModel(
key="test_key",
value="earlier_value",
dag_run_id=earlier_ti.dag_run.id,
run_id=earlier_ti.run_id,
task_id=earlier_ti.task_id,
dag_id=earlier_ti.dag_id,
)
later_xcom = XComModel(
key="test_key",
value="later_value",
dag_run_id=later_ti.dag_run.id,
run_id=later_ti.run_id,
task_id=later_ti.task_id,
dag_id=later_ti.dag_id,
)
session.add_all([earlier_xcom, later_xcom])
session.commit()

response = client.get(
f"/execution/xcoms/dag/later_run/task/test_key/slice?include_prior_dates={include_prior_dates}"
)
assert response.status_code == 200

assert response.json() == expected_xcoms


class TestXComsSetEndpoint:
@pytest.mark.parametrize(
Expand Down
3 changes: 3 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def get_sequence_slice(
start: int | None,
stop: int | None,
step: int | None,
include_prior_dates: bool = False,
) -> XComSequenceSliceResponse:
params = {}
if start is not None:
Expand All @@ -504,6 +505,8 @@ def get_sequence_slice(
params["stop"] = stop
if step is not None:
params["step"] = step
if include_prior_dates:
params["include_prior_dates"] = include_prior_dates
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params)
return XComSequenceSliceResponse.model_validate_json(resp.read())

Expand Down
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def get_all(
dag_id: str,
task_id: str,
run_id: str,
include_prior_dates: bool = False,
) -> Any:
"""
Retrieve all XCom values for a task, typically from all map indexes.
Expand All @@ -289,6 +290,9 @@ def get_all(
:param run_id: DAG run ID for the task.
:param dag_id: DAG ID to pull XComs from.
:param task_id: Task ID to pull XComs from.
:param include_prior_dates: If *False* (default), only XComs from the
specified DAG run are returned. If *True*, the latest matching XComs are
returned regardless of the run they belong to.
:return: List of all XCom values if found.
"""
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
Expand All @@ -303,6 +307,7 @@ def get_all(
start=None,
stop=None,
step=None,
include_prior_dates=include_prior_dates,
),
)

Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ class GetXComSequenceSlice(BaseModel):
start: int | None
stop: int | None
step: int | None
include_prior_dates: bool = False
type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice"


Expand Down
9 changes: 8 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
resp = xcom
elif isinstance(msg, GetXComSequenceSlice):
xcoms = self.client.xcoms.get_sequence_slice(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start, msg.stop, msg.step
msg.dag_id,
msg.run_id,
msg.task_id,
msg.key,
msg.start,
msg.stop,
msg.step,
msg.include_prior_dates,
)
resp = XComSequenceSliceResult.from_response(xcoms)
elif isinstance(msg, DeferTask):
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def xcom_pull(
key=key,
task_id=t_id,
dag_id=dag_id,
include_prior_dates=include_prior_dates,
)

if values is None:
Expand Down
3 changes: 2 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1893,10 +1893,11 @@ def watched_subprocess(self, mocker):
start=None,
stop=None,
step=None,
include_prior_dates=False,
),
{"root": ["foo", "bar"], "type": "XComSequenceSliceResult"},
"xcoms.get_sequence_slice",
("test_dag", "test_run", "test_task", "test_key", None, None, None),
("test_dag", "test_run", "test_task", "test_key", None, None, None, False),
{},
XComSequenceSliceResult(root=["foo", "bar"]),
None,
Expand Down
55 changes: 53 additions & 2 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,8 +2066,7 @@ def test_xcom_pull_from_custom_xcom_backend(

class CustomOperator(BaseOperator):
def execute(self, context):
value = context["ti"].xcom_pull(task_ids="pull_task", key="key")
print(f"Pulled XCom Value: {value}")
context["ti"].xcom_pull(task_ids="pull_task", key="key")

task = CustomOperator(task_id="pull_task")
runtime_ti = create_runtime_ti(task=task)
Expand All @@ -2078,6 +2077,7 @@ def execute(self, context):
dag_id="test_dag",
task_id="pull_task",
run_id="test_run",
include_prior_dates=False,
)

assert not any(
Expand All @@ -2094,6 +2094,57 @@ def execute(self, context):
for x in mock_supervisor_comms.send.call_args_list
)

@pytest.mark.parametrize(
("include_prior_dates", "expected_value"),
[
pytest.param(True, True, id="include_prior_dates_true"),
pytest.param(False, False, id="include_prior_dates_false"),
pytest.param(None, False, id="include_prior_dates_default"),
],
)
def test_xcom_pull_with_include_prior_dates(
self,
create_runtime_ti,
mock_supervisor_comms,
include_prior_dates,
expected_value,
):
"""Test that xcom_pull with include_prior_dates parameter correctly behaves as we expect."""
task = BaseOperator(task_id="pull_task")
runtime_ti = create_runtime_ti(task=task)

value = {"previous_run_data": "test_value"}
ser_value = BaseXCom.serialize_value(value)

def mock_send_side_effect(*args, **kwargs):
msg = kwargs.get("msg") or args[0]
if isinstance(msg, GetXComSequenceSlice):
assert msg.include_prior_dates is expected_value, (
f"include_prior_dates should be {expected_value} in GetXComSequenceSlice"
)
return XComSequenceSliceResult(root=[ser_value])
return XComResult(key="test_key", value=None)

mock_supervisor_comms.send.side_effect = mock_send_side_effect
kwargs = {"key": "test_key", "task_ids": "previous_task"}
if include_prior_dates is not None:
kwargs["include_prior_dates"] = include_prior_dates
result = runtime_ti.xcom_pull(**kwargs)
assert result == value

mock_supervisor_comms.send.assert_called_once_with(
msg=GetXComSequenceSlice(
key="test_key",
dag_id=runtime_ti.dag_id,
run_id=runtime_ti.run_id,
task_id="previous_task",
start=None,
stop=None,
step=None,
include_prior_dates=expected_value,
),
)


class TestDagParamRuntime:
DEFAULT_ARGS = {
Expand Down