From e470a65ec9b3303cb3ba861be135e56a86dacd50 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Fri, 4 Apr 2025 22:10:28 +0100 Subject: [PATCH 01/12] Fix WorkflowStateTrigger --- .../src/airflow/jobs/triggerer_job_runner.py | 15 +++- .../standard/sensors/external_task.py | 2 +- .../standard/triggers/external_task.py | 73 +++++++++++++++++-- .../src/airflow/sdk/execution_time/context.py | 63 +++++++++++++++- 4 files changed, 142 insertions(+), 11 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 520699e9e5093..25159d83e2eab 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -52,8 +52,10 @@ GetDRCount, GetVariable, GetXCom, +GetTICount, VariableResult, XComResult, + TICount, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.stats import Stats @@ -222,6 +224,7 @@ class TriggerStateSync(BaseModel): XComResult, DagRunStateResult, DRCount, + TICount, ErrorResponse, ], Field(discriminator="type"), @@ -233,7 +236,7 @@ class TriggerStateSync(BaseModel): ToTriggerSupervisor = Annotated[ - Union[messages.TriggerStateChanges, GetConnection, GetVariable, GetXCom, GetDagRunState, GetDRCount], + Union[messages.TriggerStateChanges, GetConnection, GetVariable, GetXCom, GetTICount, GetDagRunState, GetDRCount], Field(discriminator="type"), ] """ @@ -411,6 +414,16 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id) resp = DagRunStateResult.from_api_response(dr_resp).model_dump_json().encode() + elif isinstance(msg, GetTICount): + ti_count = self.client.task_instances.get_count( + dag_id=msg.dag_id, + task_ids=msg.task_ids, + task_group_id=msg.task_group_id, + logical_dates=msg.logical_dates, + run_ids=msg.run_ids, + states=msg.states, + ) + resp = ti_count.model_dump_json().encode() else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index dd5ad6c4e7a3a..003482d07137b 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -427,7 +427,7 @@ def execute(self, context: Context) -> None: external_dag_id=self.external_dag_id, external_task_group_id=self.external_task_group_id, external_task_ids=self.external_task_ids, - logical_dates=self._get_dttm_filter(context), + execution_dates=self._get_dttm_filter(context), allowed_states=self.allowed_states, poke_interval=self.poll_interval, soft_fail=self.soft_fail, diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index adf3a535e3ef2..554d4c510f89a 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -89,19 +89,77 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "allowed_states": self.allowed_states, "poke_interval": self.poke_interval, "soft_fail": self.soft_fail, + "execution_dates": self.execution_dates, } if AIRFLOW_V_3_0_PLUS: data["run_ids"] = self.run_ids - else: - data["execution_dates"] = self.execution_dates + # else: + # data["execution_dates"] = self.execution_dates return "airflow.providers.standard.triggers.external_task.WorkflowTrigger", data async def run(self) -> typing.AsyncIterator[TriggerEvent]: """Check periodically tasks, task group or dag status.""" + yield TriggerEvent({"status": "success"}) + if AIRFLOW_V_3_0_PLUS: + self._validate_count_af_3() + else: + while True: + if self.failed_states: + failed_count = await self._get_count(self.failed_states) + if failed_count > 0: + yield TriggerEvent({"status": "failed"}) + return + else: + yield TriggerEvent({"status": "success"}) + return + if self.skipped_states: + skipped_count = await self._get_count(self.skipped_states) + if skipped_count > 0: + yield TriggerEvent({"status": "skipped"}) + return + allowed_count = await self._get_count(self.allowed_states) + _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates + if allowed_count == len(_dates): # type: ignore[arg-type] + yield TriggerEvent({"status": "success"}) + return + self.log.info("Sleeping for %s seconds", self.poke_interval) + await asyncio.sleep(self.poke_interval) + + + async def _validate_count_af_3(self): + from airflow.sdk.execution_time.context import get_ti_count, get_dr_count + run_id_or_dates = self.run_ids or self.execution_dates or [] + print(f"run_id_or_dates: {run_id_or_dates}") + + async def get_count(states): + if self.external_task_ids or self.external_task_group_id: + count = await sync_to_async(get_ti_count)( + dag_id=self.external_dag_id, + task_ids=self.external_task_ids, + task_group_id=self.external_task_group_id, + logical_dates=self.execution_dates, + run_ids=self.run_ids, + states=states, + ) + else: + count = await sync_to_async(get_dr_count)( + dag_id=self.external_dag_id, + logical_dates=self.execution_dates, + run_ids=self.run_ids, + states=states, + ) + + if self.external_task_ids: + return count / len(self.external_task_ids) + elif self.external_task_group_id: + return count / len(run_id_or_dates) + else: + return count + while True: if self.failed_states: - failed_count = await self._get_count(self.failed_states) + failed_count = await get_count(self.failed_states) if failed_count > 0: yield TriggerEvent({"status": "failed"}) return @@ -109,18 +167,19 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: yield TriggerEvent({"status": "success"}) return if self.skipped_states: - skipped_count = await self._get_count(self.skipped_states) + skipped_count = await get_count(self.skipped_states) if skipped_count > 0: yield TriggerEvent({"status": "skipped"}) return - allowed_count = await self._get_count(self.allowed_states) - _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates - if allowed_count == len(_dates): # type: ignore[arg-type] + allowed_count = await get_count(self.allowed_states) + if allowed_count == len(run_id_or_dates): # type: ignore[arg-type] yield TriggerEvent({"status": "success"}) return self.log.info("Sleeping for %s seconds", self.poke_interval) await asyncio.sleep(self.poke_interval) + + @sync_to_async def _get_count(self, states: typing.Iterable[str] | None) -> int: """ diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 2ff7dbbda0871..74b31ac386e25 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -19,6 +19,7 @@ import collections import contextlib from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence +from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union @@ -53,8 +54,8 @@ AssetResult, ConnectionResult, PrevSuccessfulDagRunResponse, - VariableResult, - ) + VariableResult, TICount, GetDRCount, DRCount, GetTICount, +) from airflow.sdk.types import OutletEventAccessorsProtocol @@ -660,3 +661,61 @@ def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: except KeyError: outlet_events = context["outlet_events"] = OutletEventAccessors() return outlet_events + + +def get_ti_count( + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, +) -> int: + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + """Return the number of task instances matching the given criteria.""" + log = structlog.get_logger(logger_name="task") + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetTICount( + dag_id=dag_id, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(response, TICount) + + return response.count + + +def get_dr_count( + dag_id: str, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, +) -> int: + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + """Return the number of DAG runs matching the given criteria.""" + log = structlog.get_logger(logger_name="task") + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(response, DRCount) + + return response.count From 7c5e81f100ce2b340d9d024e93fa02fb87d355bf Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sat, 5 Apr 2025 00:11:03 +0100 Subject: [PATCH 02/12] Fix WorkflowTrigger to work with TaskSDK --- .../tests/unit/jobs/test_triggerer_job.py | 101 +++++++++ .../standard/sensors/external_task.py | 2 + .../standard/triggers/external_task.py | 92 +++----- .../standard/triggers/test_external_task.py | 206 ++++++++++++++++++ .../src/airflow/sdk/execution_time/context.py | 64 +----- 5 files changed, 343 insertions(+), 122 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index af197d6f35bfa..ecb48c2a66448 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -767,3 +767,104 @@ async def test_trigger_can_fetch_trigger_dag_run_count_in_deferrable(session, da assert task_instance.state == TaskInstanceState.SCHEDULED assert task_instance.next_method != "__fail__" assert task_instance.next_kwargs == {"event": {"count": 1}} + + +class CustomTriggerWorkflowStateTrigger(BaseTrigger): + """Custom Trigger to check the triggerer can access the get_ti_count and get_dr_count.""" + + def __init__(self, external_dag_id, execution_dates, external_task_ids, allowed_states, run_ids): + self.external_dag_id = external_dag_id + self.execution_dates = execution_dates + self.external_task_ids = external_task_ids + self.allowed_states = allowed_states + self.run_ids = run_ids + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + f"{type(self).__module__}.{type(self).__qualname__}", + { + "external_dag_id": self.external_dag_id, + "execution_dates": self.execution_dates, + "external_task_ids": self.external_task_ids, + "allowed_states": self.allowed_states, + "run_ids": self.run_ids, + }, + ) + + async def run(self, **args) -> AsyncIterator[TriggerEvent]: + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + ti_count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( + dag_id=self.external_dag_id, + task_ids=self.external_task_ids, + task_group_id=None, + run_ids=self.run_ids, + logical_dates=self.execution_dates, + states=self.allowed_states, + ) + dr_count = await sync_to_async(RuntimeTaskInstance.get_dr_count)( + dag_id=self.external_dag_id, + run_ids=self.run_ids, + logical_dates=self.execution_dates, + states=["running"], + ) + yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count}) + + +@pytest.mark.xfail( + reason="We know that test is flaky and have no time to fix it before 3.0. " + "We should fix it later. TODO: AIP-72" +) +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=2, reruns_delay=10) +@pytest.mark.execution_timeout(30) +async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): + """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" + # Create the test DAG and task + with dag_maker(dag_id="parent_dag", session=session): + EmptyOperator(task_id="parent_task") + parent_dag_run = dag_maker.create_dagrun() + parent_task = parent_dag_run.task_instances[0] + parent_task.state = TaskInstanceState.SUCCESS + + with dag_maker(dag_id="trigger_can_fetch_dag_run_count_ti_count_in_deferrable", session=session): + EmptyOperator(task_id="dummy1") + dr = dag_maker.create_dagrun() + task_instance = dr.task_instances[0] + task_instance.state = TaskInstanceState.DEFERRED + + # Use the same dag run with states deferred to fetch the count + trigger = CustomTriggerWorkflowStateTrigger( + external_dag_id=parent_task.dag_id, + execution_dates=[parent_task.logical_date], + external_task_ids=[parent_task.task_id], + allowed_states=[State.SUCCESS], + run_ids=[parent_task.run_id], + ) + trigger_orm = Trigger( + classpath=trigger.serialize()[0], + kwargs={ + "external_dag_id": parent_dag_run.dag_id, + "execution_dates": [parent_dag_run.logical_date], + "external_task_ids": [parent_task.task_id], + "allowed_states": [State.SUCCESS], + "run_ids": [parent_dag_run.run_id], + }, + ) + trigger_orm.id = 1 + session.add(trigger_orm) + session.commit() + task_instance.trigger_id = trigger_orm.id + + job = Job() + session.add(job) + session.commit() + + supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None) + supervisor.run() + + parent_task.refresh_from_db() + task_instance.refresh_from_db() + assert task_instance.state == TaskInstanceState.SCHEDULED + assert task_instance.next_method != "__fail__" + assert task_instance.next_kwargs == {"event": {"ti_count": 1, "dr_count": 1}} diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 003482d07137b..a2bcdad876003 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -429,6 +429,8 @@ def execute(self, context: Context) -> None: external_task_ids=self.external_task_ids, execution_dates=self._get_dttm_filter(context), allowed_states=self.allowed_states, + failed_states=self.failed_states, + skipped_states=self.skipped_states, poke_interval=self.poll_interval, soft_fail=self.soft_fail, ), diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index 554d4c510f89a..9675f99138d45 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -93,73 +93,17 @@ def serialize(self) -> tuple[str, dict[str, Any]]: } if AIRFLOW_V_3_0_PLUS: data["run_ids"] = self.run_ids - # else: - # data["execution_dates"] = self.execution_dates return "airflow.providers.standard.triggers.external_task.WorkflowTrigger", data async def run(self) -> typing.AsyncIterator[TriggerEvent]: """Check periodically tasks, task group or dag status.""" - yield TriggerEvent({"status": "success"}) - if AIRFLOW_V_3_0_PLUS: - self._validate_count_af_3() - else: - while True: - if self.failed_states: - failed_count = await self._get_count(self.failed_states) - if failed_count > 0: - yield TriggerEvent({"status": "failed"}) - return - else: - yield TriggerEvent({"status": "success"}) - return - if self.skipped_states: - skipped_count = await self._get_count(self.skipped_states) - if skipped_count > 0: - yield TriggerEvent({"status": "skipped"}) - return - allowed_count = await self._get_count(self.allowed_states) - _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates - if allowed_count == len(_dates): # type: ignore[arg-type] - yield TriggerEvent({"status": "success"}) - return - self.log.info("Sleeping for %s seconds", self.poke_interval) - await asyncio.sleep(self.poke_interval) - - - async def _validate_count_af_3(self): - from airflow.sdk.execution_time.context import get_ti_count, get_dr_count + get_count_func = self._get_count_af_3 if AIRFLOW_V_3_0_PLUS else self._get_count run_id_or_dates = self.run_ids or self.execution_dates or [] - print(f"run_id_or_dates: {run_id_or_dates}") - - async def get_count(states): - if self.external_task_ids or self.external_task_group_id: - count = await sync_to_async(get_ti_count)( - dag_id=self.external_dag_id, - task_ids=self.external_task_ids, - task_group_id=self.external_task_group_id, - logical_dates=self.execution_dates, - run_ids=self.run_ids, - states=states, - ) - else: - count = await sync_to_async(get_dr_count)( - dag_id=self.external_dag_id, - logical_dates=self.execution_dates, - run_ids=self.run_ids, - states=states, - ) - - if self.external_task_ids: - return count / len(self.external_task_ids) - elif self.external_task_group_id: - return count / len(run_id_or_dates) - else: - return count while True: if self.failed_states: - failed_count = await get_count(self.failed_states) + failed_count = await get_count_func(self.failed_states) if failed_count > 0: yield TriggerEvent({"status": "failed"}) return @@ -167,18 +111,46 @@ async def get_count(states): yield TriggerEvent({"status": "success"}) return if self.skipped_states: - skipped_count = await get_count(self.skipped_states) + skipped_count = await get_count_func(self.skipped_states) if skipped_count > 0: yield TriggerEvent({"status": "skipped"}) return - allowed_count = await get_count(self.allowed_states) + allowed_count = await get_count_func(self.allowed_states) + if allowed_count == len(run_id_or_dates): # type: ignore[arg-type] yield TriggerEvent({"status": "success"}) return self.log.info("Sleeping for %s seconds", self.poke_interval) await asyncio.sleep(self.poke_interval) + async def _get_count_af_3(self, states): + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + run_id_or_dates = self.run_ids or self.execution_dates or [] + + if self.external_task_ids or self.external_task_group_id: + count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( + dag_id=self.external_dag_id, + task_ids=self.external_task_ids, + task_group_id=self.external_task_group_id, + logical_dates=self.execution_dates, + run_ids=self.run_ids, + states=states, + ) + else: + count = await sync_to_async(RuntimeTaskInstance.get_dr_count)( + dag_id=self.external_dag_id, + logical_dates=self.execution_dates, + run_ids=self.run_ids, + states=states, + ) + if self.external_task_ids: + return count / len(self.external_task_ids) + elif self.external_task_group_id: + return count / len(run_id_or_dates) + else: + return count @sync_to_async def _get_count(self, states: typing.Iterable[str] | None) -> int: diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index 247085a407ff4..0828d091060f8 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -39,11 +39,217 @@ key, value = next(iter(_DATES.items())) +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3") class TestWorkflowTrigger: DAG_ID = "external_task" TASK_ID = "external_task_op" RUN_ID = "external_task_run_id" STATES = ["success", "fail"] + EXECUTION_DATE = timezone.datetime(2022, 1, 1) + + @pytest.mark.flaky(reruns=5) + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @pytest.mark.asyncio + async def test_task_workflow_trigger_success(self, mock_get_count): + """check the db count get called correctly.""" + mock_get_count.side_effect = mocked_get_count + + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + execution_dates=[self.EXECUTION_DATE], + external_task_ids=[self.TASK_ID], + allowed_states=self.STATES, + poke_interval=0.2, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + fake_task = asyncio.create_task(fake_async_fun()) + await trigger_task + assert fake_task.done() # confirm that get_count is done in an async fashion + assert trigger_task.done() + result = trigger_task.result() + assert result.payload == {"status": "success"} + + mock_get_count.assert_called_once_with( + dag_id="external_task", + task_ids=["external_task_op"], + task_group_id=None, + logical_dates=[self.EXECUTION_DATE], + run_ids=None, + states=["success", "fail"], + ) + # test that it returns after yielding + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + @pytest.mark.flaky(reruns=5) + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @pytest.mark.asyncio + async def test_task_workflow_trigger_failed(self, mock_get_count): + mock_get_count.side_effect = mocked_get_count + + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + execution_dates=[self.EXECUTION_DATE], + run_ids=[self.RUN_ID], + external_task_ids=[self.TASK_ID], + failed_states=self.STATES, + poke_interval=0.2, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + fake_task = asyncio.create_task(fake_async_fun()) + await trigger_task + assert fake_task.done() # confirm that get_count is done in an async fashion + assert trigger_task.done() + result = trigger_task.result() + assert isinstance(result, TriggerEvent) + assert result.payload == {"status": "failed"} + mock_get_count.assert_called_once_with( + dag_id="external_task", + task_ids=["external_task_op"], + task_group_id=None, + logical_dates=[self.EXECUTION_DATE], + run_ids=[self.RUN_ID], + states=["success", "fail"], + ) + # test that it returns after yielding + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @pytest.mark.asyncio + async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): + mock_get_count.return_value = 0 + + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + execution_dates=[self.EXECUTION_DATE], + run_ids=[self.RUN_ID], + external_task_ids=[self.TASK_ID], + failed_states=self.STATES, + poke_interval=0.2, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + await trigger_task + assert trigger_task.done() + result = trigger_task.result() + assert isinstance(result, TriggerEvent) + assert result.payload == {"status": "success"} + mock_get_count.assert_called_once_with( + dag_id="external_task", + task_ids=["external_task_op"], + task_group_id=None, + logical_dates=[self.EXECUTION_DATE], + run_ids=[self.RUN_ID], + states=["success", "fail"], + ) + # test that it returns after yielding + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + @pytest.mark.flaky(reruns=5) + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @pytest.mark.asyncio + async def test_task_workflow_trigger_skipped(self, mock_get_count): + mock_get_count.side_effect = mocked_get_count + + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + execution_dates=[self.EXECUTION_DATE], + external_task_ids=[self.TASK_ID], + skipped_states=self.STATES, + poke_interval=0.2, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + fake_task = asyncio.create_task(fake_async_fun()) + await trigger_task + assert fake_task.done() # confirm that get_count is done in an async fashion + assert trigger_task.done() + result = trigger_task.result() + assert isinstance(result, TriggerEvent) + assert result.payload == {"status": "skipped"} + mock_get_count.assert_called_once_with( + dag_id="external_task", + task_ids=["external_task_op"], + task_group_id=None, + logical_dates=[self.EXECUTION_DATE], + run_ids=None, + states=["success", "fail"], + ) + + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("asyncio.sleep") + @pytest.mark.asyncio + async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_count): + mock_get_count.side_effect = [0, 1] + + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + execution_dates=[self.EXECUTION_DATE], + external_task_ids=[self.TASK_ID], + poke_interval=0.2, + ) + + gen = trigger.run() + trigger_task = asyncio.create_task(gen.__anext__()) + await trigger_task + assert trigger_task.done() + result = trigger_task.result() + assert isinstance(result, TriggerEvent) + assert result.payload == {"status": "success"} + mock_get_count.assert_called() + assert mock_get_count.call_count == 2 + + # test that it returns after yielding + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + mock_sleep.assert_awaited() + assert mock_sleep.await_count == 1 + + def test_serialization(self): + """ + Asserts that the WorkflowTrigger correctly serializes its arguments and classpath. + """ + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + execution_dates=[self.EXECUTION_DATE], + run_ids=[self.RUN_ID], + failed_states=["failed"], + skipped_states=["skipped"], + external_task_ids=[self.TASK_ID], + allowed_states=self.STATES, + poke_interval=5, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.standard.triggers.external_task.WorkflowTrigger" + assert kwargs == { + "external_dag_id": self.DAG_ID, + "execution_dates": [self.EXECUTION_DATE], + "run_ids": [self.RUN_ID], + "external_task_ids": [self.TASK_ID], + "external_task_group_id": None, + "failed_states": ["failed"], + "skipped_states": ["skipped"], + "allowed_states": self.STATES, + "poke_interval": 5, + "soft_fail": False, + } + + +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 2") +class TestWorkflowTriggerAF2: + DAG_ID = "external_task" + TASK_ID = "external_task_op" + RUN_ID = "external_task_run_id" + STATES = ["success", "fail"] @pytest.mark.flaky(reruns=5) @mock.patch("airflow.providers.standard.triggers.external_task._get_count") diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 74b31ac386e25..c3c95b148cd23 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -54,8 +54,8 @@ AssetResult, ConnectionResult, PrevSuccessfulDagRunResponse, - VariableResult, TICount, GetDRCount, DRCount, GetTICount, -) + VariableResult, + ) from airflow.sdk.types import OutletEventAccessorsProtocol @@ -599,8 +599,6 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. :return: task_instance context as dict. """ - from datetime import datetime - from airflow import settings params = {} @@ -661,61 +659,3 @@ def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: except KeyError: outlet_events = context["outlet_events"] = OutletEventAccessors() return outlet_events - - -def get_ti_count( - dag_id: str, - task_ids: list[str] | None = None, - task_group_id: str | None = None, - logical_dates: list[datetime] | None = None, - run_ids: list[str] | None = None, - states: list[str] | None = None, -) -> int: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - """Return the number of task instances matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() - - if TYPE_CHECKING: - assert isinstance(response, TICount) - - return response.count - - -def get_dr_count( - dag_id: str, - logical_dates: list[datetime] | None = None, - run_ids: list[str] | None = None, - states: list[str] | None = None, -) -> int: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - """Return the number of DAG runs matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() - - if TYPE_CHECKING: - assert isinstance(response, DRCount) - - return response.count From bbf3f520f3f01bda5fa285347221db970f9320bd Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sat, 5 Apr 2025 00:19:22 +0100 Subject: [PATCH 03/12] Fix WorkflowTrigger to work with TaskSDK --- task-sdk/src/airflow/sdk/execution_time/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index c3c95b148cd23..2ff7dbbda0871 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -19,7 +19,6 @@ import collections import contextlib from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence -from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union @@ -599,6 +598,8 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. :return: task_instance context as dict. """ + from datetime import datetime + from airflow import settings params = {} From 7fe6077831b391a4697d432e9602497a5a0e51c9 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Mon, 7 Apr 2025 19:14:28 +0100 Subject: [PATCH 04/12] Use async lock --- .../standard/triggers/external_task.py | 6 +- .../src/airflow/sdk/execution_time/context.py | 64 +++++++++++++++++++ .../airflow/sdk/execution_time/task_runner.py | 25 ++++---- 3 files changed, 80 insertions(+), 15 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index 9675f99138d45..37d17055a448b 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -124,12 +124,12 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: await asyncio.sleep(self.poke_interval) async def _get_count_af_3(self, states): - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + from airflow.sdk.execution_time.context import get_dr_count, get_ti_count run_id_or_dates = self.run_ids or self.execution_dates or [] if self.external_task_ids or self.external_task_group_id: - count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( + count = await get_ti_count( dag_id=self.external_dag_id, task_ids=self.external_task_ids, task_group_id=self.external_task_group_id, @@ -138,7 +138,7 @@ async def _get_count_af_3(self, states): states=states, ) else: - count = await sync_to_async(RuntimeTaskInstance.get_dr_count)( + count = await get_dr_count( dag_id=self.external_dag_id, logical_dates=self.execution_dates, run_ids=self.run_ids, diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 2ff7dbbda0871..5fc2da2b96c6f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -19,6 +19,7 @@ import collections import contextlib from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence +from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union @@ -660,3 +661,66 @@ def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: except KeyError: outlet_events = context["outlet_events"] = OutletEventAccessors() return outlet_events + + +async def get_ti_count( + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, +) -> int: + """Return the number of task instances matching the given criteria.""" + from airflow.sdk.execution_time.comms import GetTICount, TICount + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + log = structlog.get_logger(logger_name="task") + + async with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetTICount( + dag_id=dag_id, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() + if not isinstance(response, TICount): + raise TypeError(f"Expected TICount, received: {type(response)} {response}") + + return response.count + + +async def get_dr_count( + dag_id: str, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, +) -> int: + """Return the number of DAG runs matching the given criteria.""" + from airflow.sdk.execution_time.comms import DRCount, GetDRCount + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + log = structlog.get_logger(logger_name="task") + + async with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() + + if not isinstance(response, DRCount): + raise TypeError(f"Expected DRCount, received: {type(response)} {response}") + + return response.count diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9cecdc5547d44..04b24b94d13e7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -413,18 +413,19 @@ def get_ti_count( """Return the number of task instances matching the given criteria.""" log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetTICount( + dag_id=dag_id, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TICount) From bd9b2bd241b20be130c0bf87eb8ff7b7f813221a Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Mon, 7 Apr 2025 19:18:56 +0100 Subject: [PATCH 05/12] Use async lock --- .../tests/unit/standard/triggers/test_external_task.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index 0828d091060f8..d2821110139ed 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -48,7 +48,7 @@ class TestWorkflowTrigger: EXECUTION_DATE = timezone.datetime(2022, 1, 1) @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.context.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_success(self, mock_get_count): """check the db count get called correctly.""" @@ -84,7 +84,7 @@ async def test_task_workflow_trigger_success(self, mock_get_count): await gen.__anext__() @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.context.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_failed(self, mock_get_count): mock_get_count.side_effect = mocked_get_count @@ -119,7 +119,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): with pytest.raises(StopAsyncIteration): await gen.__anext__() - @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.context.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): mock_get_count.return_value = 0 @@ -153,7 +153,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): await gen.__anext__() @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.context.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_skipped(self, mock_get_count): mock_get_count.side_effect = mocked_get_count @@ -184,7 +184,7 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): states=["success", "fail"], ) - @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.context.get_ti_count") @mock.patch("asyncio.sleep") @pytest.mark.asyncio async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_count): From dd116f17559a465e059eed18903397378fdc04cf Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Mon, 7 Apr 2025 19:58:28 +0100 Subject: [PATCH 06/12] Fix tests --- .../tests/unit/standard/triggers/test_external_task.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index d2821110139ed..aa9fc01c0039e 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -66,6 +66,7 @@ async def test_task_workflow_trigger_success(self, mock_get_count): trigger_task = asyncio.create_task(gen.__anext__()) fake_task = asyncio.create_task(fake_async_fun()) await trigger_task + await fake_task assert fake_task.done() # confirm that get_count is done in an async fashion assert trigger_task.done() result = trigger_task.result() @@ -102,6 +103,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): trigger_task = asyncio.create_task(gen.__anext__()) fake_task = asyncio.create_task(fake_async_fun()) await trigger_task + await fake_task assert fake_task.done() # confirm that get_count is done in an async fashion assert trigger_task.done() result = trigger_task.result() @@ -119,8 +121,8 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): with pytest.raises(StopAsyncIteration): await gen.__anext__() - @mock.patch("airflow.sdk.execution_time.context.get_ti_count") @pytest.mark.asyncio + @mock.patch("airflow.sdk.execution_time.context.get_ti_count") async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): mock_get_count.return_value = 0 @@ -170,6 +172,7 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): trigger_task = asyncio.create_task(gen.__anext__()) fake_task = asyncio.create_task(fake_async_fun()) await trigger_task + await fake_task assert fake_task.done() # confirm that get_count is done in an async fashion assert trigger_task.done() result = trigger_task.result() From fa29f1a506f5af8ad774599ebeb0ba140a9cf44a Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Mon, 7 Apr 2025 21:08:39 +0100 Subject: [PATCH 07/12] Use RuntimeTaskInstance taskinstance methods --- .../standard/triggers/external_task.py | 6 +- .../standard/triggers/test_external_task.py | 10 +-- .../src/airflow/sdk/execution_time/context.py | 65 ------------------- .../airflow/sdk/execution_time/task_runner.py | 21 +++--- 4 files changed, 19 insertions(+), 83 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index 37d17055a448b..9675f99138d45 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -124,12 +124,12 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: await asyncio.sleep(self.poke_interval) async def _get_count_af_3(self, states): - from airflow.sdk.execution_time.context import get_dr_count, get_ti_count + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance run_id_or_dates = self.run_ids or self.execution_dates or [] if self.external_task_ids or self.external_task_group_id: - count = await get_ti_count( + count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( dag_id=self.external_dag_id, task_ids=self.external_task_ids, task_group_id=self.external_task_group_id, @@ -138,7 +138,7 @@ async def _get_count_af_3(self, states): states=states, ) else: - count = await get_dr_count( + count = await sync_to_async(RuntimeTaskInstance.get_dr_count)( dag_id=self.external_dag_id, logical_dates=self.execution_dates, run_ids=self.run_ids, diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index aa9fc01c0039e..932640479e1ef 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -48,7 +48,7 @@ class TestWorkflowTrigger: EXECUTION_DATE = timezone.datetime(2022, 1, 1) @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.sdk.execution_time.context.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_success(self, mock_get_count): """check the db count get called correctly.""" @@ -85,7 +85,7 @@ async def test_task_workflow_trigger_success(self, mock_get_count): await gen.__anext__() @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.sdk.execution_time.context.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_failed(self, mock_get_count): mock_get_count.side_effect = mocked_get_count @@ -122,7 +122,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): await gen.__anext__() @pytest.mark.asyncio - @mock.patch("airflow.sdk.execution_time.context.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): mock_get_count.return_value = 0 @@ -155,7 +155,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): await gen.__anext__() @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.sdk.execution_time.context.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") @pytest.mark.asyncio async def test_task_workflow_trigger_skipped(self, mock_get_count): mock_get_count.side_effect = mocked_get_count @@ -187,7 +187,7 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): states=["success", "fail"], ) - @mock.patch("airflow.sdk.execution_time.context.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") @mock.patch("asyncio.sleep") @pytest.mark.asyncio async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_count): diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 5fc2da2b96c6f..c3c95b148cd23 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -599,8 +599,6 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. :return: task_instance context as dict. """ - from datetime import datetime - from airflow import settings params = {} @@ -661,66 +659,3 @@ def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: except KeyError: outlet_events = context["outlet_events"] = OutletEventAccessors() return outlet_events - - -async def get_ti_count( - dag_id: str, - task_ids: list[str] | None = None, - task_group_id: str | None = None, - logical_dates: list[datetime] | None = None, - run_ids: list[str] | None = None, - states: list[str] | None = None, -) -> int: - """Return the number of task instances matching the given criteria.""" - from airflow.sdk.execution_time.comms import GetTICount, TICount - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - log = structlog.get_logger(logger_name="task") - - async with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() - if not isinstance(response, TICount): - raise TypeError(f"Expected TICount, received: {type(response)} {response}") - - return response.count - - -async def get_dr_count( - dag_id: str, - logical_dates: list[datetime] | None = None, - run_ids: list[str] | None = None, - states: list[str] | None = None, -) -> int: - """Return the number of DAG runs matching the given criteria.""" - from airflow.sdk.execution_time.comms import DRCount, GetDRCount - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - log = structlog.get_logger(logger_name="task") - - async with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() - - if not isinstance(response, DRCount): - raise TypeError(f"Expected DRCount, received: {type(response)} {response}") - - return response.count diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 04b24b94d13e7..2d5a3add593bd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -442,16 +442,17 @@ def get_dr_count( """Return the number of DAG runs matching the given criteria.""" log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, DRCount) From f0844e398aeed76f6b26f9137c3a503385a3b90a Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Mon, 7 Apr 2025 21:39:16 +0100 Subject: [PATCH 08/12] Fix trigger tests --- airflow-core/tests/unit/jobs/test_triggerer_job.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index ecb48c2a66448..50320681df1de 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -851,7 +851,6 @@ async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, d "run_ids": [parent_dag_run.run_id], }, ) - trigger_orm.id = 1 session.add(trigger_orm) session.commit() task_instance.trigger_id = trigger_orm.id From e30c7a12448561f3e6e8314783c23d6c4078462b Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Mon, 7 Apr 2025 22:46:06 +0100 Subject: [PATCH 09/12] Fix wrong import place --- task-sdk/src/airflow/sdk/execution_time/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index c3c95b148cd23..2ff7dbbda0871 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -19,7 +19,6 @@ import collections import contextlib from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence -from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union @@ -599,6 +598,8 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. :return: task_instance context as dict. """ + from datetime import datetime + from airflow import settings params = {} From d76c0bafe0fa82f7bf7f6fc3275750d2772ece68 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Tue, 8 Apr 2025 17:08:28 +0100 Subject: [PATCH 10/12] Fix pre-commit errors --- .../src/airflow/jobs/triggerer_job_runner.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 25159d83e2eab..be7fe23004b1c 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -50,12 +50,12 @@ GetConnection, GetDagRunState, GetDRCount, + GetTICount, GetVariable, GetXCom, -GetTICount, + TICount, VariableResult, XComResult, - TICount, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.stats import Stats @@ -236,7 +236,15 @@ class TriggerStateSync(BaseModel): ToTriggerSupervisor = Annotated[ - Union[messages.TriggerStateChanges, GetConnection, GetVariable, GetXCom, GetTICount, GetDagRunState, GetDRCount], + Union[ + messages.TriggerStateChanges, + GetConnection, + GetVariable, + GetXCom, + GetTICount, + GetDagRunState, + GetDRCount, + ], Field(discriminator="type"), ] """ From 23655ac7b492367db3c23627f33227baa3521459 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Tue, 8 Apr 2025 18:59:39 +0100 Subject: [PATCH 11/12] Update with logical_dates --- .../standard/sensors/external_task.py | 6 ++++- .../standard/triggers/external_task.py | 21 +++++++++++----- .../standard/triggers/test_external_task.py | 24 +++++++++---------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index a2bcdad876003..1d0924604562a 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -421,18 +421,22 @@ def execute(self, context: Context) -> None: if not self.deferrable: super().execute(context) else: + dttm_filter = self._get_dttm_filter(context) + logical_or_execution_dates = ( + {"logical_dates": dttm_filter} if AIRFLOW_V_3_0_PLUS else {"execution_date": dttm_filter} + ) self.defer( timeout=self.execution_timeout, trigger=WorkflowTrigger( external_dag_id=self.external_dag_id, external_task_group_id=self.external_task_group_id, external_task_ids=self.external_task_ids, - execution_dates=self._get_dttm_filter(context), allowed_states=self.allowed_states, failed_states=self.failed_states, skipped_states=self.skipped_states, poke_interval=self.poll_interval, soft_fail=self.soft_fail, + **logical_or_execution_dates, ), method_name="execute_complete", ) diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index 9675f99138d45..580ebb6409059 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -50,6 +50,7 @@ class WorkflowTrigger(BaseTrigger): :param allowed_states: States considered as successful for external tasks. :param poke_interval: The interval (in seconds) for poking the external tasks. :param soft_fail: If True, the trigger will not fail the entire dag on external task failure. + :param logical_dates: A list of logical dates for the external dag. """ def __init__( @@ -57,6 +58,7 @@ def __init__( external_dag_id: str, run_ids: list[str] | None = None, execution_dates: list[datetime] | None = None, + logical_dates: list[datetime] | None = None, external_task_ids: typing.Collection[str] | None = None, external_task_group_id: str | None = None, failed_states: typing.Iterable[str] | None = None, @@ -76,6 +78,7 @@ def __init__( self.poke_interval = poke_interval self.soft_fail = soft_fail self.execution_dates = execution_dates + self.logical_dates = logical_dates super().__init__(**kwargs) def serialize(self) -> tuple[str, dict[str, Any]]: @@ -89,17 +92,23 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "allowed_states": self.allowed_states, "poke_interval": self.poke_interval, "soft_fail": self.soft_fail, - "execution_dates": self.execution_dates, } if AIRFLOW_V_3_0_PLUS: data["run_ids"] = self.run_ids + data["logical_dates"] = self.logical_dates + else: + data["execution_dates"] = self.execution_dates return "airflow.providers.standard.triggers.external_task.WorkflowTrigger", data async def run(self) -> typing.AsyncIterator[TriggerEvent]: """Check periodically tasks, task group or dag status.""" - get_count_func = self._get_count_af_3 if AIRFLOW_V_3_0_PLUS else self._get_count - run_id_or_dates = self.run_ids or self.execution_dates or [] + if AIRFLOW_V_3_0_PLUS: + get_count_func = self._get_count_af_3 + run_id_or_dates = (self.run_ids or self.logical_dates) or [] + else: + get_count_func = self._get_count + run_id_or_dates = self.execution_dates or [] while True: if self.failed_states: @@ -126,21 +135,21 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: async def _get_count_af_3(self, states): from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance - run_id_or_dates = self.run_ids or self.execution_dates or [] + run_id_or_dates = (self.run_ids or self.logical_dates) or [] if self.external_task_ids or self.external_task_group_id: count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( dag_id=self.external_dag_id, task_ids=self.external_task_ids, task_group_id=self.external_task_group_id, - logical_dates=self.execution_dates, + logical_dates=self.logical_dates, run_ids=self.run_ids, states=states, ) else: count = await sync_to_async(RuntimeTaskInstance.get_dr_count)( dag_id=self.external_dag_id, - logical_dates=self.execution_dates, + logical_dates=self.logical_dates, run_ids=self.run_ids, states=states, ) diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index 932640479e1ef..aefeb4c4fe87d 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -45,7 +45,7 @@ class TestWorkflowTrigger: TASK_ID = "external_task_op" RUN_ID = "external_task_run_id" STATES = ["success", "fail"] - EXECUTION_DATE = timezone.datetime(2022, 1, 1) + LOGICAL_DATE = timezone.datetime(2022, 1, 1) @pytest.mark.flaky(reruns=5) @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") @@ -56,7 +56,7 @@ async def test_task_workflow_trigger_success(self, mock_get_count): trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - execution_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], external_task_ids=[self.TASK_ID], allowed_states=self.STATES, poke_interval=0.2, @@ -76,7 +76,7 @@ async def test_task_workflow_trigger_success(self, mock_get_count): dag_id="external_task", task_ids=["external_task_op"], task_group_id=None, - logical_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=None, states=["success", "fail"], ) @@ -92,7 +92,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - execution_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], external_task_ids=[self.TASK_ID], failed_states=self.STATES, @@ -113,7 +113,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): dag_id="external_task", task_ids=["external_task_op"], task_group_id=None, - logical_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], states=["success", "fail"], ) @@ -128,7 +128,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - execution_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], external_task_ids=[self.TASK_ID], failed_states=self.STATES, @@ -146,7 +146,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): dag_id="external_task", task_ids=["external_task_op"], task_group_id=None, - logical_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], states=["success", "fail"], ) @@ -162,7 +162,7 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - execution_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], external_task_ids=[self.TASK_ID], skipped_states=self.STATES, poke_interval=0.2, @@ -182,7 +182,7 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): dag_id="external_task", task_ids=["external_task_op"], task_group_id=None, - logical_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=None, states=["success", "fail"], ) @@ -195,7 +195,7 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - execution_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], external_task_ids=[self.TASK_ID], poke_interval=0.2, ) @@ -223,7 +223,7 @@ def test_serialization(self): """ trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - execution_dates=[self.EXECUTION_DATE], + logical_dates=[self.LOGICAL_DATE], run_ids=[self.RUN_ID], failed_states=["failed"], skipped_states=["skipped"], @@ -235,7 +235,7 @@ def test_serialization(self): assert classpath == "airflow.providers.standard.triggers.external_task.WorkflowTrigger" assert kwargs == { "external_dag_id": self.DAG_ID, - "execution_dates": [self.EXECUTION_DATE], + "logical_dates": [self.LOGICAL_DATE], "run_ids": [self.RUN_ID], "external_task_ids": [self.TASK_ID], "external_task_group_id": None, From 177234d4608e73249ba937162c72b94d49e184bf Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Tue, 8 Apr 2025 21:26:20 +0100 Subject: [PATCH 12/12] Fix count calculation --- .../standard/sensors/external_task.py | 2 - .../standard/triggers/external_task.py | 4 - .../standard/triggers/test_external_task.py | 95 +++++++++++++++++++ 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 1d0924604562a..e64eb0d6763b0 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -346,8 +346,6 @@ def _calculate_count(self, count: int, dttm_filter: list[datetime.datetime]) -> """Calculate the normalized count based on the type of check.""" if self.external_task_ids: return count / len(self.external_task_ids) - elif self.external_task_group_id: - return count / len(dttm_filter) else: return count diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index 580ebb6409059..a5f8b67f54972 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -135,8 +135,6 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: async def _get_count_af_3(self, states): from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance - run_id_or_dates = (self.run_ids or self.logical_dates) or [] - if self.external_task_ids or self.external_task_group_id: count = await sync_to_async(RuntimeTaskInstance.get_ti_count)( dag_id=self.external_dag_id, @@ -156,8 +154,6 @@ async def _get_count_af_3(self, states): if self.external_task_ids: return count / len(self.external_task_ids) - elif self.external_task_group_id: - return count / len(run_id_or_dates) else: return count diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index aefeb4c4fe87d..3aa581257d79e 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -217,6 +217,101 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co mock_sleep.assert_awaited() assert mock_sleep.await_count == 1 + @pytest.mark.parametrize( + "task_ids, task_group_id, states, logical_dates, mock_ti_count, mock_dag_count, expected", + [ + ( + ["task_id_one", "task_id_two"], + None, + ["success"], + [ + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + ], + 4, + 2, + 2, + ), + ( + [], + "task_group_id", + ["success"], + [ + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + ], + 2, + 2, + 2, + ), + ( + [], + None, + ["success"], + [ + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc), + ], + 2, + 2, + 2, + ), + ], + ids=[ + "with_task_ids", + "with task_group_id only", + "no task_ids or task_group_id", + ], + ) + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") + @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dr_count") + @pytest.mark.asyncio + async def test_get_count_af_3( + self, + mock_get_dr_count, + mock_get_ti_count, + task_ids, + task_group_id, + states, + logical_dates, + mock_ti_count, + mock_dag_count, + expected, + ): + """ + case1: when provided two task_ids, and two dag runs, the get_ti_count should return 4(each dag run returns two tasks) + and normalized count becomes 2 + case2: when provided task_group_id, and two dag runs, the get_ti_count should return 2(each dag run returns 1 task group) + case3: when not provided any task_ids or task_group_id, the get_dr_count should return 2(total dag runs 2) + """ + + mock_get_ti_count.return_value = mock_ti_count + mock_get_dr_count.return_value = mock_dag_count + + trigger = WorkflowTrigger( + external_dag_id=self.DAG_ID, + logical_dates=logical_dates, + external_task_ids=task_ids, + external_task_group_id=task_group_id, + allowed_states=states, + poke_interval=0.2, + ) + + get_count_af_3 = await trigger._get_count_af_3(states) + assert get_count_af_3 == expected + + if task_ids or task_group_id: + mock_get_ti_count.assert_called_once() + assert mock_get_ti_count.call_count == 1 + mock_get_dr_count.assert_not_called() + assert mock_get_dr_count.call_count == 0 + + if not task_ids and not task_group_id: + mock_get_dr_count.assert_called_once() + assert mock_get_dr_count.call_count == 1 + mock_get_ti_count.assert_not_called() + assert mock_get_ti_count.call_count == 0 + def test_serialization(self): """ Asserts that the WorkflowTrigger correctly serializes its arguments and classpath.