Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from __future__ import annotations

import logging
from typing import Annotated

from fastapi import HTTPException, status
from sqlalchemy import select
from fastapi import HTTPException, Query, status
from sqlalchemy import func, select

from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload
from airflow.exceptions import DagRunAlreadyExists
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -150,3 +152,27 @@ def get_dagrun_state(
)

return DagRunStateResponse(state=dag_run.state)


@router.get("/count", status_code=status.HTTP_200_OK)
def get_dr_count(
dag_id: str,
session: SessionDep,
logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
run_ids: Annotated[list[str] | None, Query()] = None,
states: Annotated[list[str] | None, Query()] = None,
) -> int:
"""Get the count of DAG runs matching the given criteria."""
query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id)

if logical_dates:
query = query.where(DagRun.logical_date.in_(logical_dates))

if run_ids:
query = query.where(DagRun.run_id.in_(run_ids))

if states:
query = query.where(DagRun.state.in_(states))

count = session.scalar(query)
return count or 0
Comment on lines +157 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to add cadwyn migration for the new endpoints: https://docs.cadwyn.dev/concepts/version_changes/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We only need to add a migration for breaking changes (or changes to existing endpoints) from what I understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just had a chat with the Cadwyn Author ( @zmievsa ) who recommends to only add it for breaking changes.

Depending on your needs. My general recommendation is to only add migrations for breaking changes
https://docs.cadwyn.dev/how_to/change_endpoints/#add-a-new-endpoint

Copy link

@zmievsa zmievsa Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make Cadwyn's docs more verbose on when it makes the most sense to add a migration. Concepts section mostly focuses on what's possible with Cadwyn while "how to" focuses on what you should actually do.

Either way 99% of the time it makes sense to add an endpoint to all versions since it's not a breaking change. Your users will thank you later

Update: https://docs.cadwyn.dev/concepts/endpoint_migrations/#defining-endpoints-that-didnt-exist-in-old-versions added a bunch of notes here and there about this.

Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from uuid import UUID

from cadwyn import VersionedAPIRouter
from fastapi import Body, Depends, HTTPException, status
from fastapi import Body, Depends, HTTPException, Query, status
from pydantic import JsonValue
from sqlalchemy import func, tuple_, update
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.sql import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
PrevSuccessfulDagRunResponse,
TIDeferredStatePayload,
Expand All @@ -45,6 +46,7 @@
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearer
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
Expand All @@ -53,7 +55,9 @@
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState

