diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 149b1466acd2a..ab759f72ca9e4 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -342,3 +342,9 @@ class TIRuntimeCheckPayload(StrictBaseModel): inlets: list[AssetProfile] | None = None outlets: list[AssetProfile] | None = None + + +class TaskStatesResponse(BaseModel): + """Response for task states with run_id, task and state.""" + + task_states: dict[str, Any] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 68f3aa7125552..ab6117996bc22 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -19,7 +19,8 @@ import json import logging -from typing import Annotated +from collections import defaultdict +from typing import Annotated, Any from uuid import UUID from cadwyn import VersionedAPIRouter @@ -34,6 +35,7 @@ from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( PrevSuccessfulDagRunResponse, + TaskStatesResponse, TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, @@ -607,36 +609,7 @@ def get_count( query = query.where(TI.run_id.in_(run_ids)) if task_group_id: - # Get all tasks in the task group - dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session) - if not dag: - raise HTTPException( - status.HTTP_404_NOT_FOUND, - detail={ - "reason": "not_found", - "message": f"DAG {dag_id} not found", - }, - ) - - task_group = dag.task_group_dict.get(task_group_id) - if not task_group: - raise HTTPException( - status.HTTP_404_NOT_FOUND, - detail={ - "reason": "not_found", - "message": f"Task group {task_group_id} not found in DAG {dag_id}", - }, - ) - - # First get all task instances to get the task_id, map_index pairs - group_tasks = session.scalars( - select(TI).where( - TI.dag_id == dag_id, - TI.task_id.in_(task.task_id for task in task_group.iter_tasks()), - *([TI.logical_date.in_(logical_dates)] if logical_dates else []), - *([TI.run_id.in_(run_ids)] if run_ids else []), - ) - ).all() + group_tasks = _get_group_tasks(dag_id, task_group_id, session, logical_dates, run_ids) # Get unique (task_id, map_index) pairs task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks] @@ -659,9 +632,45 @@ def get_count( query = query.where(TI.state.in_(states)) count = session.scalar(query) + return count or 0 +@router.get("/states", status_code=status.HTTP_200_OK) +def get_task_states( + dag_id: str, + session: SessionDep, + task_ids: Annotated[list[str] | None, Query()] = None, + task_group_id: Annotated[str | None, Query()] = None, + logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None, + run_ids: Annotated[list[str] | None, Query()] = None, +) -> TaskStatesResponse: + """Get the task states for the given criteria.""" + run_id_task_state_map: dict[str, dict[str, Any]] = defaultdict(dict) + + query = select(TI).where(TI.dag_id == dag_id) + + if task_ids: + query = query.where(TI.task_id.in_(task_ids)) + + if logical_dates: + query = query.where(TI.logical_date.in_(logical_dates)) + + if run_ids: + query = query.where(TI.run_id.in_(run_ids)) + + results = session.scalars(query).all() + + [run_id_task_state_map[task.run_id].update({task.task_id: task.state}) for task in results] + + if task_group_id: + group_tasks = _get_group_tasks(dag_id, task_group_id, session, logical_dates, run_ids) + + [run_id_task_state_map[task.run_id].update({task.task_id: task.state}) for task in group_tasks] + + return TaskStatesResponse(task_states=run_id_task_state_map) + + @ti_id_router.only_exists_in_older_versions @ti_id_router.post( "/{task_instance_id}/runtime-checks", @@ -702,5 +711,40 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: return max_tries != 0 and try_number <= max_tries +def _get_group_tasks(dag_id: str, task_group_id: str, session: SessionDep, logical_dates=None, run_ids=None): + # Get all tasks in the task group + dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session) + if not dag: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"DAG {dag_id} not found", + }, + ) + + task_group = dag.task_group_dict.get(task_group_id) + if not task_group: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Task group {task_group_id} not found in DAG {dag_id}", + }, + ) + + # First get all task instances to get the task_id, map_index pairs + group_tasks = session.scalars( + select(TI).where( + TI.dag_id == dag_id, + TI.task_id.in_(task.task_id for task in task_group.iter_tasks()), + *([TI.logical_date.in_(logical_dates)] if logical_dates else []), + *([TI.run_id.in_(run_ids)] if run_ids else []), + ) + ).all() + + return group_tasks + + # This line should be at the end of the file to ensure all routes are registered router.include_router(ti_id_router) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index bef75cf515251..9a0765bc3ca9c 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -50,9 +50,11 @@ GetConnection, GetDagRunState, GetDRCount, + GetTaskStates, GetTICount, GetVariable, GetXCom, + TaskStatesResult, TICount, VariableResult, XComResult, @@ -225,6 +227,7 @@ class TriggerStateSync(BaseModel): DagRunStateResult, DRCount, TICount, + TaskStatesResult, ErrorResponse, ], Field(discriminator="type"), @@ -242,6 +245,7 @@ class TriggerStateSync(BaseModel): GetVariable, GetXCom, GetTICount, + GetTaskStates, GetDagRunState, GetDRCount, ], @@ -360,7 +364,12 @@ def client(self) -> Client: return client def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override] - from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse, XComResponse + from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + TaskStatesResponse, + VariableResponse, + XComResponse, + ) resp: BaseModel | None = None dump_opts = {} @@ -435,6 +444,19 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - run_ids=msg.run_ids, states=msg.states, ) + + elif isinstance(msg, GetTaskStates): + run_id_task_state_map = self.client.task_instances.get_task_states( + dag_id=msg.dag_id, + task_ids=msg.task_ids, + task_group_id=msg.task_group_id, + logical_dates=msg.logical_dates, + run_ids=msg.run_ids, + ) + if isinstance(run_id_task_state_map, TaskStatesResponse): + resp = TaskStatesResult.from_api_response(run_id_task_state_map) + else: + resp = run_id_task_state_map else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 869fffe7dc55e..4d36a2f5515cc 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1385,3 +1385,201 @@ def test_get_count_with_mixed_states(self, client, session, create_task_instance ) assert response.status_code == 200 assert response.json() == 2 + + +class TestGetTaskStates: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_get_task_states_basic(self, client, session, create_task_instance): + create_task_instance(task_id="test_task", state=State.SUCCESS) + session.commit() + + response = client.get("/execution/task-instances/states", params={"dag_id": "dag"}) + assert response.status_code == 200 + assert response.json() == {"task_states": {"test": {"test_task": "success"}}} + + def test_get_task_states_group_id_basic(self, client, dag_maker, session): + with dag_maker(dag_id="test_dag", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + + dag_maker.create_dagrun(session=session) + session.commit() + + response = client.get( + "/execution/task-instances/states", + params={"dag_id": "test_dag", "task_group_id": "group1"}, + ) + assert response.status_code == 200 + assert response.json() == { + "task_states": { + "test": { + "group1.task1": None, + }, + }, + } + + def test_get_task_states_with_task_group_id_and_task_id(self, client, session, dag_maker): + with dag_maker("test_get_task_group_states_with_multiple_task_tasks", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + EmptyOperator(task_id="task2") + + dr = dag_maker.create_dagrun() + + tis = dr.get_task_instances() + + # Set different states for the task instances + for ti, state in zip(tis, [State.SUCCESS, State.FAILED]): + ti.state = state + session.merge(ti) + session.commit() + + response = client.get( + "/execution/task-instances/states", + params={ + "dag_id": "test_get_task_group_states_with_multiple_task_tasks", + "task_group_id": "group1", + }, + ) + assert response.status_code == 200 + assert response.json() == { + "task_states": { + "test": { + "group1.task1": "success", + "task2": "failed", + }, + }, + } + + def test_get_task_group_states_with_multiple_task(self, client, session, dag_maker): + with dag_maker("test_get_task_group_states_with_multiple_task_tasks", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + EmptyOperator(task_id="task2") + EmptyOperator(task_id="task3") + + dr = dag_maker.create_dagrun() + + tis = dr.get_task_instances() + + # Set different states for the task instances + for ti, state in zip(tis, [State.SUCCESS, State.FAILED, State.SKIPPED]): + ti.state = state + session.merge(ti) + session.commit() + + response = client.get( + "/execution/task-instances/states", + params={ + "dag_id": "test_get_task_group_states_with_multiple_task_tasks", + "task_group_id": "group1", + }, + ) + assert response.status_code == 200 + assert response.json() == { + "task_states": { + "test": { + "group1.task1": "success", + "group1.task2": "failed", + "group1.task3": "skipped", + }, + }, + } + + def test_get_task_group_states_with_logical_dates(self, client, session, dag_maker, serialized=True): + with dag_maker("test_get_task_group_states_with_logical_dates", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + + date1 = timezone.datetime(2025, 1, 1) + date2 = timezone.datetime(2025, 1, 2) + + dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1) + dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2) + + session.commit() + + response = client.get( + "/execution/task-instances/states", + params={ + "dag_id": "test_get_task_group_states_with_logical_dates", + "logical_dates": [date1.isoformat(), date2.isoformat()], + "task_group_id": "group1", + }, + ) + assert response.status_code == 200 + assert response.json() == { + "task_states": { + "test_run_id1": { + "group1.task1": None, + }, + "test_run_id2": { + "group1.task1": None, + }, + }, + } + + def test_get_task_group_states_with_run_ids(self, client, session, dag_maker): + with dag_maker("test_get_task_group_states_with_run_ids", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + + dag_maker.create_dagrun(run_id="run1", logical_date=timezone.datetime(2025, 1, 1)) + dag_maker.create_dagrun(run_id="run2", logical_date=timezone.datetime(2025, 1, 2)) + + session.commit() + + response = client.get( + "/execution/task-instances/states", + params={ + "dag_id": "test_get_task_group_states_with_run_ids", + "run_ids": ["run1", "run2"], + "task_group_id": "group1", + }, + ) + assert response.status_code == 200 + assert response.json() == { + "task_states": { + "run1": { + "group1.task1": None, + }, + "run2": { + "group1.task1": None, + }, + }, + } + + def test_get_task_states_task_group_not_found(self, client, session, dag_maker): + with dag_maker(dag_id="test_get_task_states_task_group_not_found", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + dag_maker.create_dagrun(session=session) + + response = client.get( + "/execution/task-instances/states", + params={ + "dag_id": "test_get_task_states_task_group_not_found", + "task_group_id": "non_existent_group", + }, + ) + assert response.status_code == 404 + assert response.json()["detail"] == { + "reason": "not_found", + "message": "Task group non_existent_group not found in DAG test_get_task_states_task_group_not_found", + } + + def test_get_task_states_dag_not_found(self, client, session): + response = client.get( + "/execution/task-instances/states", + params={"dag_id": "non_existent_dag", "task_group_id": "group1"}, + ) + assert response.status_code == 404 + assert response.json()["detail"] == { + "reason": "not_found", + "message": "DAG non_existent_dag not found", + } diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 87302a4e6f7eb..d3c9ae6f27dc9 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -812,7 +812,14 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: logical_dates=self.execution_dates, states=["running"], ) - yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count}) + task_states = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.external_dag_id, + task_ids=self.external_task_ids, + run_ids=self.run_ids, + task_group_id=None, + logical_dates=self.execution_dates, + ) + yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count, "task_states": task_states}) @pytest.mark.xfail( @@ -823,7 +830,7 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: @pytest.mark.flaky(reruns=2, reruns_delay=10) @pytest.mark.execution_timeout(30) async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): - """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" + """Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states.""" # Create the test DAG and task with dag_maker(dag_id="parent_dag", session=session): EmptyOperator(task_id="parent_task") @@ -870,4 +877,6 @@ async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, d task_instance.refresh_from_db() assert task_instance.state == TaskInstanceState.SCHEDULED assert task_instance.next_method != "__fail__" - assert task_instance.next_kwargs == {"event": {"ti_count": 1, "dr_count": 1}} + assert task_instance.next_kwargs == { + "event": {"ti_count": 1, "dr_count": 1, "task_states": {"test": {"parent_task": "success"}}} + } diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index e64eb0d6763b0..012abb272a010 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -303,6 +303,8 @@ def poke(self, context: Context) -> bool: return self._poke_af2(dttm_filter) def _poke_af3(self, context: Context, dttm_filter: list[datetime.datetime]) -> bool: + from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states + self._has_checked_existence = True ti = context["ti"] @@ -315,12 +317,12 @@ def _get_count(states: list[str]) -> int: states=states, ) elif self.external_task_group_id: - return ti.get_ti_count( + run_id_task_state_map = ti.get_task_states( dag_id=self.external_dag_id, task_group_id=self.external_task_group_id, logical_dates=dttm_filter, - states=states, ) + return _get_count_by_matched_states(run_id_task_state_map, states) else: return ti.get_dr_count( dag_id=self.external_dag_id, diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index a5f8b67f54972..d3db7027fefb9 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -133,15 +133,27 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: await asyncio.sleep(self.poke_interval) async def _get_count_af_3(self, states): + from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance - if self.external_task_ids or self.external_task_group_id: + params = { + "dag_id": self.external_dag_id, + "logical_dates": self.logical_dates, + "run_ids": self.run_ids, + } + if self.external_task_ids: count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( - dag_id=self.external_dag_id, - task_ids=self.external_task_ids, + task_ids=self.external_task_ids, # type: ignore[arg-type] + states=states, + **params, + ) + elif self.external_task_group_id: + run_id_task_state_map = await sync_to_async(RuntimeTaskInstance.get_task_states)( task_group_id=self.external_task_group_id, - logical_dates=self.logical_dates, - run_ids=self.run_ids, + **params, + ) + count = await sync_to_async(_get_count_by_matched_states)( + run_id_task_state_map=run_id_task_state_map, states=states, ) else: @@ -151,7 +163,6 @@ async def _get_count_af_3(self, states): run_ids=self.run_ids, states=states, ) - if self.external_task_ids: return count / len(self.external_task_ids) else: diff --git a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py index 17d54e371bcb2..89808b9db7c52 100644 --- a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py +++ b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import func, select, tuple_ @@ -124,3 +124,14 @@ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, exter # returning default task_id as group_id itself, this will avoid any failure in case of # 'check_existence=False' and will fail on timeout return [(external_task_group_id, -1)] + + +def _get_count_by_matched_states( + run_id_task_state_map: dict[str, dict[str, Any]], + states: list[str], +): + count = 0 + for _, task_states in run_id_task_state_map.items(): + if all(state in states for state in task_states.values() if state): + count += 1 + return count diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index 439cdb0d8f7a3..efa33f5a5583f 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -1108,13 +1108,12 @@ def test_external_task_sensor_task_group(self, dag_maker): allowed_states=["success"], ) - self.context["ti"].get_ti_count.return_value = 1 + self.context["ti"].get_task_states.return_value = {"run_id": {"test_group.task_id": State.SUCCESS}} op.execute(context=self.context) - self.context["ti"].get_ti_count.assert_called_once_with( + self.context["ti"].get_task_states.assert_called_once_with( dag_id="test_dag_parent", logical_dates=[DEFAULT_DATE], - states=["success"], task_group_id="test_group", ) @@ -1226,15 +1225,14 @@ def test_external_task_sensor_task_group_failed_states(self, dag_maker): failed_states=[State.FAILED], ) - self.context["ti"].get_ti_count.return_value = 1 + self.context["ti"].get_task_states.return_value = {"run_id": {"test_group.task_id": State.FAILED}} with pytest.raises(AirflowException): op.execute(context=self.context) - self.context["ti"].get_ti_count.assert_called_once_with( + self.context["ti"].get_task_states.assert_called_once_with( dag_id="test_dag_parent", logical_dates=[DEFAULT_DATE], - states=[State.FAILED], task_group_id="test_group", ) diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index 3aa581257d79e..5f5b4861b04e9 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -75,7 +75,6 @@ async def test_task_workflow_trigger_success(self, mock_get_count): mock_get_count.assert_called_once_with( dag_id="external_task", task_ids=["external_task_op"], - task_group_id=None, logical_dates=[self.LOGICAL_DATE], run_ids=None, states=["success", "fail"], @@ -112,7 +111,6 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): mock_get_count.assert_called_once_with( dag_id="external_task", task_ids=["external_task_op"], - task_group_id=None, logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], states=["success", "fail"], @@ -145,7 +143,6 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): mock_get_count.assert_called_once_with( dag_id="external_task", task_ids=["external_task_op"], - task_group_id=None, logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], states=["success", "fail"], @@ -181,7 +178,6 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): mock_get_count.assert_called_once_with( dag_id="external_task", task_ids=["external_task_op"], - task_group_id=None, logical_dates=[self.LOGICAL_DATE], run_ids=None, states=["success", "fail"], @@ -218,7 +214,7 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co assert mock_sleep.await_count == 1 @pytest.mark.parametrize( - "task_ids, task_group_id, states, logical_dates, mock_ti_count, mock_dag_count, expected", + "task_ids, task_group_id, states, logical_dates, mock_ti_count, mock_task_states, mock_dag_count, expected", [ ( ["task_id_one", "task_id_two"], @@ -229,7 +225,8 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), ], 4, - 2, + None, + None, 2, ), ( @@ -240,19 +237,53 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), ], - 2, - 2, + None, + {"run_id_one": {"group1.task_id": "success"}, "run_id_two": {"group1.task_id": "success"}}, + None, 2, ), ( [], + "task_group_id", + ["success"], + [ + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + ], None, + { + "run_id_one": {"task_group_id.task_id": "success"}, + "run_id_two": {"task_group_id.task_id": "failed"}, + }, + None, + 1, + ), + ( + [], + "task_group_id", ["success"], [ timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), ], - 2, + None, + { + "run_id_one": {"task_group_id.task_id": "success"}, + "run_id_two": {"task_group_id.task_id": "skipped"}, + }, + None, + 1, + ), + ( + [], + None, + ["success"], + [ + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + ], + None, + None, 2, 2, ), @@ -260,21 +291,26 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co ids=[ "with_task_ids", "with task_group_id only", + "with task_group_id and some task_ids failed", + "without task_group_id and some task_ids skipped", "no task_ids or task_group_id", ], ) @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dr_count") @pytest.mark.asyncio async def test_get_count_af_3( self, mock_get_dr_count, + mock_get_task_states, mock_get_ti_count, task_ids, task_group_id, states, logical_dates, mock_ti_count, + mock_task_states, mock_dag_count, expected, ): @@ -287,6 +323,7 @@ async def test_get_count_af_3( mock_get_ti_count.return_value = mock_ti_count mock_get_dr_count.return_value = mock_dag_count + mock_get_task_states.return_value = mock_task_states trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, @@ -300,9 +337,19 @@ async def test_get_count_af_3( get_count_af_3 = await trigger._get_count_af_3(states) assert get_count_af_3 == expected - if task_ids or task_group_id: + if task_ids: mock_get_ti_count.assert_called_once() assert mock_get_ti_count.call_count == 1 + mock_get_task_states.assert_not_called() + assert mock_get_task_states.call_count == 0 + mock_get_dr_count.assert_not_called() + assert mock_get_dr_count.call_count == 0 + + elif task_group_id: + mock_get_task_states.assert_called_once() + assert mock_get_task_states.call_count == 1 + mock_get_ti_count.assert_not_called() + assert mock_get_ti_count.call_count == 0 mock_get_dr_count.assert_not_called() assert mock_get_dr_count.call_count == 0 @@ -311,6 +358,8 @@ async def test_get_count_af_3( assert mock_get_dr_count.call_count == 1 mock_get_ti_count.assert_not_called() assert mock_get_ti_count.call_count == 0 + mock_get_task_states.assert_not_called() + assert mock_get_task_states.call_count == 0 def test_serialization(self): """ diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index df7679cab8741..a3840ba4a15d3 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -41,6 +41,7 @@ DagRunStateResponse, DagRunType, PrevSuccessfulDagRunResponse, + TaskStatesResponse, TerminalStateNonSuccess, TerminalTIState, TIDeferredStatePayload, @@ -233,6 +234,29 @@ def get_count( resp = self.client.get("task-instances/count", params=params) return TICount(count=resp.json()) + def get_task_states( + self, + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + ) -> TaskStatesResponse: + """Get task states given criteria.""" + params = { + "dag_id": dag_id, + "task_ids": task_ids, + "task_group_id": task_group_id, + "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, + "run_ids": run_ids, + } + + # Remove None values from params + params = {k: v for k, v in params.items() if v is not None} + + resp = self.client.get("task-instances/states", params=params) + return TaskStatesResponse.model_validate_json(resp.read()) + class ConnectionOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 746af72c137e6..0f551aa665425 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -283,6 +283,14 @@ class TITargetStatePayload(BaseModel): state: IntermediateTIState +class TaskStatesResponse(BaseModel): + """ + Response for task states with run_id, task and state. + """ + + task_states: Annotated[dict[str, Any], Field(title="Task States")] + + class TerminalStateNonSuccess(str, Enum): """ TaskInstance states that can be reported without extra information. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index a9c2793457798..fbaee7de4e512 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -63,6 +63,7 @@ DagRunStateResponse, PrevSuccessfulDagRunResponse, TaskInstance, + TaskStatesResponse, TerminalTIState, TIDeferredStatePayload, TIRescheduleStatePayload, @@ -304,6 +305,21 @@ class TICount(BaseModel): type: Literal["TICount"] = "TICount" +class TaskStatesResult(TaskStatesResponse): + type: Literal["TaskStatesResult"] = "TaskStatesResult" + + @classmethod + def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we need to convert it to Result + for communication between the Supervisor and the task process since it needs a + discriminator field. + """ + return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult") + + class DRCount(BaseModel): """Response containing count of DAG Runs matching certain filters.""" @@ -334,6 +350,7 @@ class OKResponse(BaseModel): StartupDetails, TaskRescheduleStartDate, TICount, + TaskStatesResult, VariableResult, XComResult, XComCountResponse, @@ -546,6 +563,15 @@ class GetTICount(BaseModel): type: Literal["GetTICount"] = "GetTICount" +class GetTaskStates(BaseModel): + dag_id: str + task_ids: list[str] | None = None + task_group_id: str | None = None + logical_dates: list[AwareDatetime] | None = None + run_ids: list[str] | None = None + type: Literal["GetTaskStates"] = "GetTaskStates" + + class GetDRCount(BaseModel): dag_id: str logical_dates: list[AwareDatetime] | None = None @@ -568,6 +594,7 @@ class GetDRCount(BaseModel): GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetTICount, + GetTaskStates, GetVariable, GetXCom, GetXComCount, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 7ade78b4ec10b..a48b49574e5f8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -56,6 +56,7 @@ ConnectionResponse, IntermediateTIState, TaskInstance, + TaskStatesResponse, TerminalTIState, VariableResponse, ) @@ -77,6 +78,7 @@ GetDRCount, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, + GetTaskStates, GetTICount, GetVariable, GetXCom, @@ -91,6 +93,7 @@ StartupDetails, SucceedTask, TaskState, + TaskStatesResult, ToSupervisor, TriggerDagRun, VariableResult, @@ -1035,6 +1038,18 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): run_ids=msg.run_ids, states=msg.states, ) + elif isinstance(msg, GetTaskStates): + task_states_map = self.client.task_instances.get_task_states( + dag_id=msg.dag_id, + task_ids=msg.task_ids, + task_group_id=msg.task_group_id, + logical_dates=msg.logical_dates, + run_ids=msg.run_ids, + ) + if isinstance(task_states_map, TaskStatesResponse): + resp = TaskStatesResult.from_api_response(task_states_map) + else: + resp = task_states_map elif isinstance(msg, GetDRCount): resp = self.client.dag_runs.get_count( dag_id=msg.dag_id, diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 102ca78b5248c..f6758d7ca59e8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -64,6 +64,7 @@ GetDagRunState, GetDRCount, GetTaskRescheduleStartDate, + GetTaskStates, GetTICount, RescheduleTask, RetryTask, @@ -73,6 +74,7 @@ SucceedTask, TaskRescheduleStartDate, TaskState, + TaskStatesResult, TICount, ToSupervisor, ToTask, @@ -435,6 +437,35 @@ def get_ti_count( return response.count + @staticmethod + def get_task_states( + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + ) -> dict[str, Any]: + """Return the task states matching the given criteria.""" + log = structlog.get_logger(logger_name="task") + + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetTaskStates( + dag_id=dag_id, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + ), + ) + response = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(response, TaskStatesResult) + + return response.task_states + @staticmethod def get_dr_count( dag_id: str, diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index a7749ccd2308b..ae21ed5599f36 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -95,6 +95,15 @@ def get_ti_count( states: list[str] | None = None, ) -> int: ... + @staticmethod + def get_task_states( + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[AwareDatetime] | None = None, + run_ids: list[str] | None = None, + ) -> dict[str, Any]: ... + @staticmethod def get_dr_count( dag_id: str, diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index c626749602987..a2c1bc107db2e 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -464,6 +464,48 @@ def handle_request(request: httpx.Request) -> httpx.Response: ) assert result.count == 10 + def test_get_task_states_basic(self): + """Test basic get_task_states functionality with just dag_id.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/task-instances/states" + assert request.url.params.get("dag_id") == "test_dag" + assert request.url.params.get("task_group_id") == "group1" + return httpx.Response( + 200, json={"task_states": {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}} + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_instances.get_task_states(dag_id="test_dag", task_group_id="group1") + assert result.task_states == {"run_id": {"group1.task1": "success", "group1.task2": "failed"}} + + def test_get_task_states_with_all_params(self): + """Test get_task_states with all optional parameters.""" + + logical_dates_str = ["2024-01-01T00:00:00+00:00", "2024-01-02T00:00:00+00:00"] + logical_dates = [timezone.parse(d) for d in logical_dates_str] + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/task-instances/states" + assert request.method == "GET" + params = request.url.params + assert params["dag_id"] == "test_dag" + assert params["task_group_id"] == "group1" + assert params.get_list("logical_dates") == logical_dates_str + assert params.get_list("task_ids") == [] + assert params.get_list("run_ids") == [] + return httpx.Response( + 200, json={"task_states": {"run_id": {"group1.task1": "success", "group1.task2": "failed"}}} + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_instances.get_task_states( + dag_id="test_dag", + task_group_id="group1", + logical_dates=logical_dates, + ) + assert result.task_states == {"run_id": {"group1.task1": "success", "group1.task2": "failed"}} + class TestVariableOperations: """ diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index f8baf2e03ee72..454a742bb5687 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -72,6 +72,7 @@ GetDRCount, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, + GetTaskStates, GetTICount, GetVariable, GetXCom, @@ -85,6 +86,7 @@ SucceedTask, TaskRescheduleStartDate, TaskState, + TaskStatesResult, TICount, TriggerDagRun, VariableResult, @@ -1459,6 +1461,21 @@ def watched_subprocess(self, mocker): DRCount(count=2), id="get_dr_count", ), + pytest.param( + GetTaskStates(dag_id="test_dag", task_group_id="test_group"), + b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n', + "task_instances.get_task_states", + (), + { + "dag_id": "test_dag", + "task_ids": None, + "logical_dates": None, + "run_ids": None, + "task_group_id": "test_group", + }, + TaskStatesResult(task_states={"run_id": {"task1": "success", "task2": "failed"}}), + id="get_task_states", + ), ], ) def test_handle_requests( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index acdace8344362..d3f6cd90bde11 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -72,6 +72,7 @@ GetConnection, GetDagRunState, GetDRCount, + GetTaskStates, GetTICount, GetVariable, GetXCom, @@ -84,6 +85,7 @@ SucceedTask, TaskRescheduleStartDate, TaskState, + TaskStatesResult, TICount, TriggerDagRun, VariableResult, @@ -1471,6 +1473,28 @@ def test_get_dagrun_state(self, mock_supervisor_comms): ) assert state == "running" + def test_get_task_states(self, mock_supervisor_comms): + """Test that get_task_states sends the correct request and returns the states.""" + mock_supervisor_comms.get_message.return_value = TaskStatesResult( + task_states={"run1": {"task1": "running"}} + ) + + states = RuntimeTaskInstance.get_task_states( + dag_id="test_dag", + task_ids=["task1"], + run_ids=["run1"], + ) + + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, + msg=GetTaskStates( + dag_id="test_dag", + task_ids=["task1"], + run_ids=["run1"], + ), + ) + assert states == {"run1": {"task1": "running"}} + class TestXComAfterTaskExecution: @pytest.mark.parametrize(