diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index 3a680c1ef8c69..a439860113862 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -18,13 +18,15 @@ from __future__ import annotations import logging +from typing import Annotated -from fastapi import HTTPException, status -from sqlalchemy import select +from fastapi import HTTPException, Query, status +from sqlalchemy import func, select from airflow.api.common.trigger_dag import trigger_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel @@ -150,3 +152,27 @@ def get_dagrun_state( ) return DagRunStateResponse(state=dag_run.state) + + +@router.get("/count", status_code=status.HTTP_200_OK) +def get_dr_count( + dag_id: str, + session: SessionDep, + logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None, + run_ids: Annotated[list[str] | None, Query()] = None, + states: Annotated[list[str] | None, Query()] = None, +) -> int: + """Get the count of DAG runs matching the given criteria.""" + query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id) + + if logical_dates: + query = query.where(DagRun.logical_date.in_(logical_dates)) + + if run_ids: + query = query.where(DagRun.run_id.in_(run_ids)) + + if states: + query = query.where(DagRun.state.in_(states)) + + count = session.scalar(query) + return count or 0 diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index ff0ca51631463..6dd0732c0778e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -23,13 +23,14 @@ from uuid import UUID from cadwyn import VersionedAPIRouter -from fastapi import Body, Depends, HTTPException, status +from fastapi import Body, Depends, HTTPException, Query, status from pydantic import JsonValue -from sqlalchemy import func, tuple_, update +from sqlalchemy import func, or_, tuple_, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.sql import select from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( PrevSuccessfulDagRunResponse, TIDeferredStatePayload, @@ -45,6 +46,7 @@ TITerminalStatePayload, ) from airflow.api_fastapi.execution_api.deps import JWTBearer +from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, _update_rtif from airflow.models.taskreschedule import TaskReschedule @@ -53,7 +55,9 @@ from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState -router = VersionedAPIRouter( +router = VersionedAPIRouter() + +ti_id_router = VersionedAPIRouter( dependencies=[ # This checks that the UUID in the url matches the one in the token for us. Depends(JWTBearer(path_param_name="task_instance_id")), @@ -64,7 +68,7 @@ log = logging.getLogger(__name__) -@router.patch( +@ti_id_router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, responses={ @@ -243,7 +247,7 @@ def ti_run( ) -@router.patch( +@ti_id_router.patch( "/{task_instance_id}/state", status_code=status.HTTP_204_NO_CONTENT, responses={ @@ -404,7 +408,7 @@ def ti_update_state( ) -@router.patch( +@ti_id_router.patch( "/{task_instance_id}/skip-downstream", status_code=status.HTTP_204_NO_CONTENT, responses={ @@ -436,7 +440,7 @@ def ti_skip_downstream( log.info("TI %s updated the state of %s task(s) to skipped", ti_id_str, result.rowcount) -@router.put( +@ti_id_router.put( "/{task_instance_id}/heartbeat", status_code=status.HTTP_204_NO_CONTENT, responses={ @@ -498,7 +502,7 @@ def ti_heartbeat( log.debug("Task with %s state heartbeated", previous_state) -@router.put( +@ti_id_router.put( "/{task_instance_id}/rtif", status_code=status.HTTP_201_CREATED, # TODO: Add description to the operation @@ -528,7 +532,7 @@ def ti_put_rtif( return {"message": "Rendered task instance fields successfully set"} -@router.get( +@ti_id_router.get( "/{task_instance_id}/previous-successful-dagrun", status_code=status.HTTP_200_OK, responses={ @@ -564,8 +568,86 @@ def get_previous_successful_dagrun( return PrevSuccessfulDagRunResponse.model_validate(dag_run) -@router.only_exists_in_older_versions -@router.post( +@router.get("/count", status_code=status.HTTP_200_OK) +def get_count( + dag_id: str, + session: SessionDep, + task_ids: Annotated[list[str] | None, Query()] = None, + task_group_id: Annotated[str | None, Query()] = None, + logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None, + run_ids: Annotated[list[str] | None, Query()] = None, + states: Annotated[list[str] | None, Query()] = None, +) -> int: + """Get the count of task instances matching the given criteria.""" + query = select(func.count()).select_from(TI).where(TI.dag_id == dag_id) + + if task_ids: + query = query.where(TI.task_id.in_(task_ids)) + + if logical_dates: + query = query.where(TI.logical_date.in_(logical_dates)) + + if run_ids: + query = query.where(TI.run_id.in_(run_ids)) + + if task_group_id: + # Get all tasks in the task group + dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session) + if not dag: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"DAG {dag_id} not found", + }, + ) + + task_group = dag.task_group_dict.get(task_group_id) + if not task_group: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Task group {task_group_id} not found in DAG {dag_id}", + }, + ) + + # First get all task instances to get the task_id, map_index pairs + group_tasks = session.scalars( + select(TI).where( + TI.dag_id == dag_id, + TI.task_id.in_(task.task_id for task in task_group.iter_tasks()), + *([TI.logical_date.in_(logical_dates)] if logical_dates else []), + *([TI.run_id.in_(run_ids)] if run_ids else []), + ) + ).all() + + # Get unique (task_id, map_index) pairs + task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks] + if not task_map_pairs: + # If no task group tasks found, default to checking the task group ID itself + # This matches the behavior in _get_external_task_group_task_ids + task_map_pairs = [(task_group_id, -1)] + + # Update query to use task_id, map_index pairs + query = query.where(tuple_(TI.task_id, TI.map_index).in_(task_map_pairs)) + + if states: + if "null" in states: + not_none_states = [s for s in states if s != "null"] + if not_none_states: + query = query.where(or_(TI.state.is_(None), TI.state.in_(not_none_states))) + else: + query = query.where(TI.state.is_(None)) + else: + query = query.where(TI.state.in_(states)) + + count = session.scalar(query) + return count or 0 + + +@ti_id_router.only_exists_in_older_versions +@ti_id_router.post( "/{task_instance_id}/runtime-checks", status_code=status.HTTP_204_NO_CONTENT, # TODO: Add description to the operation @@ -602,3 +684,7 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: # max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for # retries from the task SDK now, we can handle using max_tries return max_tries != 0 and try_number <= max_tries + + +# This line should be at the end of the file to ensure all routes are registered +router.include_router(ti_id_router) diff --git a/airflow-core/tests/conftest.py b/airflow-core/tests/conftest.py index c5affb469f496..c605cc7648e73 100644 --- a/airflow-core/tests/conftest.py +++ b/airflow-core/tests/conftest.py @@ -78,17 +78,6 @@ def clear_all_logger_handlers(): remove_all_non_pytest_log_handlers() -@pytest.fixture -def testing_dag_bundle(): - from airflow.models.dagbundle import DagBundleModel - from airflow.utils.session import create_session - - with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: - testing = DagBundleModel(name="testing") - session.add(testing) - - @contextmanager def _config_bundles(bundles: dict[str, Path | str]): from tests_common.test_utils.config import conf_vars diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py index c5624f188b51a..f9f8d489d3d26 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py @@ -23,7 +23,7 @@ from airflow.models.dagrun import DagRun from airflow.providers.standard.operators.empty import EmptyOperator from airflow.utils import timezone -from airflow.utils.state import DagRunState +from airflow.utils.state import DagRunState, State from tests_common.test_utils.db import clear_db_runs @@ -218,3 +218,99 @@ def test_dag_run_not_found(self, client): response = client.post(f"/execution/dag-runs/{dag_id}/{run_id}/clear") assert response.status_code == 404 + + +class TestGetDagRunCount: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_get_count_basic(self, client, session, dag_maker): + with dag_maker("test_dag"): + pass + dag_maker.create_dagrun() + session.commit() + + response = client.get("/execution/dag-runs/count", params={"dag_id": "test_dag"}) + assert response.status_code == 200 + assert response.json() == 1 + + def test_get_count_with_states(self, client, session, dag_maker): + """Test counting DAG runs in specific states.""" + with dag_maker("test_get_count_with_states"): + pass + + # Create DAG runs with different states + dag_maker.create_dagrun( + state=State.SUCCESS, logical_date=timezone.datetime(2025, 1, 1), run_id="test_run_id1" + ) + dag_maker.create_dagrun( + state=State.FAILED, logical_date=timezone.datetime(2025, 1, 2), run_id="test_run_id2" + ) + dag_maker.create_dagrun( + state=State.RUNNING, logical_date=timezone.datetime(2025, 1, 3), run_id="test_run_id3" + ) + session.commit() + + response = client.get( + "/execution/dag-runs/count", + params={"dag_id": "test_get_count_with_states", "states": [State.SUCCESS, State.FAILED]}, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_logical_dates(self, client, session, dag_maker): + with dag_maker("test_get_count_with_logical_dates"): + pass + + date1 = timezone.datetime(2025, 1, 1) + date2 = timezone.datetime(2025, 1, 2) + + dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1) + dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2) + session.commit() + + response = client.get( + "/execution/dag-runs/count", + params={ + "dag_id": "test_get_count_with_logical_dates", + "logical_dates": [date1.isoformat(), date2.isoformat()], + }, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_run_ids(self, client, session, dag_maker): + with dag_maker("test_get_count_with_run_ids"): + pass + + dag_maker.create_dagrun(run_id="run1", logical_date=timezone.datetime(2025, 1, 1)) + dag_maker.create_dagrun(run_id="run2", logical_date=timezone.datetime(2025, 1, 2)) + session.commit() + + response = client.get( + "/execution/dag-runs/count", + params={"dag_id": "test_get_count_with_run_ids", "run_ids": ["run1", "run2"]}, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_mixed_states(self, client, session, dag_maker): + with dag_maker("test_get_count_with_mixed"): + pass + dag_maker.create_dagrun( + state=State.SUCCESS, run_id="runid1", logical_date=timezone.datetime(2025, 1, 1) + ) + dag_maker.create_dagrun( + state=State.QUEUED, run_id="runid2", logical_date=timezone.datetime(2025, 1, 2) + ) + session.commit() + + response = client.get( + "/execution/dag-runs/count", + params={"dag_id": "test_get_count_with_mixed", "states": [State.SUCCESS, State.QUEUED]}, + ) + assert response.status_code == 200 + assert response.json() == 2 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 81c99b271be30..147209e967abc 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -33,6 +33,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk import TaskGroup from airflow.utils import timezone from airflow.utils.state import State, TaskInstanceState, TerminalTIState @@ -1223,3 +1224,167 @@ def test_get_start_date_with_try_number(self, client, session, create_task_insta response = client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2") assert response.status_code == 200 assert response.json() == "2024-01-02T00:00:00Z" + + +class TestGetCount: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_get_count_basic(self, client, session, create_task_instance): + create_task_instance(task_id="test_task", state=State.SUCCESS) + session.commit() + + response = client.get("/execution/task-instances/count", params={"dag_id": "dag"}) + assert response.status_code == 200 + assert response.json() == 1 + + def test_get_count_with_task_ids(self, client, session, create_task_instance): + for i in range(3): + create_task_instance( + task_id=f"task{i}", + state=State.SUCCESS, + dag_id="test_get_count_with_task_ids", + run_id=f"test_run_id{i}", + ) + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "test_get_count_with_task_ids", "task_ids": ["task1", "task2"]}, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_states(self, client, session, dag_maker): + """Test counting tasks in specific states.""" + with dag_maker("test_get_count_with_states"): + EmptyOperator(task_id="task1") + EmptyOperator(task_id="task2") + EmptyOperator(task_id="task3") + + dr = dag_maker.create_dagrun() + + tis = dr.get_task_instances() + + # Set different states for the task instances + for ti, state in zip(tis, [State.SUCCESS, State.FAILED, State.SKIPPED]): + ti.state = state + session.merge(ti) + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "test_get_count_with_states", "states": [State.SUCCESS, State.FAILED]}, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_logical_dates(self, client, session, dag_maker): + with dag_maker("test_get_count_with_logical_dates"): + EmptyOperator(task_id="task1") + + date1 = timezone.datetime(2025, 1, 1) + date2 = timezone.datetime(2025, 1, 2) + + dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1) + dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2) + + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={ + "dag_id": "test_get_count_with_logical_dates", + "logical_dates": [date1.isoformat(), date2.isoformat()], + }, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_run_ids(self, client, session, dag_maker): + with dag_maker("test_get_count_with_run_ids"): + EmptyOperator(task_id="task1") + + dag_maker.create_dagrun(run_id="run1", logical_date=timezone.datetime(2025, 1, 1)) + dag_maker.create_dagrun(run_id="run2", logical_date=timezone.datetime(2025, 1, 2)) + + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "test_get_count_with_run_ids", "run_ids": ["run1", "run2"]}, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_with_task_group(self, client, session, dag_maker): + with dag_maker(dag_id="test_dag", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + EmptyOperator(task_id="task2") + + with TaskGroup("group2"): + EmptyOperator(task_id="task3") + + dag_maker.create_dagrun(session=session) + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "test_dag", "task_group_id": "group1"}, + ) + assert response.status_code == 200 + assert response.json() == 2 + + def test_get_count_task_group_not_found(self, client, session, dag_maker): + with dag_maker(dag_id="test_get_count_task_group_not_found", serialized=True): + with TaskGroup("group1"): + EmptyOperator(task_id="task1") + dag_maker.create_dagrun(session=session) + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "test_get_count_task_group_not_found", "task_group_id": "non_existent_group"}, + ) + assert response.status_code == 404 + assert response.json()["detail"] == { + "reason": "not_found", + "message": "Task group non_existent_group not found in DAG test_get_count_task_group_not_found", + } + + def test_get_count_dag_not_found(self, client, session): + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "non_existent_dag", "task_group_id": "group1"}, + ) + assert response.status_code == 404 + assert response.json()["detail"] == { + "reason": "not_found", + "message": "DAG non_existent_dag not found", + } + + def test_get_count_with_none_state(self, client, session, create_task_instance): + create_task_instance(task_id="task1", dag_id="get_count_with_none", state=None) + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "get_count_with_none", "states": ["null"]}, + ) + assert response.status_code == 200 + assert response.json() == 1 + + def test_get_count_with_mixed_states(self, client, session, create_task_instance): + create_task_instance(task_id="task1", state=State.SUCCESS, run_id="runid1", dag_id="mixed_states") + create_task_instance(task_id="task2", state=None, run_id="runid2", dag_id="mixed_states") + session.commit() + + response = client.get( + "/execution/task-instances/count", + params={"dag_id": "mixed_states", "states": [State.SUCCESS, "null"]}, + ) + assert response.status_code == 200 + assert response.json() == 2 diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 3cc66ac9f9679..b92c27141d7b0 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2318,3 +2318,14 @@ def __call__( def mock_xcom_backend(): with mock.patch("airflow.sdk.execution_time.task_runner.XCom", create=True) as xcom_backend: yield xcom_backend + + +@pytest.fixture +def testing_dag_bundle(): + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 0a34cc5d48f2c..dd5ad6c4e7a3a 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -25,27 +25,32 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.external_task import WorkflowTrigger from airflow.providers.standard.utils.sensor_helper import _get_count, _get_external_task_group_task_ids from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.sensors.base import BaseSensorOperator from airflow.utils.file import correct_maybe_zipped from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State, TaskInstanceState +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.bases.sensor import BaseSensorOperator +else: + from airflow.sensors.base import BaseSensorOperator + if TYPE_CHECKING: from sqlalchemy.orm import Session from airflow.models.taskinstancekey import TaskInstanceKey try: + from airflow.sdk import BaseOperator from airflow.sdk.definitions.context import Context except ImportError: # TODO: Remove once provider drops support for Airflow 2 + from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context @@ -65,15 +70,16 @@ class ExternalDagLink(BaseOperatorLink): name = "External DAG" def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: - from airflow.models.renderedtifields import RenderedTaskInstanceFields - if TYPE_CHECKING: assert isinstance(operator, (ExternalTaskMarker, ExternalTaskSensor)) - if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key): - external_dag_id: str = template_fields.get("external_dag_id", operator.external_dag_id) - else: - external_dag_id = operator.external_dag_id + external_dag_id = operator.external_dag_id + + if not AIRFLOW_V_3_0_PLUS: + from airflow.models.renderedtifields import RenderedTaskInstanceFields + + if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key): + external_dag_id: str = template_fields.get("external_dag_id", operator.external_dag_id) # type: ignore[no-redef] if AIRFLOW_V_3_0_PLUS: from airflow.utils.helpers import build_airflow_dagrun_url @@ -86,9 +92,7 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: return build_airflow_url_with_query(query) -# TODO: Remove BaseOperator from inheritance in https://github.com/apache/airflow/issues/47447 -# It is only temporary until we refactor the code to not directly go to the DB. -class ExternalTaskSensor(BaseSensorOperator, BaseOperator): +class ExternalTaskSensor(BaseSensorOperator): """ Waits for a different DAG, task group, or task to complete for a specific logical date. @@ -247,16 +251,22 @@ def __init__( self.poll_interval = poll_interval def _get_dttm_filter(self, context): + logical_date = context.get("logical_date") + if logical_date is None: + dag_run = context.get("dag_run") + if TYPE_CHECKING: + assert dag_run + + logical_date = dag_run.run_after if self.execution_delta: - dttm = context["logical_date"] - self.execution_delta + dttm = logical_date - self.execution_delta elif self.execution_date_fn: dttm = self._handle_execution_date_fn(context=context) else: - dttm = context["logical_date"] + dttm = logical_date return dttm if isinstance(dttm, list) else [dttm] - @provide_session - def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: + def poke(self, context: Context) -> bool: # delay check to poke rather than __init__ in case it was supplied as XComArgs if self.external_task_ids and len(self.external_task_ids) > len(set(self.external_task_ids)): raise ValueError("Duplicate task_ids passed in external_task_ids parameter") @@ -287,15 +297,62 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: serialized_dttm_filter, ) - # In poke mode this will check dag existence only once - if self.check_existence and not self._has_checked_existence: - self._check_for_existence(session=session) + if AIRFLOW_V_3_0_PLUS: + return self._poke_af3(context, dttm_filter) + else: + return self._poke_af2(dttm_filter) + + def _poke_af3(self, context: Context, dttm_filter: list[datetime.datetime]) -> bool: + self._has_checked_existence = True + ti = context["ti"] + + def _get_count(states: list[str]) -> int: + if self.external_task_ids: + return ti.get_ti_count( + dag_id=self.external_dag_id, + task_ids=self.external_task_ids, # type: ignore[arg-type] + logical_dates=dttm_filter, + states=states, + ) + elif self.external_task_group_id: + return ti.get_ti_count( + dag_id=self.external_dag_id, + task_group_id=self.external_task_group_id, + logical_dates=dttm_filter, + states=states, + ) + else: + return ti.get_dr_count( + dag_id=self.external_dag_id, + logical_dates=dttm_filter, + states=states, + ) - count_failed = -1 if self.failed_states: - count_failed = self.get_count(dttm_filter, session, self.failed_states) + count = _get_count(self.failed_states) + count_failed = self._calculate_count(count, dttm_filter) + self._handle_failed_states(count_failed) - # Fail if anything in the list has failed. + if self.skipped_states: + count = _get_count(self.skipped_states) + count_skipped = self._calculate_count(count, dttm_filter) + self._handle_skipped_states(count_skipped) + + count = _get_count(self.allowed_states) + count_allowed = self._calculate_count(count, dttm_filter) + return count_allowed == len(dttm_filter) + + def _calculate_count(self, count: int, dttm_filter: list[datetime.datetime]) -> float | int: + """Calculate the normalized count based on the type of check.""" + if self.external_task_ids: + return count / len(self.external_task_ids) + elif self.external_task_group_id: + return count / len(dttm_filter) + else: + return count + + def _handle_failed_states(self, count_failed: float | int) -> None: + """Handle failed states and raise appropriate exceptions.""" if count_failed > 0: if self.external_task_ids: if self.soft_fail: @@ -317,7 +374,6 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed." ) - else: if self.soft_fail: raise AirflowSkipException( @@ -325,12 +381,8 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: ) raise AirflowException(f"The external DAG {self.external_dag_id} failed.") - count_skipped = -1 - if self.skipped_states: - count_skipped = self.get_count(dttm_filter, session, self.skipped_states) - - # Skip if anything in the list has skipped. Note if we are checking multiple tasks and one skips - # before another errors, we'll skip first. + def _handle_skipped_states(self, count_skipped: float | int) -> None: + """Handle skipped states and raise appropriate exceptions.""" if count_skipped > 0: if self.external_task_ids: raise AirflowSkipException( @@ -348,7 +400,19 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: "Skipping." ) - # only go green if every single task has reached an allowed state + @provide_session + def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session = NEW_SESSION) -> bool: + if self.check_existence and not self._has_checked_existence: + self._check_for_existence(session=session) + + if self.failed_states: + count_failed = self.get_count(dttm_filter, session, self.failed_states) + self._handle_failed_states(count_failed) + + if self.skipped_states: + count_skipped = self.get_count(dttm_filter, session, self.skipped_states) + self._handle_skipped_states(count_skipped) + count_allowed = self.get_count(dttm_filter, session, self.allowed_states) return count_allowed == len(dttm_filter) @@ -483,6 +547,9 @@ class ExternalTaskMarker(EmptyOperator): """ template_fields = ["external_dag_id", "external_task_id", "logical_date"] + if not AIRFLOW_V_3_0_PLUS: + template_fields.append("execution_date") + ui_color = "#4db7db" operator_extra_links = [ExternalDagLink()] @@ -510,6 +577,9 @@ def __init__( f"Expected str or datetime.datetime type for logical_date. Got {type(logical_date)}" ) + if not AIRFLOW_V_3_0_PLUS: + self.execution_date = self.logical_date + if recursion_depth <= 0: raise ValueError("recursion_depth should be a positive integer") self.recursion_depth = recursion_depth diff --git a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py index ae5d3c12985bc..17d54e371bcb2 100644 --- a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py +++ b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py @@ -21,6 +21,7 @@ from sqlalchemy import func, select, tuple_ from airflow.models import DagBag, DagRun, TaskInstance +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -88,8 +89,10 @@ def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable: :param dttm_filter: date time filter for logical date :param external_dag_id: The ID of the external DAG. """ + date_field = model.logical_date if AIRFLOW_V_3_0_PLUS else model.execution_date + return select(func.count()).where( - model.dag_id == external_dag_id, model.state.in_(states), model.logical_date.in_(dttm_filter) + model.dag_id == external_dag_id, model.state.in_(states), date_field.in_(dttm_filter) ) @@ -106,11 +109,13 @@ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, exter task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id) if task_group: + date_field = TaskInstance.logical_date if AIRFLOW_V_3_0_PLUS else TaskInstance.execution_date + group_tasks = session.scalars( select(TaskInstance).filter( TaskInstance.dag_id == external_dag_id, TaskInstance.task_id.in_(task.task_id for task in task_group), - TaskInstance.logical_date.in_(dttm_filter), + date_field.in_(dttm_filter), ) ) diff --git a/airflow-core/tests/unit/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py similarity index 79% rename from airflow-core/tests/unit/sensors/test_external_task_sensor.py rename to providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index 1a7938cc27dd8..95deddd4ade62 100644 --- a/airflow-core/tests/unit/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -19,10 +19,7 @@ import itertools import logging -import os import re -import tempfile -import zipfile from datetime import time, timedelta from unittest import mock @@ -47,8 +44,7 @@ from airflow.providers.standard.triggers.external_task import WorkflowTrigger from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.timetables.base import DataInterval -from airflow.utils.hashlib_wrapper import md5 -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.timezone import coerce_datetime, datetime @@ -57,7 +53,6 @@ from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.mock_operators import MockOperator from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from unit.models import TEST_DAGS_FOLDER if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType @@ -81,42 +76,13 @@ def clean_db(): clear_db_runs() -@pytest.fixture -def dag_zip_maker(testing_dag_bundle): - class DagZipMaker: - def __call__(self, *dag_files): - self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), dag_file]) for dag_file in dag_files] - dag_files_hash = md5("".join(self.__dag_files).encode()).hexdigest() - self.__tmp_dir = os.sep.join([tempfile.tempdir, dag_files_hash]) - - self.__zip_file_name = os.sep.join([self.__tmp_dir, f"{dag_files_hash}.zip"]) - - if not os.path.exists(self.__tmp_dir): - os.mkdir(self.__tmp_dir) - return self - - def __enter__(self): - with zipfile.ZipFile(self.__zip_file_name, "x") as zf: - for dag_file in self.__dag_files: - zf.write(dag_file, os.path.basename(dag_file)) - dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False) - dagbag.sync_to_db("testing", None) - return dagbag - - def __exit__(self, exc_type, exc_val, exc_tb): - os.unlink(self.__zip_file_name) - os.rmdir(self.__tmp_dir) - - return DagZipMaker() - - -@pytest.mark.usefixtures("testing_dag_bundle") -class TestExternalTaskSensor: +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for v3.0+") +class TestExternalTaskSensorV2: def setup_method(self): self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) self.args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, schedule=None, default_args=self.args) - self.dag_run_id = DagRunType.MANUAL.generate_run_id(suffix=DEFAULT_DATE.isoformat()) + self.dag_run_id = DagRunType.MANUAL.generate_run_id(DEFAULT_DATE) def add_time_sensor(self, task_id=TEST_TASK_ID): # TODO: Remove BaseOperator in https://github.com/apache/airflow/issues/47447 @@ -133,7 +99,7 @@ def add_fake_task_group(self, target_states=None): with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group: _ = [EmptyOperator(task_id=f"task{i}") for i in range(len(target_states))] dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="test_bundle") + SerializedDagModel.write_dag(dag) for idx, task in enumerate(task_group): ti = TaskInstance(task=task, run_id=self.dag_run_id) @@ -156,7 +122,7 @@ def fake_mapped_task(x: int): fake_task() fake_mapped_task.expand(x=list(map_indexes)) dag.sync_to_db() - SerializedDagModel.write_dag(dag, bundle_name="test_bundle") + SerializedDagModel.write_dag(dag) for task in task_group: if task.task_id == "fake_mapped_task": @@ -530,7 +496,7 @@ def test_external_task_sensor_fn_multiple_logical_dates(self): .filter( TI.dag_id == dag_external_id, TI.state == State.FAILED, - TI.logical_date == DEFAULT_DATE + timedelta(seconds=1), + TI.execution_date == DEFAULT_DATE + timedelta(seconds=1), ) .all() ) @@ -977,10 +943,301 @@ def test_fail__check_for_existence( check_existence=True, **kwargs, ) + if not hasattr(op, "never_fail"): + expected_message = "Skipping due to soft_fail is set to True." if soft_fail else expected_message with pytest.raises(expected_exception, match=expected_message): op.execute(context={}) +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") +@pytest.mark.usefixtures("testing_dag_bundle") +class TestExternalTaskSensorV3: + def setup_method(self): + # Create a mock for TaskInstance with get_ti_count method + mock_ti = mock.MagicMock() + mock_ti.get_ti_count = mock.MagicMock(return_value=0) # Default return value + + self.context = { + "execution_date": DEFAULT_DATE, + "logical_date": DEFAULT_DATE, + "ti": mock_ti, + "task": mock.MagicMock(), + "run_id": "test_run_id", + } + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_success(self, dag_maker): + """Test that the sensor succeeds when the external task succeeds.""" + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_external_task_sensor_success", + allowed_states=["success"], + ) + + # Mimic DB response to get_ti_count as 1 + self.context["ti"].get_ti_count.return_value = 1 + + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=["success"], + task_ids=["test_external_task_sensor_success"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_failure(self, dag_maker): + """Test that the sensor fails when the external task fails.""" + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_external_task_sensor_failure", + failed_states=[State.FAILED], + ) + + self.context["ti"].get_ti_count.return_value = 1 + + with pytest.raises(AirflowException): + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=[State.FAILED], + task_ids=["test_external_task_sensor_failure"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_soft_fail(self, dag_maker): + """Test that the sensor skips when soft_fail is True and external task fails.""" + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_external_task_sensor_soft_fail", + failed_states=[State.FAILED], + soft_fail=True, + ) + + self.context["ti"].get_ti_count.return_value = 1 + + with pytest.raises(AirflowSkipException): + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=[State.FAILED], + task_ids=["test_external_task_sensor_soft_fail"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_multiple_task_ids(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_ids=["task1", "task2"], + allowed_states=["success"], + ) + + self.context["ti"].get_ti_count.return_value = 2 + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=["success"], + task_ids=["task1", "task2"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_skipped_states(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_external_task_sensor_skipped_states", + skipped_states=[State.SKIPPED], + ) + + self.context["ti"].get_ti_count.return_value = 1 + with pytest.raises(AirflowSkipException): + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=[State.SKIPPED], + task_ids=["test_external_task_sensor_skipped_states"], + ) + + def test_external_task_sensor_invalid_combination(self, dag_maker): + """Test that the sensor raises an error with invalid parameter combinations.""" + with pytest.raises(ValueError): + with dag_maker("test_external_task_sensor_invalid_combination"): + ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag", + external_task_id="test_task", + external_task_ids=["test_task"], + ) + + def test_external_task_sensor_invalid_state(self, dag_maker): + with pytest.raises(ValueError): + with dag_maker("test_external_task_sensor_invalid_state"): + ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag", + external_task_id="test_task", + allowed_states=["invalid_state"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_task_group(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_group_id="test_group", + allowed_states=["success"], + ) + + self.context["ti"].get_ti_count.return_value = 1 + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=["success"], + task_group_id="test_group", + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_execution_date_fn(self, dag_maker): + def execution_date_fn(dt): + return [dt + timedelta(hours=1)] + + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_task", + execution_date_fn=execution_date_fn, + allowed_states=["success"], + ) + + self.context["ti"].get_ti_count.return_value = 1 + op.execute(context=self.context) + + expected_date = DEFAULT_DATE + timedelta(hours=1) + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[expected_date], + states=["success"], + task_ids=["test_task"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_execution_delta(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_task", + execution_delta=timedelta(hours=1), + allowed_states=["success"], + ) + + self.context["ti"].get_ti_count.return_value = 1 + op.execute(context=self.context) + + expected_date = DEFAULT_DATE - timedelta(hours=1) + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[expected_date], + states=["success"], + task_ids=["test_task"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_duplicate_task_ids(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_ids=["task1", "task1"], + allowed_states=["success"], + ) + + with pytest.raises(ValueError, match="Duplicate task_ids passed in external_task_ids parameter"): + op.execute(context=self.context) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_deferrable(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_id="test_task", + deferrable=True, + allowed_states=["success"], + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(context=self.context) + + assert isinstance(exc.value.trigger, WorkflowTrigger) + assert exc.value.trigger.external_dag_id == "test_dag_parent" + assert exc.value.trigger.external_task_ids == ["test_task"] + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_only_dag_id(self, dag_maker): + """Test that the sensor works correctly when only external_dag_id is provided.""" + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + allowed_states=["success"], + ) + + self.context["ti"].get_dr_count = mock.MagicMock(return_value=1) + + op.execute(context=self.context) + + self.context["ti"].get_dr_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=["success"], + ) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_task_group_failed_states(self, dag_maker): + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_group_id="test_group", + failed_states=[State.FAILED], + ) + + self.context["ti"].get_ti_count.return_value = 1 + + with pytest.raises(AirflowException): + op.execute(context=self.context) + + self.context["ti"].get_ti_count.assert_called_once_with( + dag_id="test_dag_parent", + logical_dates=[DEFAULT_DATE], + states=[State.FAILED], + task_group_id="test_group", + ) + + class TestExternalTaskAsyncSensor: TASK_ID = "external_task_sensor_check" EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on @@ -1050,14 +1307,7 @@ def test_defer_execute_check_correct_logging(self): mock_log_info.assert_called_with("External tasks %s has executed successfully.", [EXTERNAL_TASK_ID]) -def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker): - with dag_zip_maker("test_external_task_sensor_check_existense.py") as dagbag: - with create_session() as session: - dag = dagbag.dags["test_external_task_sensor_check_existence"] - op = dag.tasks[0] - op._check_for_existence(session) - - +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Needs Flask app context fixture for AF 2") @pytest.mark.parametrize( argnames=["external_dag_id", "external_task_id", "expected_external_dag_id", "expected_external_task_id"], argvalues=[ @@ -1166,7 +1416,10 @@ def dag_bag_ext(): task_a_3 >> task_b_3 for dag in [dag_0, dag_1, dag_2, dag_3]: - dag_bag.bag_dag(dag=dag) + if AIRFLOW_V_3_0_PLUS: + dag_bag.bag_dag(dag=dag) + else: + dag_bag.bag_dag(dag=dag, root_dag=dag) yield dag_bag @@ -1215,7 +1468,10 @@ def dag_bag_parent_child(): ) for dag in [dag_0, dag_1]: - dag_bag.bag_dag(dag=dag) + if AIRFLOW_V_3_0_PLUS: + dag_bag.bag_dag(dag=dag) + else: + dag_bag.bag_dag(dag=dag, root_dag=dag) yield dag_bag @@ -1237,22 +1493,37 @@ def run_tasks( for dag in dag_bag.dags.values(): data_interval = DataInterval(coerce_datetime(logical_date), coerce_datetime(logical_date)) - runs[dag.dag_id] = dagrun = dag.create_dagrun( - run_id=dag.timetable.generate_run_id( - run_type=DagRunType.MANUAL, + if AIRFLOW_V_3_0_PLUS: + runs[dag.dag_id] = dagrun = dag.create_dagrun( + run_id=dag.timetable.generate_run_id( + run_type=DagRunType.MANUAL, + run_after=logical_date, + data_interval=data_interval, + ), + logical_date=logical_date, + data_interval=data_interval, run_after=logical_date, + run_type=DagRunType.MANUAL, + triggered_by=DagRunTriggeredByType.TEST, + dag_version=None, + state=DagRunState.RUNNING, + start_date=logical_date, + session=session, + ) + else: + runs[dag.dag_id] = dagrun = dag.create_dagrun( # type: ignore[call-arg] + run_id=dag.timetable.generate_run_id( # type: ignore[call-arg] + run_type=DagRunType.MANUAL, + logical_date=logical_date, + data_interval=data_interval, + ), + execution_date=logical_date, data_interval=data_interval, - ), - logical_date=logical_date, - data_interval=data_interval, - run_after=logical_date, - run_type=DagRunType.MANUAL, - triggered_by=DagRunTriggeredByType.TEST, - dag_version=None, - state=DagRunState.RUNNING, - start_date=logical_date, - session=session, - ) + run_type=DagRunType.MANUAL, + state=DagRunState.RUNNING, + start_date=logical_date, + session=session, + ) # we use sorting by task_id here because for the test DAG structure of ours # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS @@ -1290,7 +1561,7 @@ def clear_tasks( """ Clear the task and its downstream tasks recursively for the dag in the given dagbag. """ - partial: DAG = dag.partial_subset(task_ids=[task.task_id], include_downstream=True) + partial: DAG = dag.partial_subset(task_ids_or_regex=[task.task_id], include_downstream=True) return partial.clear( start_date=start_date, end_date=end_date, @@ -1300,6 +1571,7 @@ def clear_tasks( ) +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+") def test_external_task_marker_transitive(dag_bag_ext): """ Test clearing tasks across DAGs. @@ -1314,6 +1586,7 @@ def test_external_task_marker_transitive(dag_bag_ext): assert_ti_state_equal(ti_b_3, State.NONE) +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+") @provide_session def test_external_task_marker_clear_activate(dag_bag_parent_child, session): """ @@ -1326,22 +1599,9 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session): run_tasks(dag_bag, logical_date=day_1) run_tasks(dag_bag, logical_date=day_2) - from sqlalchemy import select - - run_ids = [] # Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared. - for dag, logical_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]): - run_id = ( - select(DagRun.run_id) - .where(DagRun.logical_date == logical_date) - .order_by(DagRun.id.desc()) - .limit(1) - ) - run_ids.append(run_id) - dagrun = dag.get_dagrun( - run_id=run_id, - session=session, - ) + for dag, execution_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]): + dagrun = dag.get_dagrun(execution_date=execution_date, session=session) dagrun.set_state(State.SUCCESS) session.flush() @@ -1351,10 +1611,10 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session): # Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared. # Unaffected dagruns should be left as SUCCESS. - dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[0], session=session) - dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[1], session=session) - dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[2], session=session) - dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[3], session=session) + dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(execution_date=day_1, session=session) + dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(execution_date=day_2, session=session) + dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(execution_date=day_1, session=session) + dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(execution_date=day_2, session=session) assert dagrun_0_1.state == State.QUEUED assert dagrun_0_2.state == State.QUEUED @@ -1362,6 +1622,7 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session): assert dagrun_1_2.state == State.SUCCESS +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+") def test_external_task_marker_future(dag_bag_ext): """ Test clearing tasks with no end_date. This is the case when users clear tasks with @@ -1386,6 +1647,7 @@ def test_external_task_marker_future(dag_bag_ext): assert_ti_state_equal(ti_b_3_date_1, State.NONE) +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+") def test_external_task_marker_exception(dag_bag_ext): """ Clearing across multiple DAGs should raise AirflowException if more levels are being cleared @@ -1463,13 +1725,17 @@ def _factory(depth: int) -> DagBag: task_a >> task_b for dag in dags: - dag_bag.bag_dag(dag=dag) + if AIRFLOW_V_3_0_PLUS: + dag_bag.bag_dag(dag=dag) + else: + dag_bag.bag_dag(dag=dag, root_dag=dag) # type: ignore[call-arg] return dag_bag return _factory +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+") def test_external_task_marker_cyclic_deep(dag_bag_cyclic): """ Tests clearing across multiple DAGs that have cyclic dependencies. AirflowException should be @@ -1483,6 +1749,7 @@ def test_external_task_marker_cyclic_deep(dag_bag_cyclic): clear_tasks(dag_bag, dag_0, task_a_0) +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+") def test_external_task_marker_cyclic_shallow(dag_bag_cyclic): """ Tests clearing across multiple DAGs that have cyclic dependencies shallower @@ -1513,8 +1780,13 @@ def dag_bag_multiple(): dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False) daily_dag = DAG("daily_dag", start_date=DEFAULT_DATE, schedule="@daily") agg_dag = DAG("agg_dag", start_date=DEFAULT_DATE, schedule="@daily") - dag_bag.bag_dag(dag=daily_dag) - dag_bag.bag_dag(dag=agg_dag) + + if AIRFLOW_V_3_0_PLUS: + dag_bag.bag_dag(dag=daily_dag) + dag_bag.bag_dag(dag=agg_dag) + else: + dag_bag.bag_dag(dag=daily_dag, root_dag=daily_dag) + dag_bag.bag_dag(dag=agg_dag, root_dag=agg_dag) daily_task = EmptyOperator(task_id="daily_tas", dag=daily_dag) @@ -1584,7 +1856,10 @@ def dag_bag_head_tail(): ) head >> body >> tail - dag_bag.bag_dag(dag=dag) + if AIRFLOW_V_3_0_PLUS: + dag_bag.bag_dag(dag=dag) + else: + dag_bag.bag_dag(dag=dag, root_dag=dag) return dag_bag @@ -1600,10 +1875,13 @@ def test_clear_overlapping_external_task_marker(dag_bag_head_tail, session): dag_id=dag.dag_id, start_date=logical_date, state=DagRunState.SUCCESS, - logical_date=logical_date, run_type=DagRunType.MANUAL, run_id=f"test_{delta}", ) + if AIRFLOW_V_3_0_PLUS: + dagrun.logical_date = logical_date + else: + dagrun.execution_date = logical_date session.add(dagrun) for task in dag.tasks: ti = TaskInstance(task=task) @@ -1625,10 +1903,13 @@ def test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail, dag_id=dag.dag_id, start_date=logical_date, state=DagRunState.SUCCESS, - logical_date=logical_date, run_type=DagRunType.MANUAL, run_id=f"test_{delta}", ) + if AIRFLOW_V_3_0_PLUS: + dagrun.logical_date = logical_date + else: + dagrun.execution_date = logical_date session.add(dagrun) for task in dag.tasks: ti = TaskInstance(task=task) @@ -1689,7 +1970,10 @@ def dummy_task(x: int): ) head >> body >> tail - dag_bag.bag_dag(dag=dag) + if AIRFLOW_V_3_0_PLUS: + dag_bag.bag_dag(dag=dag) + else: + dag_bag.bag_dag(dag=dag, root_dag=dag) return dag_bag @@ -1705,10 +1989,13 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m dag_id=dag.dag_id, start_date=logical_date, state=DagRunState.SUCCESS, - logical_date=logical_date, run_type=DagRunType.MANUAL, run_id=f"test_{delta}", ) + if AIRFLOW_V_3_0_PLUS: + dagrun.logical_date = logical_date + else: + dagrun.execution_date = logical_date session.add(dagrun) for task in dag.tasks: if task.task_id == "dummy_task": @@ -1721,12 +2008,19 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m ti.state = TaskInstanceState.SUCCESS dagrun.task_instances.append(ti) session.flush() + if AIRFLOW_V_3_0_PLUS: + dag = dag.partial_subset( + task_ids=["head"], + include_downstream=True, + include_upstream=False, + ) + else: + dag = dag.partial_subset( + task_ids_or_regex=["head"], + include_downstream=True, + include_upstream=False, + ) - dag = dag.partial_subset( - task_ids=["head"], - include_downstream=True, - include_upstream=False, - ) task_ids = list(dag.task_dict) assert ( dag.clear( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index cfffd2b3823f0..a2fff3a335ba7 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -59,10 +59,12 @@ ) from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( + DRCount, ErrorResponse, OKResponse, SkipDownstreamTasks, TaskRescheduleStartDate, + TICount, ) from airflow.utils.net import get_hostname from airflow.utils.platform import getuser @@ -200,6 +202,31 @@ def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> TaskR resp = self.client.get(f"task-reschedules/{id}/start_date", params={"try_number": try_number}) return TaskRescheduleStartDate.model_construct(start_date=resp.json()) + def get_count( + self, + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, + ) -> TICount: + """Get count of task instances matching the given criteria.""" + params = { + "dag_id": dag_id, + "task_ids": task_ids, + "task_group_id": task_group_id, + "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, + "run_ids": run_ids, + "states": states, + } + + # Remove None values from params + params = {k: v for k, v in params.items() if v is not None} + + resp = self.client.get("task-instances/count", params=params) + return TICount(count=resp.json()) + class ConnectionOperations: __slots__ = ("client",) @@ -452,6 +479,27 @@ def get_state(self, dag_id: str, run_id: str) -> DagRunStateResponse: resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state") return DagRunStateResponse.model_validate_json(resp.read()) + def get_count( + self, + dag_id: str, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, + ) -> DRCount: + """Get count of DAG runs matching the given criteria.""" + params = { + "dag_id": dag_id, + "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, + "run_ids": run_ids, + "states": states, + } + + # Remove None values from params + params = {k: v for k, v in params.items() if v is not None} + + resp = self.client.get("dag-runs/count", params=params) + return DRCount(count=resp.json()) + class BearerAuth(httpx.Auth): def __init__(self, token: str): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d579aa8fb4866..76c858f3ac1dd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -222,6 +222,20 @@ class TaskRescheduleStartDate(BaseModel): type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate" +class TICount(BaseModel): + """Response containing count of Task Instances matching certain filters.""" + + count: int + type: Literal["TICount"] = "TICount" + + +class DRCount(BaseModel): + """Response containing count of DAG Runs matching certain filters.""" + + count: int + type: Literal["DRCount"] = "DRCount" + + class ErrorResponse(BaseModel): error: ErrorType = ErrorType.GENERIC_ERROR detail: dict | None = None @@ -239,10 +253,12 @@ class OKResponse(BaseModel): AssetEventsResult, ConnectionResult, DagRunStateResult, + DRCount, ErrorResponse, PrevSuccessfulDagRunResult, StartupDetails, TaskRescheduleStartDate, + TICount, VariableResult, XComResult, XComCountResponse, @@ -445,30 +461,50 @@ class GetTaskRescheduleStartDate(BaseModel): type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate" +class GetTICount(BaseModel): + dag_id: str + task_ids: list[str] | None = None + task_group_id: str | None = None + logical_dates: list[AwareDatetime] | None = None + run_ids: list[str] | None = None + states: list[str] | None = None + type: Literal["GetTICount"] = "GetTICount" + + +class GetDRCount(BaseModel): + dag_id: str + logical_dates: list[AwareDatetime] | None = None + run_ids: list[str] | None = None + states: list[str] | None = None + type: Literal["GetDRCount"] = "GetDRCount" + + ToSupervisor = Annotated[ Union[ - SucceedTask, DeferTask, + DeleteXCom, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, GetDagRunState, + GetDRCount, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, + GetTICount, GetVariable, GetXCom, GetXComCount, PutVariable, RescheduleTask, RetryTask, - SkipDownstreamTasks, SetRenderedFields, SetXCom, + SkipDownstreamTasks, + SucceedTask, TaskState, TriggerDagRun, - DeleteXCom, ], Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index a353ffce23ce9..21ae3cd4cae66 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -77,8 +77,10 @@ GetAssetEventByAssetAlias, GetConnection, GetDagRunState, + GetDRCount, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, + GetTICount, GetVariable, GetXCom, GetXComCount, @@ -988,6 +990,24 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): elif isinstance(msg, GetTaskRescheduleStartDate): tr_resp = self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number) resp = tr_resp.model_dump_json().encode() + elif isinstance(msg, GetTICount): + ti_count = self.client.task_instances.get_count( + dag_id=msg.dag_id, + task_ids=msg.task_ids, + task_group_id=msg.task_group_id, + logical_dates=msg.logical_dates, + run_ids=msg.run_ids, + states=msg.states, + ) + resp = ti_count.model_dump_json().encode() + elif isinstance(msg, GetDRCount): + dr_count = self.client.dag_runs.get_count( + dag_id=msg.dag_id, + logical_dates=msg.logical_dates, + run_ids=msg.run_ids, + states=msg.states, + ) + resp = dr_count.model_dump_json().encode() else: log.error("Unhandled request", msg=msg) return diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d5d738f52546b..d89fbeb0aa902 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -58,9 +58,12 @@ from airflow.sdk.execution_time.comms import ( DagRunStateResult, DeferTask, + DRCount, ErrorResponse, GetDagRunState, + GetDRCount, GetTaskRescheduleStartDate, + GetTICount, RescheduleTask, RetryTask, SetRenderedFields, @@ -69,6 +72,7 @@ SucceedTask, TaskRescheduleStartDate, TaskState, + TICount, ToSupervisor, ToTask, TriggerDagRun, @@ -400,6 +404,62 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: return response.start_date + @staticmethod + def get_ti_count( + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, + ) -> int: + """Return the number of task instances matching the given criteria.""" + log = structlog.get_logger(logger_name="task") + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetTICount( + dag_id=dag_id, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(response, TICount) + + return response.count + + @staticmethod + def get_dr_count( + dag_id: str, + logical_dates: list[datetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, + ) -> int: + """Return the number of DAG runs matching the given criteria.""" + log = structlog.get_logger(logger_name="task") + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) + response = SUPERVISOR_COMMS.get_message() + + if TYPE_CHECKING: + assert isinstance(response, DRCount) + + return response.count + def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 6760ea5195990..d9544589cf948 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -87,6 +87,24 @@ def get_template_context(self) -> Context: ... def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | None: ... + @staticmethod + def get_ti_count( + dag_id: str, + task_ids: list[str] | None = None, + task_group_id: str | None = None, + logical_dates: list[AwareDatetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, + ) -> int: ... + + @staticmethod + def get_dr_count( + dag_id: str, + logical_dates: list[AwareDatetime] | None = None, + run_ids: list[str] | None = None, + states: list[str] | None = None, + ) -> int: ... + class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance): """Protocol for managing access to a specific outlet event accessor.""" diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index f0092f8aea994..3339a6cee4c76 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -405,6 +405,48 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result == {"ok": True} + def test_get_count_basic(self): + """Test basic get_count functionality with just dag_id.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/task-instances/count" + assert request.url.params.get("dag_id") == "test_dag" + return httpx.Response(200, json=5) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_instances.get_count(dag_id="test_dag") + assert result.count == 5 + + def test_get_count_with_all_params(self): + """Test get_count with all optional parameters.""" + + logical_dates_str = ["2024-01-01T00:00:00+00:00", "2024-01-02T00:00:00+00:00"] + logical_dates = [timezone.parse(d) for d in logical_dates_str] + task_ids = ["task1", "task2"] + states = ["success", "failed"] + + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/task-instances/count" + assert request.method == "GET" + params = request.url.params + assert params["dag_id"] == "test_dag" + assert params.get_list("task_ids") == task_ids + assert params["task_group_id"] == "group1" + assert params.get_list("logical_dates") == logical_dates_str + assert params.get_list("run_ids") == [] + assert params.get_list("states") == states + return httpx.Response(200, json=10) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_instances.get_count( + dag_id="test_dag", + task_ids=task_ids, + task_group_id="group1", + logical_dates=logical_dates, + states=states, + ) + assert result.count == 10 + class TestVariableOperations: """ @@ -904,6 +946,58 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result == DagRunStateResponse(state=DagRunState.RUNNING) + def test_get_count_basic(self): + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/count": + assert request.url.params["dag_id"] == "test_dag" + return httpx.Response(status_code=200, json=1) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_count(dag_id="test_dag") + assert result.count == 1 + + def test_get_count_with_states(self): + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/count": + assert request.url.params["dag_id"] == "test_dag" + assert request.url.params.get_list("states") == ["success", "failed"] + return httpx.Response(status_code=200, json=2) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_count(dag_id="test_dag", states=["success", "failed"]) + assert result.count == 2 + + def test_get_count_with_logical_dates(self): + logical_dates = [timezone.datetime(2025, 1, 1), timezone.datetime(2025, 1, 2)] + logical_dates_str = [d.isoformat() for d in logical_dates] + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/count": + assert request.url.params["dag_id"] == "test_dag" + assert request.url.params.get_list("logical_dates") == logical_dates_str + return httpx.Response(status_code=200, json=2) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_count( + dag_id="test_dag", logical_dates=[timezone.datetime(2025, 1, 1), timezone.datetime(2025, 1, 2)] + ) + assert result.count == 2 + + def test_get_count_with_run_ids(self): + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dag-runs/count": + assert request.url.params["dag_id"] == "test_dag" + assert request.url.params.get_list("run_ids") == ["run1", "run2"] + return httpx.Response(status_code=200, json=2) + return httpx.Response(status_code=422) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1", "run2"]) + assert result.count == 2 + class TestTaskRescheduleOperations: def test_get_start_date(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 6490672621965..5cae5232e64e3 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -60,6 +60,7 @@ DagRunStateResult, DeferTask, DeleteXCom, + DRCount, ErrorResponse, GetAssetByName, GetAssetByUri, @@ -67,8 +68,10 @@ GetAssetEventByAssetAlias, GetConnection, GetDagRunState, + GetDRCount, GetPrevSuccessfulDagRun, GetTaskRescheduleStartDate, + GetTICount, GetVariable, GetXCom, OKResponse, @@ -80,6 +83,7 @@ SucceedTask, TaskRescheduleStartDate, TaskState, + TICount, TriggerDagRun, VariableResult, XComResult, @@ -1350,6 +1354,36 @@ def watched_subprocess(self, mocker): TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")), id="get_task_reschedule_start_date", ), + pytest.param( + GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), + b'{"count":2,"type":"TICount"}\n', + "task_instances.get_count", + (), + { + "dag_id": "test_dag", + "logical_dates": None, + "run_ids": None, + "states": None, + "task_group_id": None, + "task_ids": ["task1", "task2"], + }, + TICount(count=2), + id="get_ti_count", + ), + pytest.param( + GetDRCount(dag_id="test_dag", states=["success", "failed"]), + b'{"count":2,"type":"DRCount"}\n', + "dag_runs.get_count", + (), + { + "dag_id": "test_dag", + "logical_dates": None, + "run_ids": None, + "states": ["success", "failed"], + }, + DRCount(count=2), + id="get_dr_count", + ), ], ) def test_handle_requests( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 150b691fc179d..f7ac25f6ae89d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -67,9 +67,12 @@ ConnectionResult, DagRunStateResult, DeferTask, + DRCount, ErrorResponse, GetConnection, GetDagRunState, + GetDRCount, + GetTICount, GetVariable, GetXCom, OKResponse, @@ -81,6 +84,7 @@ SucceedTask, TaskRescheduleStartDate, TaskState, + TICount, TriggerDagRun, VariableResult, XComResult, @@ -1396,6 +1400,54 @@ def test_get_first_reschedule_date( context = runtime_ti.get_template_context() assert runtime_ti.get_first_reschedule_date(context=context) == expected_date + def test_get_ti_count(self, mock_supervisor_comms): + """Test that get_ti_count sends the correct request and returns the count.""" + mock_supervisor_comms.get_message.return_value = TICount(count=2) + + count = RuntimeTaskInstance.get_ti_count( + dag_id="test_dag", + task_ids=["task1", "task2"], + task_group_id="group1", + logical_dates=[timezone.datetime(2024, 1, 1)], + run_ids=["run1"], + states=["success", "failed"], + ) + + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, + msg=GetTICount( + dag_id="test_dag", + task_ids=["task1", "task2"], + task_group_id="group1", + logical_dates=[timezone.datetime(2024, 1, 1)], + run_ids=["run1"], + states=["success", "failed"], + ), + ) + assert count == 2 + + def test_get_dr_count(self, mock_supervisor_comms): + """Test that get_dr_count sends the correct request and returns the count.""" + mock_supervisor_comms.get_message.return_value = DRCount(count=2) + + count = RuntimeTaskInstance.get_dr_count( + dag_id="test_dag", + logical_dates=[timezone.datetime(2024, 1, 1)], + run_ids=["run1"], + states=["success", "failed"], + ) + + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, + msg=GetDRCount( + dag_id="test_dag", + logical_dates=[timezone.datetime(2024, 1, 1)], + run_ids=["run1"], + states=["success", "failed"], + ), + ) + assert count == 2 + class TestXComAfterTaskExecution: @pytest.mark.parametrize(