router = VersionedAPIRouter(
router = VersionedAPIRouter()

ti_id_router = VersionedAPIRouter(
Comment on lines +58 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here explaining why we need this one? It will be clear to reader too

dependencies=[
# This checks that the UUID in the url matches the one in the token for us.
Depends(JWTBearer(path_param_name="task_instance_id")),
Expand All @@ -64,7 +68,7 @@
log = logging.getLogger(__name__)


@router.patch(
@ti_id_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
responses={
Expand Down Expand Up @@ -243,7 +247,7 @@ def ti_run(
)


@router.patch(
@ti_id_router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
responses={
Expand Down Expand Up @@ -404,7 +408,7 @@ def ti_update_state(
)


@router.patch(
@ti_id_router.patch(
"/{task_instance_id}/skip-downstream",
status_code=status.HTTP_204_NO_CONTENT,
responses={
Expand Down Expand Up @@ -436,7 +440,7 @@ def ti_skip_downstream(
log.info("TI %s updated the state of %s task(s) to skipped", ti_id_str, result.rowcount)


@router.put(
@ti_id_router.put(
"/{task_instance_id}/heartbeat",
status_code=status.HTTP_204_NO_CONTENT,
responses={
Expand Down Expand Up @@ -498,7 +502,7 @@ def ti_heartbeat(
log.debug("Task with %s state heartbeated", previous_state)


@router.put(
@ti_id_router.put(
"/{task_instance_id}/rtif",
status_code=status.HTTP_201_CREATED,
# TODO: Add description to the operation
Expand Down Expand Up @@ -528,7 +532,7 @@ def ti_put_rtif(
return {"message": "Rendered task instance fields successfully set"}


@router.get(
@ti_id_router.get(
"/{task_instance_id}/previous-successful-dagrun",
status_code=status.HTTP_200_OK,
responses={
Expand Down Expand Up @@ -564,8 +568,86 @@ def get_previous_successful_dagrun(
return PrevSuccessfulDagRunResponse.model_validate(dag_run)


@router.only_exists_in_older_versions
@router.post(
@router.get("/count", status_code=status.HTTP_200_OK)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to add cadwyn migration for the new endpoints: https://docs.cadwyn.dev/concepts/version_changes/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We only need to add a migration for breaking changes (or changes to existing endpoints) from what I understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check #48651 (comment) :)

def get_count(
dag_id: str,
session: SessionDep,
task_ids: Annotated[list[str] | None, Query()] = None,
task_group_id: Annotated[str | None, Query()] = None,
logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
run_ids: Annotated[list[str] | None, Query()] = None,
states: Annotated[list[str] | None, Query()] = None,
) -> int:
"""Get the count of task instances matching the given criteria."""
query = select(func.count()).select_from(TI).where(TI.dag_id == dag_id)

if task_ids:
query = query.where(TI.task_id.in_(task_ids))

if logical_dates:
query = query.where(TI.logical_date.in_(logical_dates))

if run_ids:
query = query.where(TI.run_id.in_(run_ids))

if task_group_id:
# Get all tasks in the task group
dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session)
if not dag:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"DAG {dag_id} not found",
},
)

task_group = dag.task_group_dict.get(task_group_id)
if not task_group:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Task group {task_group_id} not found in DAG {dag_id}",
},
)

# First get all task instances to get the task_id, map_index pairs
group_tasks = session.scalars(
select(TI).where(
TI.dag_id == dag_id,
TI.task_id.in_(task.task_id for task in task_group.iter_tasks()),
*([TI.logical_date.in_(logical_dates)] if logical_dates else []),
*([TI.run_id.in_(run_ids)] if run_ids else []),
)
).all()

# Get unique (task_id, map_index) pairs
task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks]
if not task_map_pairs:
# If no task group tasks found, default to checking the task group ID itself
# This matches the behavior in _get_external_task_group_task_ids
task_map_pairs = [(task_group_id, -1)]

# Update query to use task_id, map_index pairs
query = query.where(tuple_(TI.task_id, TI.map_index).in_(task_map_pairs))

if states:
if "null" in states:
not_none_states = [s for s in states if s != "null"]
if not_none_states:
query = query.where(or_(TI.state.is_(None), TI.state.in_(not_none_states)))
else:
query = query.where(TI.state.is_(None))
else:
query = query.where(TI.state.in_(states))

count = session.scalar(query)
return count or 0


@ti_id_router.only_exists_in_older_versions
@ti_id_router.post(
"/{task_instance_id}/runtime-checks",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add description to the operation
Expand Down Expand Up @@ -602,3 +684,7 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
# max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for
# retries from the task SDK now, we can handle using max_tries
return max_tries != 0 and try_number <= max_tries


# This line should be at the end of the file to ensure all routes are registered
router.include_router(ti_id_router)
11 changes: 0 additions & 11 deletions airflow-core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,6 @@ def clear_all_logger_handlers():
remove_all_non_pytest_log_handlers()


@pytest.fixture
def testing_dag_bundle():
from airflow.models.dagbundle import DagBundleModel
from airflow.utils.session import create_session

with create_session() as session:
if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0:
testing = DagBundleModel(name="testing")
session.add(testing)


@contextmanager
def _config_bundles(bundles: dict[str, Path | str]):
from tests_common.test_utils.config import conf_vars
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from airflow.models.dagrun import DagRun
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.state import DagRunState, State

from tests_common.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -218,3 +218,99 @@ def test_dag_run_not_found(self, client):
response = client.post(f"/execution/dag-runs/{dag_id}/{run_id}/clear")

assert response.status_code == 404


class TestGetDagRunCount:
def setup_method(self):
clear_db_runs()

def teardown_method(self):
clear_db_runs()

def test_get_count_basic(self, client, session, dag_maker):
with dag_maker("test_dag"):
pass
dag_maker.create_dagrun()
session.commit()

response = client.get("/execution/dag-runs/count", params={"dag_id": "test_dag"})
assert response.status_code == 200
assert response.json() == 1

def test_get_count_with_states(self, client, session, dag_maker):
"""Test counting DAG runs in specific states."""
with dag_maker("test_get_count_with_states"):
pass

# Create DAG runs with different states
dag_maker.create_dagrun(
state=State.SUCCESS, logical_date=timezone.datetime(2025, 1, 1), run_id="test_run_id1"
)
dag_maker.create_dagrun(
state=State.FAILED, logical_date=timezone.datetime(2025, 1, 2), run_id="test_run_id2"
)
dag_maker.create_dagrun(
state=State.RUNNING, logical_date=timezone.datetime(2025, 1, 3), run_id="test_run_id3"
)
session.commit()

response = client.get(
"/execution/dag-runs/count",
params={"dag_id": "test_get_count_with_states", "states": [State.SUCCESS, State.FAILED]},
)
assert response.status_code == 200
assert response.json() == 2

def test_get_count_with_logical_dates(self, client, session, dag_maker):
with dag_maker("test_get_count_with_logical_dates"):
pass

date1 = timezone.datetime(2025, 1, 1)
date2 = timezone.datetime(2025, 1, 2)

dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1)
dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2)
session.commit()

response = client.get(
"/execution/dag-runs/count",
params={
"dag_id": "test_get_count_with_logical_dates",
"logical_dates": [date1.isoformat(), date2.isoformat()],
},
)
assert response.status_code == 200
assert response.json() == 2

def test_get_count_with_run_ids(self, client, session, dag_maker):
with dag_maker("test_get_count_with_run_ids"):
pass

dag_maker.create_dagrun(run_id="run1", logical_date=timezone.datetime(2025, 1, 1))
dag_maker.create_dagrun(run_id="run2", logical_date=timezone.datetime(2025, 1, 2))
session.commit()

response = client.get(
"/execution/dag-runs/count",
params={"dag_id": "test_get_count_with_run_ids", "run_ids": ["run1", "run2"]},
)
assert response.status_code == 200
assert response.json() == 2

def test_get_count_with_mixed_states(self, client, session, dag_maker):
with dag_maker("test_get_count_with_mixed"):
pass
dag_maker.create_dagrun(
state=State.SUCCESS, run_id="runid1", logical_date=timezone.datetime(2025, 1, 1)
)
dag_maker.create_dagrun(
state=State.QUEUED, run_id="runid2", logical_date=timezone.datetime(2025, 1, 2)
)
session.commit()

response = client.get(
"/execution/dag-runs/count",
params={"dag_id": "test_get_count_with_mixed", "states": [State.SUCCESS, State.QUEUED]},
)
assert response.status_code == 200
assert response.json() == 2
Loading