diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 720d5653f32ff..9b0e2dc7a27b7 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -593,6 +593,7 @@ def get_previous_successful_dagrun( def get_task_instance_count( dag_id: str, session: SessionDep, + map_index: Annotated[int | None, Query()] = None, task_ids: Annotated[list[str] | None, Query()] = None, task_group_id: Annotated[str | None, Query()] = None, logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None, @@ -605,6 +606,9 @@ def get_task_instance_count( if task_ids: query = query.where(TI.task_id.in_(task_ids)) + if map_index is not None: + query = query.where(TI.map_index == map_index) + if logical_dates: query = query.where(TI.logical_date.in_(logical_dates)) @@ -615,7 +619,12 @@ def get_task_instance_count( group_tasks = _get_group_tasks(dag_id, task_group_id, session, logical_dates, run_ids) # Get unique (task_id, map_index) pairs + task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks] + + if map_index is not None: + task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks if ti.map_index == map_index] + if not task_map_pairs: # If no task group tasks found, default to checking the task group ID itself # This matches the behavior in _get_external_task_group_task_ids @@ -643,6 +652,7 @@ def get_task_instance_count( def get_task_instance_states( dag_id: str, session: SessionDep, + map_index: Annotated[int | None, Query()] = None, task_ids: Annotated[list[str] | None, Query()] = None, task_group_id: Annotated[str | None, Query()] = None, logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None, @@ -664,12 +674,21 @@ def get_task_instance_states( results = session.scalars(query).all() - [run_id_task_state_map[task.run_id].update({task.task_id: task.state}) for task in results] - if task_group_id: group_tasks = _get_group_tasks(dag_id, task_group_id, session, logical_dates, run_ids) - [run_id_task_state_map[task.run_id].update({task.task_id: task.state}) for task in group_tasks] + results = results + group_tasks if task_ids else group_tasks + + if map_index is not None: + results = [task for task in results if task.map_index == map_index] + [ + run_id_task_state_map[task.run_id].update( + {task.task_id: task.state} + if task.map_index < 0 + else {f"{task.task_id}_{task.map_index}": task.state} + ) + for task in results + ] return TaskStatesResponse(task_states=run_id_task_state_map) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 0752b59a01ec9..5df6a03261523 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -430,6 +430,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - elif isinstance(msg, GetTICount): resp = self.client.task_instances.get_count( dag_id=msg.dag_id, + map_index=msg.map_index, task_ids=msg.task_ids, task_group_id=msg.task_group_id, logical_dates=msg.logical_dates, @@ -440,6 +441,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - elif isinstance(msg, GetTaskStates): run_id_task_state_map = self.client.task_instances.get_task_states( dag_id=msg.dag_id, + map_index=msg.map_index, task_ids=msg.task_ids, task_group_id=msg.task_group_id, logical_dates=msg.logical_dates, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 49d4f42b9a677..8af11c04df960 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1447,6 +1447,156 @@ def test_get_count_with_mixed_states(self, client, session, create_task_instance assert response.status_code == 200 assert response.json() == 2 + def test_get_count_with_map_index_less_than_zero(self, client, session, create_task_instance): + create_task_instance(task_id="task1", state=State.SUCCESS, run_id="runid1", dag_id="map_index_test") + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "map_index_test", "states": [State.SUCCESS], "map_index": -1}, + ) + assert response.status_code == 200 + assert response.json() == 1 + + def test_get_count_with_multiple_tasks_and_map_index_less_than_zero( + self, dag_maker, client, session, create_task_instance + ): + with dag_maker("test_get_count_with_multiple_tasks_and_map_index_less_than_zero"): + EmptyOperator(task_id="task1") + EmptyOperator(task_id="task2") + EmptyOperator(task_id="task3") + + dr = dag_maker.create_dagrun() + + tis = dr.get_task_instances() + + # Set different states for the task instances + for ti, state in zip(tis, [State.SUCCESS, State.FAILED, State.SKIPPED]): + ti.state = state + session.merge(ti) + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={ + "dag_id": "test_get_count_with_multiple_tasks_and_map_index_less_than_zero", + "map_index": -1, + }, + ) + assert response.status_code == 200 + assert response.json() == 3 + + @pytest.mark.parametrize( + ["map_index", "dynamic_task_args", "expected_count"], + ( + pytest.param(None, [1, 2, 3], 4, id="use-default-map-index"), + pytest.param(-1, [1, 2, 3], 1, id="map-index-(-1)"), + pytest.param(0, [1, 2, 3], 1, id="map-index-0"), + pytest.param(1, [1, 2, 3], 1, id="map-index-1"), + pytest.param(2, [1, 2, 3], 1, id="map-index-2"), + ), + ) + def test_get_count_for_dynamic_task_mapping( + self, dag_maker, client, session, map_index, dynamic_task_args, expected_count + ): + """ + case 1: map_index is None, it should fetch all the tasks + other cases: when map index is provided, it should return the count of tasks that falls under the map index + """ + with dag_maker(session=session) as dag: + EmptyOperator(task_id="op1") + + @dag.task() + def add_one(x): + return [x + 1] + + add_one.expand(x=dynamic_task_args) + + dr = dag_maker.create_dagrun() + + tis = dr.get_task_instances() + + for ti in tis: + ti.state = State.SUCCESS + session.merge(ti) + session.commit() + + map_index = {} if map_index is None else {"map_index": map_index} + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": dr.dag_id, "run_ids": [dr.run_id], **map_index}, + ) + assert response.status_code == 200 + assert response.json() == expected_count + + @pytest.mark.parametrize( + [ + "map_index", + "dynamic_task_args", + "task_ids", + "task_group_name", + "expected_count", + ], + ( + pytest.param(None, [1, 2, 3], None, None, 5, id="use-default-map-index-None"), + pytest.param(-1, [1, 2, 3], ["task1"], None, 1, id="with-task-ids-and-map-index-(-1)"), + pytest.param(None, [1, 2, 3], None, "group1", 4, id="with-task-group-id-and-map-index-None"), + pytest.param(0, [1, 2, 3], None, "group1", 1, id="with-task-group-id-and-map-index-0"), + pytest.param(-1, [1, 2, 3], None, "group1", 1, id="with-task-group-id-and-map-index-(-1)"), + ), + ) + def test_get_count_mix_of_task_and_task_group_dynamic_task_mapping( + self, + dag_maker, + client, + session, + map_index, + dynamic_task_args, + task_ids, + task_group_name, + expected_count, + ): + """ + case 1: map_index is None, task_ids is None, task_group_name is None, it should fetch all the tasks + case 2: when map index -1 and provided task_ids, it should return the count of task_ids + case 3: when map index is None and provided task_group_id, it should return the count of tasks under the task group + case 4: when map index is 0 and provided task_group_id, it should return the count of tasks under the task group that falls map index =0 + case 5: when map index is -1 and provided task_group_id, it should return the count of tasks under the task group that falls map index =-1 i.e this task is not mapped + """ + + with dag_maker(session=session, serialized=True) as dag: + EmptyOperator(task_id="task1") + + with TaskGroup("group1"): + + @dag.task() + def add_one(x): + return [x + 1] + + add_one.expand(x=dynamic_task_args) + + EmptyOperator(task_id="task2") + + dr = dag_maker.create_dagrun(session=session) + + session.commit() + params = {} + + if task_ids: + params["task_ids"] = task_ids + if task_group_name: + params["task_group_id"] = task_group_name + if map_index is not None: + params["map_index"] = map_index + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": dr.dag_id, "run_ids": [dr.run_id], **params}, + ) + assert response.status_code == 200 + assert response.json() == expected_count + class TestGetTaskStates: def setup_method(self): @@ -1512,7 +1662,6 @@ def test_get_task_states_with_task_group_id_and_task_id(self, client, session, d "task_states": { "test": { "group1.task1": "success", - "task2": "failed", }, }, } @@ -1644,3 +1793,180 @@ def test_get_task_states_dag_not_found(self, client, session): "reason": "not_found", "message": "DAG non_existent_dag not found", } + + @pytest.mark.parametrize( + ["map_index", "dynamic_task_args", "states", "expected"], + ( + pytest.param( + None, + [1, 2, 3], + {"-1": State.SUCCESS, "0": State.SUCCESS, "1": State.SUCCESS, "2": State.SUCCESS}, + {"task1": "success", "add_one_0": "success", "add_one_1": "success", "add_one_2": "success"}, + id="with-default-map-index-None", + ), + pytest.param( + 0, + [1, 2, 3], + {"-1": State.SUCCESS, "0": State.FAILED, "1": State.SUCCESS, "2": State.SUCCESS}, + {"add_one_0": "failed"}, + id="with-map-index-0", + ), + pytest.param( + 1, + [1, 2, 3], + {"-1": State.SUCCESS, "0": State.SUCCESS, "1": State.FAILED, "2": State.SUCCESS}, + {"add_one_1": "failed"}, + id="with-map-index-1", + ), + ), + ) + def test_get_task_states_for_dynamic_task_mapping( + self, dag_maker, client, session, map_index, dynamic_task_args, states, expected + ): + """ + case 1: map_index is None, it should fetch all the tasks + other cases: when map index is provided, it should return the count of tasks that falls under the map index + """ + with dag_maker(session=session, serialized=True) as dag: + EmptyOperator(task_id="task1") + + @dag.task() + def add_one(x): + return [x + 1] + + add_one.expand(x=dynamic_task_args) + + dr = dag_maker.create_dagrun() + + tis = dr.get_task_instances() + for ti in tis: + ti.state = states.get(str(ti.map_index)) + session.merge(ti) + session.commit() + + map_index = {} if map_index is None else {"map_index": map_index} + + response = client.get("/execution/task-instances/states", params={"dag_id": dr.dag_id, **map_index}) + assert response.status_code == 200 + assert response.json() == {"task_states": {dr.run_id: expected}} + + @pytest.mark.parametrize( + [ + "map_index", + "dynamic_task_args", + "task_ids", + "task_group_name", + "states", + "expected", + ], + ( + pytest.param( + None, + [1, 2, 3], + None, + None, + {"-1": State.SUCCESS, "0": State.SUCCESS, "1": State.SUCCESS, "2": State.SUCCESS}, + { + "group1.add_one_0": "success", + "group1.add_one_1": "success", + "group1.add_one_2": "success", + "group1.task2": "success", + "task1": "success", + }, + id="with-default-map-index-None", + ), + pytest.param( + -1, + [1, 2, 3], + ["task1"], + None, + {"-1": State.SUCCESS, "0": State.SUCCESS, "1": State.SUCCESS, "2": State.SUCCESS}, + {"task1": "success"}, + id="with-task-ids-map-index-(-1)", + ), + pytest.param( + None, + [1, 2, 3], + None, + "group1", + {"-1": State.SUCCESS, "0": State.SUCCESS, "1": State.SUCCESS, "2": State.SUCCESS}, + { + "group1.task2": "success", + "group1.add_one_0": "success", + "group1.add_one_1": "success", + "group1.add_one_2": "success", + }, + id="with-task-group-id-and-map-index-None", + ), + pytest.param( + 0, + [1, 2, 3], + None, + "group1", + {"-1": State.SUCCESS, "0": State.FAILED, "1": State.SUCCESS, "2": State.SUCCESS}, + {"group1.add_one_0": "failed"}, + id="with-task-group-id-and-map-index-0", + ), + pytest.param( + -1, + [1, 2, 3], + ["task1"], + "group1", + {"-1": State.SUCCESS, "0": State.SUCCESS, "1": State.SUCCESS, "2": State.SUCCESS}, + {"task1": "success", "group1.task2": "success"}, + id="with-task-id-and-task-group-map-index-(-1)", + ), + ), + ) + def test_get_task_states_mix_of_task_and_task_group_dynamic_task_mapping( + self, + dag_maker, + client, + session, + map_index, + dynamic_task_args, + task_ids, + task_group_name, + states, + expected, + ): + """ + case1: map_index is None, task_ids is None, task_group_name is None, it should fetch all the task states + case2: when map index -1 and provided task_ids, it should return the task states of task_ids + case3: when map index is None and provided task_group_id, it should return the task states of tasks under the task group and normal task states under task group + case4: when map index is 0 and provided task_group_id, it should return the task states of tasks under the task group that falls under map index = 0 + case5: when map index is -1 and provided both task_id and task_group_id, it should return the task states of tasks under the task group that falls under map index = -1 and normal task_ids states + """ + + with dag_maker(session=session, serialized=True) as dag: + EmptyOperator(task_id="task1") + + with TaskGroup("group1"): + + @dag.task() + def add_one(x): + return [x + 1] + + add_one.expand(x=dynamic_task_args) + + EmptyOperator(task_id="task2") + + dr = dag_maker.create_dagrun(session=session) + + tis = dr.get_task_instances() + for ti in tis: + ti.state = states.get(str(ti.map_index)) + session.merge(ti) + session.commit() + params = {} + + if task_ids: + params["task_ids"] = task_ids + if task_group_name: + params["task_group_id"] = task_group_name + if map_index is not None: + params["map_index"] = map_index + + response = client.get("/execution/task-instances/states", params={"dag_id": dr.dag_id, **params}) + assert response.status_code == 200 + assert response.json() == {"task_states": {dr.run_id: expected}} diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 7e0f4e657ae93..c64f721b9ae61 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -212,6 +212,7 @@ def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> TaskR def get_count( self, dag_id: str, + map_index: int | None = None, task_ids: list[str] | None = None, task_group_id: str | None = None, logical_dates: list[datetime] | None = None, @@ -231,12 +232,16 @@ def get_count( # Remove None values from params params = {k: v for k, v in params.items() if v is not None} + if map_index is not None and map_index >= 0: + params.update({"map_index": map_index}) # type: ignore[dict-item] + resp = self.client.get("task-instances/count", params=params) return TICount(count=resp.json()) def get_task_states( self, dag_id: str, + map_index: int | None = None, task_ids: list[str] | None = None, task_group_id: str | None = None, logical_dates: list[datetime] | None = None, @@ -254,6 +259,9 @@ def get_task_states( # Remove None values from params params = {k: v for k, v in params.items() if v is not None} + if map_index is not None and map_index >= 0: + params.update({"map_index": map_index}) # type: ignore[dict-item] + resp = self.client.get("task-instances/states", params=params) return TaskStatesResponse.model_validate_json(resp.read()) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 8583f0cee7743..b4d68086b0c4d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -560,6 +560,7 @@ class GetTaskRescheduleStartDate(BaseModel): class GetTICount(BaseModel): dag_id: str + map_index: int | None = None task_ids: list[str] | None = None task_group_id: str | None = None logical_dates: list[AwareDatetime] | None = None @@ -570,6 +571,7 @@ class GetTICount(BaseModel): class GetTaskStates(BaseModel): dag_id: str + map_index: int | None = None task_ids: list[str] | None = None task_group_id: str | None = None logical_dates: list[AwareDatetime] | None = None diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 0b77156578bb0..c90d6ea5a0241 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1099,6 +1099,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): elif isinstance(msg, GetTICount): resp = self.client.task_instances.get_count( dag_id=msg.dag_id, + map_index=msg.map_index, task_ids=msg.task_ids, task_group_id=msg.task_group_id, logical_dates=msg.logical_dates, @@ -1108,6 +1109,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): elif isinstance(msg, GetTaskStates): task_states_map = self.client.task_instances.get_task_states( dag_id=msg.dag_id, + map_index=msg.map_index, task_ids=msg.task_ids, task_group_id=msg.task_group_id, logical_dates=msg.logical_dates, diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index a3dd7f56bb54e..00c09528c19a1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -413,6 +413,7 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: @staticmethod def get_ti_count( dag_id: str, + map_index: int | None = None, task_ids: list[str] | None = None, task_group_id: str | None = None, logical_dates: list[datetime] | None = None, @@ -427,6 +428,7 @@ def get_ti_count( log=log, msg=GetTICount( dag_id=dag_id, + map_index=map_index, task_ids=task_ids, task_group_id=task_group_id, logical_dates=logical_dates, @@ -444,6 +446,7 @@ def get_ti_count( @staticmethod def get_task_states( dag_id: str, + map_index: int | None = None, task_ids: list[str] | None = None, task_group_id: str | None = None, logical_dates: list[datetime] | None = None, @@ -457,6 +460,7 @@ def get_task_states( log=log, msg=GetTaskStates( dag_id=dag_id, + map_index=map_index, task_ids=task_ids, task_group_id=task_group_id, logical_dates=logical_dates, diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index ae21ed5599f36..8bd0ea0db8d4d 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -88,6 +88,7 @@ def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | None: . @staticmethod def get_ti_count( dag_id: str, + map_index: int | None = None, task_ids: list[str] | None = None, task_group_id: str | None = None, logical_dates: list[AwareDatetime] | None = None, @@ -98,6 +99,7 @@ def get_ti_count( @staticmethod def get_task_states( dag_id: str, + map_index: int | None = None, task_ids: list[str] | None = None, task_group_id: str | None = None, logical_dates: list[AwareDatetime] | None = None, diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 1eacfe1e632b7..e1f678bb1f77f 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -452,11 +452,13 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert params.get_list("logical_dates") == logical_dates_str assert params.get_list("run_ids") == [] assert params.get_list("states") == states + assert params["map_index"] == "0" return httpx.Response(200, json=10) client = make_client(transport=httpx.MockTransport(handle_request)) result = client.task_instances.get_count( dag_id="test_dag", + map_index=0, task_ids=task_ids, task_group_id="group1", logical_dates=logical_dates, @@ -494,6 +496,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert params.get_list("logical_dates") == logical_dates_str assert params.get_list("task_ids") == [] assert params.get_list("run_ids") == [] + assert params.get("map_index") == "0" return httpx.Response( 200, json={"task_states": {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}} ) @@ -501,6 +504,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) result = client.task_instances.get_task_states( dag_id="test_dag", + map_index=0, task_group_id="group1", logical_dates=logical_dates, ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index f4e20111128ca..4aff50a5fd486 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1396,6 +1396,7 @@ def watched_subprocess(self, mocker): (), { "dag_id": "test_dag", + "map_index": None, "logical_dates": None, "run_ids": None, "states": None, @@ -1426,6 +1427,7 @@ def watched_subprocess(self, mocker): (), { "dag_id": "test_dag", + "map_index": None, "task_ids": None, "logical_dates": None, "run_ids": None,