diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index cd8287be97b73..f68395ac8c611 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -310,7 +310,7 @@ class TIRunContext(BaseModel): connections: Annotated[list[ConnectionResponse], Field(default_factory=list)] """Connections that can be accessed by the task instance.""" - upstream_map_indexes: dict[str, int] | None = None + upstream_map_indexes: dict[str, int | list[int] | None] | None = None next_method: str | None = None """Method to call. Set when task resumes from a trigger.""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index f97a684e3ffed..00a5ea9a5e627 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -19,7 +19,8 @@ import json from collections import defaultdict -from typing import Annotated, Any +from collections.abc import Iterator +from typing import TYPE_CHECKING, Annotated, Any from uuid import UUID import structlog @@ -55,9 +56,14 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.xcom import XComModel +from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState +if TYPE_CHECKING: + from airflow.sdk.types import Operator + + router = VersionedAPIRouter() ti_id_router = VersionedAPIRouter( @@ -82,7 +88,10 @@ response_model_exclude_unset=True, ) def ti_run( - task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep + task_instance_id: UUID, + ti_run_payload: Annotated[TIEnterRunningPayload, Body()], + session: SessionDep, + request: Request, ) -> TIRunContext: """ Run a TaskInstance. @@ -233,6 +242,11 @@ def ti_run( or 0 ) + if dag := request.app.state.dag_bag.get_dag(ti.dag_id): + upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index)) + else: + upstream_map_indexes = None + context = TIRunContext( dag_run=dr, task_reschedule_count=task_reschedule_count, @@ -242,6 +256,7 @@ def ti_run( connections=[], xcom_keys_to_clear=xcom_keys, should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), + upstream_map_indexes=upstream_map_indexes, ) # Only set if they are non-null @@ -257,6 +272,27 @@ def ti_run( ) +def _get_upstream_map_indexes( + task: Operator, ti_map_index: int +) -> Iterator[tuple[str, int | list[int] | None]]: + for upstream_task in task.upstream_list: + map_indexes: int | list[int] | None + if not isinstance(upstream_task.task_group, MappedTaskGroup): + # regular tasks or non-mapped task groups + map_indexes = None + elif task.task_group == upstream_task.task_group: + # tasks in the same mapped task group + # the task should use the map_index as the previous task in the same mapped task group + map_indexes = ti_map_index + else: + # tasks not in the same mapped task group + # the upstream mapped task group should combine the xcom as a list and return it + mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count() + map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None + + yield upstream_task.task_id, map_indexes + + @ti_id_router.patch( "/{task_instance_id}/state", status_code=status.HTTP_204_NO_CONTENT, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 8af11c04df960..4884c38e51dab 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -178,6 +178,7 @@ def test_ti_run_state_to_running( "consumed_asset_events": [], }, "task_reschedule_count": 0, + "upstream_map_indexes": None, "max_tries": max_tries, "should_retry": should_retry, "variables": [], @@ -256,6 +257,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance, assert response.json() == { "dag_run": mock.ANY, "task_reschedule_count": 0, + "upstream_map_indexes": None, "max_tries": 0, "should_retry": False, "variables": [], @@ -317,6 +319,7 @@ def test_next_kwargs_determines_start_date_update(self, client, session, create_ assert response.json() == { "dag_run": mock.ANY, "task_reschedule_count": 0, + "upstream_map_indexes": None, "max_tries": 0, "should_retry": False, "variables": [], diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index ef2f434fddb96..c0193ccd5df0b 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2093,7 +2093,7 @@ def _create_task_instance( run_type: str = "manual", try_number: int = 1, map_index: int | None = -1, - upstream_map_indexes: dict[str, int] | None = None, + upstream_map_indexes: dict[str, int | list[int] | None] | None = None, task_reschedule_count: int = 0, ti_id: UUID | None = None, conf: dict[str, Any] | None = None, @@ -2148,6 +2148,7 @@ def _create_task_instance( task_reschedule_count=task_reschedule_count, max_tries=task_retries if max_tries is None else max_tries, should_retry=should_retry if should_retry is not None else try_number <= task_retries, + upstream_map_indexes=upstream_map_indexes, ) if upstream_map_indexes is not None: diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 7c7647635ea05..9ecae0cb1bec4 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -477,7 +477,9 @@ class TIRunContext(BaseModel): max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None - upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Map Indexes")] = None + upstream_map_indexes: Annotated[ + dict[str, int | list[int] | None] | None, Field(title="Upstream Map Indexes") + ] = None next_method: Annotated[str | None, Field(title="Next Method")] = None next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 1adcb7efaa7a6..2a93585304cb0 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -339,22 +339,17 @@ def resolve(self, context: Mapping[str, Any]) -> Any: if self.operator.is_mapped: return LazyXComSequence(xcom_arg=self, ti=ti) tg = self.operator.get_closest_mapped_task_group() - result = None if tg is None: - # regular task - result = ti.xcom_pull( - task_ids=task_id, - key=self.key, - default=NOTSET, - map_indexes=None, - ) + map_indexes = None else: - # task from a task group - result = ti.xcom_pull( - task_ids=task_id, - key=self.key, - default=NOTSET, - ) + upstream_map_indexes = getattr(ti, "_upstream_map_indexes", {}) + map_indexes = upstream_map_indexes.get(task_id, None) + result = ti.xcom_pull( + task_ids=task_id, + key=self.key, + default=NOTSET, + map_indexes=map_indexes, + ) if not isinstance(result, ArgNotSet): return result if self.key == XCOM_RETURN_KEY: diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index d6b8ca8da4e96..6cdb4b520f21e 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -627,9 +627,22 @@ def xcom_get(): "tg.t2": range(3), "t3": [None], } + upstream_map_indexes_per_task_id = { + ("tg.t1", 0): {}, + ("tg.t1", 1): {}, + ("tg.t1", 2): {}, + ("tg.t2", 0): {"tg.t1": 0}, + ("tg.t2", 1): {"tg.t1": 1}, + ("tg.t2", 2): {"tg.t1": 2}, + ("t3", None): {"tg.t2": [0, 1, 2]}, + } for task in dag.tasks: for map_index in expansion_per_task_id[task.task_id]: - mapped_ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + mapped_ti = create_runtime_ti( + task=task.prepare_for_execution(), + map_index=map_index, + upstream_map_indexes=upstream_map_indexes_per_task_id[(task.task_id, map_index)], + ) context = mapped_ti.get_template_context() mapped_ti.task.render_template_fields(context) mapped_ti.task.execute(context)