diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dagrun.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dagrun.py index 1f73156cebf5a..7875d0fd43d13 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dagrun.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dagrun.py @@ -36,3 +36,9 @@ class DagRunStateResponse(BaseModel): """Schema for DAG Run State response.""" state: DagRunState + + +class DagRunCountResponse(BaseModel): + """Schema for DAG Count response.""" + + count: int diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index 3a680c1ef8c69..aeb4cd76f2034 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -18,14 +18,20 @@ from __future__ import annotations import logging +from datetime import datetime +from typing import Annotated -from fastapi import HTTPException, status -from sqlalchemy import select +from fastapi import HTTPException, Query, status +from sqlalchemy import func, select from airflow.api.common.trigger_dag import trigger_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload +from airflow.api_fastapi.execution_api.datamodels.dagrun import ( + DagRunCountResponse, + DagRunStateResponse, + TriggerDAGRunPayload, +) from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag @@ -150,3 +156,39 @@ def get_dagrun_state( ) return DagRunStateResponse(state=dag_run.state) + + +@router.get( + "/{dag_id}/count", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, + }, +) +def get_dag_count( + dag_id: str, + session: SessionDep, + run_ids: Annotated[list[str] | None, Query()] = None, + states: Annotated[list[str] | None, Query()] = None, + logical_dates: Annotated[list[datetime] | None, Query()] = None, +) -> DagRunCountResponse: + """Get the count of DAGs by run_ids and states.""" + dm = session.scalar(select(DagModel).where(DagModel.is_active, DagModel.dag_id == dag_id).limit(1)) + if not dm: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={"reason": "not_found", "message": f"DAG with dag_id: '{dag_id}' not found"}, + ) + + query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id) + + if run_ids: + query = query.where(DagRun.run_id.in_(run_ids)) + + if states: + query = query.where(DagRun.state.in_(states)) + + if logical_dates: + query = query.where(DagRun.logical_date.in_(logical_dates)) + + result = session.scalar(query) + return DagRunCountResponse(count=result or 0) diff --git a/airflow-core/src/airflow/exceptions.py b/airflow-core/src/airflow/exceptions.py index c659f6b91db42..1c836fbe5f05f 100644 --- a/airflow-core/src/airflow/exceptions.py +++ b/airflow-core/src/airflow/exceptions.py @@ -434,6 +434,7 @@ def __init__( allowed_states: list[str | DagRunState], failed_states: list[str | DagRunState], poke_interval: int, + deferrable: bool, ): super().__init__() self.trigger_dag_id = trigger_dag_id @@ -446,6 +447,29 @@ def __init__( self.allowed_states = allowed_states self.failed_states = failed_states self.poke_interval = poke_interval + self.deferrable = deferrable + + +class DagRunTriggerExecuteCompleteException(AirflowException): + """ + Signal by an operator execute complete method. + + Special exception raised to signal that the operator ``TriggerDagRunOperator`` to execute_complete method. + """ + + def __init__( + self, + *, + trigger_dag_id: str, + run_ids: str, + allowed_states: list[str | DagRunState], + failed_states: list[str | DagRunState], + ): + super().__init__() + self.trigger_dag_id = trigger_dag_id + self.run_ids = run_ids + self.allowed_states = allowed_states + self.failed_states = failed_states class TaskDeferred(BaseException): diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 421bd28d031d7..d0053dd5ea754 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -44,8 +44,10 @@ from airflow.models.trigger import Trigger from airflow.sdk.execution_time.comms import ( ConnectionResult, + DagRunCountResult, ErrorResponse, GetConnection, + GetDagRunCount, GetVariable, GetXCom, VariableResult, @@ -212,6 +214,7 @@ class TriggerStateChanges(BaseModel): ConnectionResult, VariableResult, XComResult, + DagRunCountResult, ErrorResponse, ], Field(discriminator="type"), @@ -223,7 +226,7 @@ class TriggerStateChanges(BaseModel): ToTriggerSupervisor = Annotated[ - Union[messages.TriggerStateChanges, GetConnection, GetVariable, GetXCom], + Union[messages.TriggerStateChanges, GetConnection, GetVariable, GetXCom, GetDagRunCount], Field(discriminator="type"), ] """ @@ -333,7 +336,12 @@ def client(self) -> Client: return client def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override] - from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse, XComResponse + from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + DagRunCountResponse, + VariableResponse, + XComResponse, + ) resp = None @@ -371,6 +379,18 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - resp = xcom_result.model_dump_json(exclude_unset=True).encode() else: resp = xcom.model_dump_json().encode() + + elif isinstance(msg, GetDagRunCount): + dr_resp = self.client.dag_runs.get_dag_run_count( + msg.dag_id, msg.run_ids, msg.states, msg.logical_dates + ) + + if isinstance(dr_resp, DagRunCountResponse): + dag_run_state_count_result = DagRunCountResult.from_api_response(dr_resp) + resp = dag_run_state_count_result.model_dump_json(exclude_unset=True).encode() + else: + resp = dr_resp.model_dump_json().encode() + else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py index c5624f188b51a..881323365f5b5 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py @@ -218,3 +218,105 @@ def test_dag_run_not_found(self, client): response = client.post(f"/execution/dag-runs/{dag_id}/{run_id}/clear") assert response.status_code == 404 + + +class TestDagRunCount: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_dag_run_count(self, client, session, dag_maker): + dag_id = "test_dag_run_count_by_run_ids_and_states" + run_id1 = "test_run_id1" + run_id2 = "test_run_id2" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + dag_maker.create_dagrun( + run_id=run_id1, state=DagRunState.SUCCESS, logical_date=timezone.datetime(2025, 2, 20) + ) + dag_maker.create_dagrun( + run_id=run_id2, state=DagRunState.SUCCESS, logical_date=timezone.datetime(2025, 3, 20) + ) + + session.commit() + + response = client.get( + f"/execution/dag-runs/{dag_id}/count", + params={"run_ids": [run_id1, run_id2], "states": [DagRunState.SUCCESS]}, + ) + + assert response.status_code == 200 + assert response.json() == {"count": 2} + + def test_dag_run_count_by_run_ids_and_success_failure_states(self, client, session, dag_maker): + dag_id = "test_dag_run_count_by_run_ids_and_states" + run_id1 = "test_run_id3" + run_id2 = "test_run_id4" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + dag_maker.create_dagrun( + run_id=run_id1, state=DagRunState.SUCCESS, logical_date=timezone.datetime(2025, 4, 20) + ) + dag_maker.create_dagrun( + run_id=run_id2, state=DagRunState.FAILED, logical_date=timezone.datetime(2025, 5, 20) + ) + + session.commit() + + response = client.get( + f"/execution/dag-runs/{dag_id}/count", + params={"run_ids": [run_id1, run_id2], "states": [DagRunState.SUCCESS, DagRunState.FAILED]}, + ) + + assert response.status_code == 200 + assert response.json() == {"count": 2} + + def test_dag_run_count_by_logical_dates_and_success_failure_states(self, client, session, dag_maker): + dag_id = "dag_run_count_by_logical_dates_and_success_failure_states" + run_one_logical_date = timezone.datetime(2025, 4, 20) + run_two_logical_date = timezone.datetime(2025, 5, 20) + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + dag_maker.create_dagrun( + run_id="test_run_id5", state=DagRunState.SUCCESS, logical_date=timezone.datetime(2025, 4, 20) + ) + dag_maker.create_dagrun( + run_id="test_run_id6", state=DagRunState.FAILED, logical_date=timezone.datetime(2025, 5, 20) + ) + + session.commit() + + response = client.get( + f"/execution/dag-runs/{dag_id}/count", + params={ + "logical_dates": [run_one_logical_date, run_two_logical_date], + "states": [DagRunState.SUCCESS, DagRunState.FAILED], + }, + ) + + assert response.status_code == 200 + assert response.json() == {"count": 2} + + def test_dag_run_count_by_run_ids_and_states_dag_not_found(self, client): + dag_id = "dag_not_found" + + response = client.get( + f"/execution/dag-runs/{dag_id}/count", + params={"run_ids": ["test_run_id1"], "states": [DagRunState.SUCCESS]}, + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "DAG with dag_id: 'dag_not_found' not found", + "reason": "not_found", + } + } diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index e53b893178df5..5a03b3195ec6b 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -734,3 +734,81 @@ async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_m "xcom": '"some_xcom_value"', } } + + +class CustomTriggerDagRun(BaseTrigger): + def __init__(self, trigger_dag_id, run_ids, states, logical_dates): + self.trigger_dag_id = trigger_dag_id + self.run_ids = run_ids + self.states = states + self.logical_dates = logical_dates + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + f"{type(self).__module__}.{type(self).__qualname__}", + { + "trigger_dag_id": self.trigger_dag_id, + "run_ids": self.run_ids, + "states": self.states, + "logical_dates": self.logical_dates, + }, + ) + + async def run(self, **args) -> AsyncIterator[TriggerEvent]: + from airflow.sdk.execution_time.context import get_dag_run_count + + print(self.trigger_dag_id, self.run_ids, self.states) + dag_run_states_count = await sync_to_async(get_dag_run_count)( + dag_id=self.trigger_dag_id, + run_ids=self.run_ids, + states=self.states, + logical_dates=self.logical_dates, + ) + yield TriggerEvent({"count": dag_run_states_count.count}) + + +@pytest.mark.xfail( + reason="We know that test is flaky and have no time to fix it before 3.0. " + "We should fix it later. TODO: AIP-72" +) +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=2, reruns_delay=10) +@pytest.mark.execution_timeout(30) +async def test_trigger_can_fetch_trigger_dag_run_count_in_deferrable(session, dag_maker): + """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" + # Create the test DAG and task + with dag_maker(dag_id="trigger_can_fetch_trigger_dag_run_count_in_deferrable", session=session): + EmptyOperator(task_id="dummy1") + dr = dag_maker.create_dagrun() + task_instance = dr.task_instances[0] + task_instance.state = TaskInstanceState.DEFERRED + + # Use the same dag run with states deferred to fetch the count + trigger = CustomTriggerDagRun( + trigger_dag_id=dr.dag_id, run_ids=[dr.run_id], states=[dr.state], logical_dates=[dr.logical_date] + ) + trigger_orm = Trigger( + classpath=trigger.serialize()[0], + kwargs={ + "trigger_dag_id": dr.dag_id, + "run_ids": [dr.run_id], + "states": [dr.state], + "logical_dates": [dr.logical_date], + }, + ) + trigger_orm.id = 1 + session.add(trigger_orm) + session.commit() + task_instance.trigger_id = trigger_orm.id + + job = Job() + session.add(job) + session.commit() + + supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None) + supervisor.run() + + task_instance.refresh_from_db() + assert task_instance.state == TaskInstanceState.SCHEDULED + assert task_instance.next_method != "__fail__" + assert task_instance.next_kwargs == {"event": {"count": 1}} diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 11977a081d3ff..fd37d9effb3ef 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -41,7 +41,7 @@ from airflow.providers.standard.triggers.external_task import DagStateTrigger from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -231,6 +231,7 @@ def _trigger_dag_af_3(self, context, run_id, parsed_logical_date): allowed_states=self.allowed_states, failed_states=self.failed_states, poke_interval=self.poke_interval, + deferrable=self._defer, ) # TODO: Support deferral @@ -304,12 +305,27 @@ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date): self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) return + def execute_complete(self, context: Context, event: tuple[str, dict[str, Any]]): + if AIRFLOW_V_3_0_PLUS: + from airflow.exceptions import DagRunTriggerExecuteCompleteException + + raise DagRunTriggerExecuteCompleteException( + trigger_dag_id=self.trigger_dag_id, + run_ids=event[1]["run_ids"], + allowed_states=self.allowed_states, # type: ignore[arg-type] + failed_states=self.failed_states, # type: ignore[arg-type] + ) + else: + self._trigger_dag_run_af_2_execute_complete(event=event) + @provide_session - def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]): + def _trigger_dag_run_af_2_execute_complete( + self, event: tuple[str, dict[str, Any]], session: Session = NEW_SESSION + ): # This logical_date is parsed from the return trigger event + provided_logical_date = event[1]["execution_dates"][0] try: - # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag_run = session.execute( select(DagRun).where( DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index 3c1e6747685e8..5b4cad3c70f55 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -160,8 +160,8 @@ def __init__( super().__init__() self.dag_id = dag_id self.states = states - self.run_ids = run_ids - self.execution_dates = execution_dates + self.run_ids = run_ids or [] + self.execution_dates = execution_dates or [] self.poll_interval = poll_interval def serialize(self) -> tuple[str, dict[str, typing.Any]]: @@ -181,18 +181,33 @@ def serialize(self) -> tuple[str, dict[str, typing.Any]]: async def run(self) -> typing.AsyncIterator[TriggerEvent]: """Check periodically if the dag run exists, and has hit one of the states yet, or not.""" + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.context import get_dag_run_count + + count_runs_ids_or_dates = len(self.run_ids) + else: + count_runs_ids_or_dates = len(self.execution_dates) + while True: - # mypy confuses typing here - num_dags = await self.count_dags() # type: ignore[call-arg] - _dates = self.run_ids if AIRFLOW_V_3_0_PLUS else self.execution_dates - if num_dags == len(_dates): # type: ignore[arg-type] + if AIRFLOW_V_3_0_PLUS: + dag_run_count_result = await sync_to_async(get_dag_run_count)( + dag_id=self.dag_id, + run_ids=self.run_ids, + states=self.states, # type: ignore[arg-type] + logical_dates=self.execution_dates, + ) + num_dags = dag_run_count_result.count + else: + num_dags = await self.count_dags() # type: ignore[call-arg] + + if num_dags == count_runs_ids_or_dates: yield TriggerEvent(self.serialize()) return await asyncio.sleep(self.poll_interval) @sync_to_async @provide_session - def count_dags(self, *, session: Session = NEW_SESSION) -> int | None: + def count_dags(self, *, session: Session = NEW_SESSION) -> int: """Count how many dag runs in the database match our criteria.""" _dag_run_date_condition = ( DagRun.run_id.in_(self.run_ids) diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index 90280a4805a92..5fcce3ad67887 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -239,17 +239,18 @@ class TestDagStateTrigger: @pytest.mark.db_test @pytest.mark.asyncio + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 3 had a different implementation") async def test_dag_state_trigger(self, session): """ Assert that the DagStateTrigger only goes off on or after a DagRun reaches an allowed state (i.e. SUCCESS). """ dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1)) - run_id_or_execution_date = ( - {"run_id": "external_task_run_id"} - if AIRFLOW_V_3_0_PLUS - else {"execution_date": timezone.datetime(2022, 1, 1), "run_id": "external_task_run_id"} - ) + run_id_or_execution_date = { + "execution_date": timezone.datetime(2022, 1, 1), + "run_id": "external_task_run_id", + } + dag_run = DagRun(dag_id=dag.dag_id, run_type="manual", **run_id_or_execution_date) session.add(dag_run) session.commit() @@ -276,6 +277,55 @@ async def test_dag_state_trigger(self, session): # Prevents error when task is destroyed while in "pending" state asyncio.get_event_loop().stop() + @pytest.mark.db_test + @pytest.mark.asyncio + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 had a different implementation") + @mock.patch("airflow.sdk.execution_time.context.get_dag_run_count") + async def test_dag_state_trigger_af_3(self, mock_get_dag_run_count, session): + """ + Assert that the DagStateTrigger only goes off on or after a DagRun + reaches an allowed state (i.e. SUCCESS). + """ + + # Mock the get_dag_run_count_by_run_ids_and_states function to return 0 first time + mock_get_dag_run_count.return_value = mock.Mock(count=0) + dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1)) + + dag_run = DagRun( + dag_id=dag.dag_id, + run_type="manual", + run_id="external_task_run_id", + logical_date=timezone.datetime(2022, 1, 1), + ) + session.add(dag_run) + session.commit() + + trigger = DagStateTrigger( + dag_id=dag.dag_id, + states=self.STATES, + run_ids=["external_task_run_id"], + poll_interval=0.2, + execution_dates=[timezone.datetime(2022, 1, 1)], + ) + + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # It should not have produced a result + assert task.done() is False + + # Progress the dag to a "success" state so that yields a TriggerEvent + dag_run.state = DagRunState.SUCCESS + session.commit() + + # Mock the get_dag_run_count_by_run_ids_and_states function to return 1 second time + mock_get_dag_run_count.return_value = mock.Mock(count=1) + await asyncio.sleep(0.5) + assert task.done() is True + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + def test_serialization(self): """Asserts that the DagStateTrigger correctly serializes its arguments and classpath.""" trigger = DagStateTrigger( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index cfffd2b3823f0..f9fd7c258e5fa 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -38,6 +38,7 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + DagRunCountResponse, DagRunStateResponse, DagRunType, PrevSuccessfulDagRunResponse, @@ -452,6 +453,28 @@ def get_state(self, dag_id: str, run_id: str) -> DagRunStateResponse: resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state") return DagRunStateResponse.model_validate_json(resp.read()) + def get_dag_run_count( + self, + dag_id: str, + run_ids: list[str] | None = None, + states: list[str] | None = None, + logical_dates: list[datetime] | None = None, + ) -> DagRunCountResponse: + """Get the count of dag runs by run ids, states, logical_dates via the API server.""" + params = {} + + if run_ids: + params.update({"run_ids": run_ids}) + + if states: + params.update({"states": states}) + + if logical_dates is not None: + params.update({"logical_dates": [date.isoformat() for date in logical_dates]}) + + resp = self.client.get(f"dag-runs/{dag_id}/count", params=params) + return DagRunCountResponse.model_validate_json(resp.read()) + class BearerAuth(httpx.Auth): def __init__(self, token: str): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index d0c7ae17a37ed..c7af36d546f0a 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -96,6 +96,14 @@ class DagRunAssetReference(BaseModel): data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None +class DagRunCountResponse(BaseModel): + """ + Schema for DAG Count response. + """ + + count: Annotated[int, Field(title="Count")] + + class DagRunState(str, Enum): """ All possible states that a DagRun can be in. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d579aa8fb4866..a6bf42c266e33 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -55,6 +55,7 @@ AssetResponse, BundleInfo, ConnectionResponse, + DagRunCountResponse, DagRunStateResponse, PrevSuccessfulDagRunResponse, TaskInstance, @@ -201,6 +202,21 @@ def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStat return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") +class DagRunCountResult(DagRunCountResponse): + type: Literal["DagRunCountResult"] = "DagRunCountResult" + + @classmethod + def from_api_response(cls, dag_run_count_response: DagRunCountResponse) -> DagRunCountResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we need to convert it to Result + for communication between the Supervisor and the task process since it needs a + discriminator field. + """ + return cls(**dag_run_count_response.model_dump(exclude_defaults=True), type="DagRunCountResult") + + class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" @@ -239,6 +255,7 @@ class OKResponse(BaseModel): AssetEventsResult, ConnectionResult, DagRunStateResult, + DagRunCountResult, ErrorResponse, PrevSuccessfulDagRunResult, StartupDetails, @@ -413,6 +430,14 @@ class GetDagRunState(BaseModel): type: Literal["GetDagRunState"] = "GetDagRunState" +class GetDagRunCount(BaseModel): + dag_id: str + run_ids: list[str] | None = None + states: list[str] | None = None + logical_dates: list[datetime] | None = None + type: Literal["GetDagRunCount"] = "GetDagRunCount" + + class GetAssetByName(BaseModel): name: str type: Literal["GetAssetByName"] = "GetAssetByName" @@ -455,6 +480,7 @@ class GetTaskRescheduleStartDate(BaseModel): GetAssetEventByAssetAlias, GetConnection, GetDagRunState, + GetDagRunCount, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, GetVariable, diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 654f52e2f4a1f..b5360e8b6184b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -18,6 +18,7 @@ import contextlib from collections.abc import Generator, Iterator, Mapping +from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Any, Union @@ -50,6 +51,7 @@ AssetEventsResult, AssetResult, ConnectionResult, + DagRunCountResult, PrevSuccessfulDagRunResponse, VariableResult, ) @@ -583,3 +585,30 @@ def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: except KeyError: outlet_events = context["outlet_events"] = OutletEventAccessors() return outlet_events + + +def get_dag_run_count( + dag_id: str, + states: list[str] | None = None, + run_ids: list[str] | None = None, + logical_dates: list[datetime] | None = None, +) -> DagRunCountResult: + from airflow.sdk.execution_time.comms import GetDagRunCount + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetDagRunCount( + dag_id=dag_id, + states=states, + run_ids=run_ids, + logical_dates=logical_dates, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(msg, DagRunCountResult) + + return msg diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 860756fa06a6a..fc77e99b4cac2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -67,6 +67,7 @@ AssetEventsResult, AssetResult, ConnectionResult, + DagRunCountResult, DagRunStateResult, DeferTask, DeleteXCom, @@ -76,6 +77,7 @@ GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, + GetDagRunCount, GetDagRunState, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, @@ -984,6 +986,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): elif isinstance(msg, GetDagRunState): dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id) resp = DagRunStateResult.from_api_response(dr_resp).model_dump_json().encode() + elif isinstance(msg, GetDagRunCount): + dr_resp = self.client.dag_runs.get_dag_run_count( + msg.dag_id, msg.run_ids, msg.states, msg.logical_dates + ) + resp = DagRunCountResult.from_api_response(dr_resp).model_dump_json().encode() + elif isinstance(msg, GetTaskRescheduleStartDate): tr_resp = self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number) resp = tr_resp.model_dump_json().encode() diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d5d738f52546b..d98678911f56e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -94,7 +94,7 @@ from pendulum.datetime import DateTime from structlog.typing import FilteringBoundLogger as Logger - from airflow.exceptions import DagRunTriggerException + from airflow.exceptions import DagRunTriggerException, DagRunTriggerExecuteCompleteException, TaskDeferred from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions.context import Context from airflow.sdk.types import OutletEventAccessorsProtocol @@ -628,6 +628,24 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv return None +def _defer_task( + defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger +) -> tuple[ToSupervisor, IntermediateTIState]: + log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) + classpath, trigger_kwargs = defer.trigger.serialize() + + msg = DeferTask( + classpath=classpath, + trigger_kwargs=trigger_kwargs, + trigger_timeout=defer.timeout, + next_method=defer.method_name, + next_kwargs=defer.kwargs or {}, + ) + state = IntermediateTIState.DEFERRED + + return msg, state + + def run( ti: RuntimeTaskInstance, context: Context, @@ -643,6 +661,7 @@ def run( AirflowTaskTerminated, AirflowTaskTimeout, DagRunTriggerException, + DagRunTriggerExecuteCompleteException, DownstreamTasksSkipped, TaskDeferred, ) @@ -687,19 +706,10 @@ def run( msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) + except DagRunTriggerExecuteCompleteException as drte: + msg, state = _handle_trigger_dag_run_execute_complete(drte, context, ti, log) except TaskDeferred as defer: - # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? - log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) - classpath, trigger_kwargs = defer.trigger.serialize() - - msg = DeferTask( - classpath=classpath, - trigger_kwargs=trigger_kwargs, - trigger_timeout=defer.timeout, - next_method=defer.method_name, - next_kwargs=defer.kwargs or {}, - ) - state = IntermediateTIState.DEFERRED + msg, state = _defer_task(defer, ti, log) except AirflowSkipException as e: if e.args: log.info("Skipping task.", reason=e.args[0]) @@ -783,7 +793,7 @@ def _handle_current_task_failed( def _handle_trigger_dag_run( drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger -) -> tuple[ToSupervisor, TerminalTIState]: +) -> tuple[ToSupervisor, IntermediateTIState | TerminalTIState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) SUPERVISOR_COMMS.send_request( @@ -819,7 +829,23 @@ def _handle_trigger_dag_run( # be used when creating the extra link on the webserver. ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) - if drte.wait_for_completion: + if drte.deferrable: + from airflow.exceptions import TaskDeferred + from airflow.providers.standard.triggers.external_task import DagStateTrigger + + defer = TaskDeferred( + trigger=DagStateTrigger( + dag_id=drte.trigger_dag_id, + states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] + execution_dates=[drte.logical_date], # type: ignore[list-item] + run_ids=[drte.dag_run_id], + poll_interval=drte.poke_interval, + ), + method_name="execute_complete", + ) + return _defer_task(defer, ti, log) + + elif drte.wait_for_completion: while True: log.info( "Waiting for dag run to complete execution in allowed state.", @@ -856,6 +882,44 @@ def _handle_trigger_dag_run( return _handle_current_task_success(context, ti) +def _handle_trigger_dag_run_execute_complete( + drte: DagRunTriggerExecuteCompleteException, context: Context, ti: RuntimeTaskInstance, log: Logger +) -> tuple[ToSupervisor, TerminalTIState]: + run_ids = drte.run_ids + failed_run_id_conditions = [] + + for run_id in run_ids: + SUPERVISOR_COMMS.send_request(log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, run_id=run_id)) + comms_msg = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(comms_msg, DagRunStateResult) + + if comms_msg.state in drte.failed_states: + failed_run_id_conditions.append(run_id) + continue + if comms_msg.state in drte.allowed_states: + log.info( + "%s finished with allowed state %s for run_id %s", + drte.trigger_dag_id, + comms_msg.state, + run_id, + ) + + if failed_run_id_conditions: + log.error( + "%s failed with failed states %s for run_ids %s", + drte.trigger_dag_id, + drte.failed_states, + failed_run_id_conditions, + ) + msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) + state = TerminalTIState.FAILED + return msg, state + + return _handle_current_task_success(context, ti) + + def _run_task_state_change_callbacks( task: BaseOperator, kind: Literal[ diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index f0092f8aea994..2ef5548ef0735 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -30,6 +30,7 @@ from airflow.sdk.api.datamodels._generated import ( AssetResponse, ConnectionResponse, + DagRunCountResponse, DagRunState, DagRunStateResponse, VariableResponse, @@ -904,6 +905,44 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result == DagRunStateResponse(state=DagRunState.RUNNING) + def test_get_dag_count_by_run_ids_and_states(self): + """Test that the client can get the count of dag runs by run ids and states""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_state/count": + return httpx.Response( + status_code=200, + json={"count": 1}, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_dag_run_count( + dag_id="test_state", + run_ids=["test_run_id"], + states=[DagRunState.RUNNING], + ) + assert result == DagRunCountResponse(count=1) + + def test_get_dag_count_by_logical_dates_and_states(self): + """Test that the client can get the count of dag runs by run ids and states""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/test_state/count": + return httpx.Response( + status_code=200, + json={"count": 1}, + ) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_dag_run_count( + dag_id="test_state", + logical_dates=[timezone.datetime(2022, 1, 1)], + states=[DagRunState.RUNNING], + ) + assert result == DagRunCountResponse(count=1) + class TestTaskRescheduleOperations: def test_get_start_date(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 9598c7599e81b..78f36ae3ed617 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -57,6 +57,7 @@ AssetEventsResult, AssetResult, ConnectionResult, + DagRunCountResult, DagRunStateResult, DeferTask, DeleteXCom, @@ -66,6 +67,7 @@ GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, + GetDagRunCount, GetDagRunState, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, @@ -1340,6 +1342,36 @@ def watched_subprocess(self, mocker): DagRunStateResult(state=DagRunState.RUNNING), id="get_dag_run_state", ), + pytest.param( + GetDagRunCount(dag_id="test_dag", run_ids=["test_run1", "test_run2"], states=["success"]), + b'{"count":2,"type":"DagRunCountResult"}\n', + "dag_runs.get_dag_run_count", + ("test_dag", ["test_run1", "test_run2"], ["success"], None), + {}, + DagRunCountResult(count=2), + id="get_dag_run_count", + ), + pytest.param( + GetDagRunCount( + dag_id="test_dag", + logical_dates=[ + timezone.parse("2025-01-10T12:00:00Z"), + timezone.parse("2025-02-10T12:00:00Z"), + ], + states=["success"], + ), + b'{"count":2,"type":"DagRunCountResult"}\n', + "dag_runs.get_dag_run_count", + ( + "test_dag", + None, + ["success"], + [timezone.parse("2025-01-10T12:00:00Z"), timezone.parse("2025-02-10T12:00:00Z")], + ), + {}, + DagRunCountResult(count=2), + id="get_dag_run_count_by_logical_date_states", + ), pytest.param( GetTaskRescheduleStartDate(ti_id=TI_ID), b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n', diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 150b691fc179d..70aad58dafb9e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2325,3 +2325,41 @@ def test_handle_trigger_dag_run_wait_for_completion( ), ] mock_supervisor_comms.assert_has_calls(expected_calls) + + @pytest.mark.parametrize( + ["allowed_states", "failed_states", "intermediate_state"], + [ + ([DagRunState.SUCCESS], None, IntermediateTIState.DEFERRED), + ], + ) + def test_handle_trigger_dag_run_deferred( + self, + allowed_states, + failed_states, + intermediate_state, + create_runtime_ti, + mock_supervisor_comms, + ): + """ + Test that TriggerDagRunOperator defers when the deferrable flag is set to True + + """ + from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator + + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id="test_dag", + trigger_run_id="test_run_id", + poke_interval=5, + wait_for_completion=False, + allowed_states=allowed_states, + failed_states=failed_states, + deferrable=True, + ) + ti = create_runtime_ti(dag_id="test_handle_trigger_dag_run_deferred", run_id="test_run", task=task) + + log = mock.MagicMock() + with mock.patch("time.sleep", return_value=None): + state, msg, _ = run(ti, ti.get_template_context(), log) + + assert state == intermediate_state