Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
from airflow.utils.state import (
DagRunState,
IntermediateTIState,
TaskInstanceState as TIState,
TerminalTIState,
)
from airflow.utils.types import DagRunType

AwareDatetimeAdapter = TypeAdapter(AwareDatetime)
Expand Down Expand Up @@ -292,6 +297,7 @@ class DagRun(StrictBaseModel):
end_date: UtcDateTime | None
clear_number: int = 0
run_type: DagRunType
state: DagRunState
conf: Annotated[dict[str, Any], Field(default_factory=dict)]
consumed_asset_events: list[AssetEventDagRunReference]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload
from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun
from airflow.exceptions import DagRunAlreadyExists
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
from airflow.models.dagrun import DagRun as DagRunModel
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType

router = APIRouter()
Expand Down Expand Up @@ -123,7 +125,9 @@ def clear_dag_run(
)
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id))
dag_run = session.scalar(
select(DagRunModel).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id)
)
dag_bag = SchedulerDagBag()
dag = dag_bag.get_dag(dag_run=dag_run, session=session)
if not dag:
Expand All @@ -150,7 +154,9 @@ def get_dagrun_state(
session: SessionDep,
) -> DagRunStateResponse:
"""Get a DAG Run State."""
dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id))
dag_run = session.scalar(
select(DagRunModel).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id)
)
if dag_run is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
Expand All @@ -172,16 +178,45 @@ def get_dr_count(
states: Annotated[list[str] | None, Query()] = None,
) -> int:
"""Get the count of DAG runs matching the given criteria."""
query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == dag_id)
query = select(func.count()).select_from(DagRunModel).where(DagRunModel.dag_id == dag_id)

if logical_dates:
query = query.where(DagRun.logical_date.in_(logical_dates))
query = query.where(DagRunModel.logical_date.in_(logical_dates))

if run_ids:
query = query.where(DagRun.run_id.in_(run_ids))
query = query.where(DagRunModel.run_id.in_(run_ids))

if states:
query = query.where(DagRun.state.in_(states))
query = query.where(DagRunModel.state.in_(states))

count = session.scalar(query)
return count or 0


@router.get("/{dag_id}/previous", status_code=status.HTTP_200_OK)
def get_previous_dagrun(
dag_id: str,
logical_date: UtcDateTime,
session: SessionDep,
state: Annotated[DagRunState | None, Query()] = None,
) -> DagRun | None:
"""Get the previous DAG run before the given logical date, optionally filtered by state."""
query = (
select(DagRunModel)
.where(
DagRunModel.dag_id == dag_id,
DagRunModel.logical_date < logical_date,
)
.order_by(DagRunModel.logical_date.desc())
.limit(1)
)

if state:
query = query.where(DagRunModel.state == state)

dag_run = session.scalar(query)

if not dag_run:
return None

return DagRun.model_validate(dag_run)
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@

from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField
from airflow.api_fastapi.execution_api.versions.v2025_05_20 import DowngradeUpstreamMapIndexes
from airflow.api_fastapi.execution_api.versions.v2025_08_10 import AddDagVersionIdField
from airflow.api_fastapi.execution_api.versions.v2025_08_10 import (
AddDagRunStateFieldAndPreviousEndpoint,
AddDagVersionIdField,
)

bundle = VersionBundle(
HeadVersion(),
Version("2025-08-10", AddDagVersionIdField),
Version("2025-08-10", AddDagVersionIdField, AddDagRunStateFieldAndPreviousEndpoint),
Version("2025-05-20", DowngradeUpstreamMapIndexes),
Version("2025-04-28", AddRenderedMapIndexField),
Version("2025-04-11"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

from __future__ import annotations

from cadwyn import VersionChange, schema
from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TaskInstance, TIRunContext


class AddDagVersionIdField(VersionChange):
Expand All @@ -28,3 +28,20 @@ class AddDagVersionIdField(VersionChange):
description = __doc__

instructions_to_migrate_to_previous_version = (schema(TaskInstance).field("dag_version_id").didnt_exist,)


class AddDagRunStateFieldAndPreviousEndpoint(VersionChange):
"""Add the `state` field to DagRun model and `/dag-runs/{dag_id}/previous` endpoint."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(DagRun).field("state").didnt_exist,
endpoint("/dag-runs/{dag_id}/previous", ["GET"]).didnt_exist,
)

@convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type]
def remove_state_from_dag_run(response: ResponseInfo) -> None: # type: ignore[misc]
"""Remove the `state` field from the dag_run object when converting to the previous version."""
if "dag_run" in response.body and isinstance(response.body["dag_run"], dict):
response.body["dag_run"].pop("state", None)
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,142 @@ def test_get_count_with_mixed_states(self, client, session, dag_maker):
)
assert response.status_code == 200
assert response.json() == 2


class TestGetPreviousDagRun:
def setup_method(self):
clear_db_runs()

def teardown_method(self):
clear_db_runs()

def test_get_previous_dag_run_basic(self, client, session, dag_maker):
"""Test getting the previous DAG run without state filtering."""
dag_id = "test_get_previous_basic"

with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

# Create multiple DAG runs
dag_maker.create_dagrun(
run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS
)
dag_maker.create_dagrun(
run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED
)
dag_maker.create_dagrun(
run_id="run3", logical_date=timezone.datetime(2025, 1, 10), state=DagRunState.SUCCESS
)
session.commit()

# Query for previous DAG run before 2025-01-10
response = client.get(
f"/execution/dag-runs/{dag_id}/previous",
params={
"logical_date": timezone.datetime(2025, 1, 10).isoformat(),
},
)

assert response.status_code == 200
result = response.json()
assert result["dag_id"] == dag_id
assert result["run_id"] == "run2" # Most recent before 2025-01-10
assert result["state"] == "failed"

def test_get_previous_dag_run_with_state_filter(self, client, session, dag_maker):
"""Test getting the previous DAG run with state filtering."""
dag_id = "test_get_previous_with_state"

with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

# Create multiple DAG runs with different states
dag_maker.create_dagrun(
run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS
)
dag_maker.create_dagrun(
run_id="run2", logical_date=timezone.datetime(2025, 1, 5), state=DagRunState.FAILED
)
dag_maker.create_dagrun(
run_id="run3", logical_date=timezone.datetime(2025, 1, 8), state=DagRunState.SUCCESS
)
session.commit()

# Query for previous successful DAG run before 2025-01-10
response = client.get(
f"/execution/dag-runs/{dag_id}/previous",
params={"logical_date": timezone.datetime(2025, 1, 10).isoformat(), "state": "success"},
)

assert response.status_code == 200
result = response.json()
assert result["dag_id"] == dag_id
assert result["run_id"] == "run3" # Most recent successful run before 2025-01-10
assert result["state"] == "success"

def test_get_previous_dag_run_no_previous_found(self, client, session, dag_maker):
"""Test getting previous DAG run when none exists returns null."""
dag_id = "test_get_previous_none"

with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

# Create only one DAG run - no previous should exist
dag_maker.create_dagrun(
run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS
)

response = client.get(f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-01T00:00:00Z")

assert response.status_code == 200
assert response.json() is None # Should return null

def test_get_previous_dag_run_no_matching_state(self, client, session, dag_maker):
"""Test getting previous DAG run with state filter that matches nothing returns null."""
dag_id = "test_get_previous_no_match"

with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

# Create DAG runs with different states
dag_maker.create_dagrun(
run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.FAILED
)
dag_maker.create_dagrun(
run_id="run2", logical_date=timezone.datetime(2025, 1, 2), state=DagRunState.FAILED
)

# Look for previous success but only failed runs exist
response = client.get(
f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-03T00:00:00Z&state=success"
)

assert response.status_code == 200
assert response.json() is None

def test_get_previous_dag_run_dag_not_found(self, client, session):
"""Test getting previous DAG run for non-existent DAG returns 404."""
response = client.get(
"/execution/dag-runs/nonexistent_dag/previous?logical_date=2025-01-01T00:00:00Z"
)

assert response.status_code == 200
assert response.json() is None

def test_get_previous_dag_run_invalid_state_parameter(self, client, session, dag_maker):
"""Test that invalid state parameter returns 422 validation error."""
dag_id = "test_get_previous_invalid_state"

with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

dag_maker.create_dagrun(
run_id="run1", logical_date=timezone.datetime(2025, 1, 1), state=DagRunState.SUCCESS
)
session.commit()

response = client.get(
f"/execution/dag-runs/{dag_id}/previous?logical_date=2025-01-02T00:00:00Z&state=invalid_state"
)

assert response.status_code == 422
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import Asset, TaskGroup, task, task_group
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState
from airflow.utils.state import DagRunState, State, TaskInstanceState, TerminalTIState

from tests_common.test_utils.db import (
clear_db_assets,
Expand Down Expand Up @@ -155,6 +155,7 @@ def test_ti_run_state_to_running(
ti = create_task_instance(
task_id="test_ti_run_state_to_running",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=instant,
dag_id=str(uuid4()),
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_ti_run_state_to_running(
"data_interval_end": instant_str,
"run_after": instant_str,
"start_date": instant_str,
"state": "running",
"end_date": None,
"run_type": "manual",
"conf": {},
Expand Down
2 changes: 2 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.sdk import DAG
from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import DagRunState
from airflow.sdk.execution_time import comms
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -691,6 +692,7 @@ def fake_collect_dags(self, *args, **kwargs):
logical_date=timezone.utcnow(),
start_date=timezone.utcnow(),
run_type="manual",
state=DagRunState.RUNNING,
)
dag_run.run_after = timezone.utcnow()

Expand Down
27 changes: 15 additions & 12 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,7 @@ def _create_task_instance(
should_retry: bool | None = None,
max_tries: int | None = None,
) -> RuntimeTaskInstance:
from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext
from airflow.utils.types import DagRunType

if not ti_id:
Expand Down Expand Up @@ -2267,17 +2267,20 @@ def _create_task_instance(
run_after = data_interval_end or logical_date or timezone.utcnow()

ti_context = TIRunContext(
dag_run=DagRun(
dag_id=dag_id,
run_id=run_id,
logical_date=logical_date, # type: ignore
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
start_date=start_date, # type: ignore
run_type=run_type, # type: ignore
run_after=run_after, # type: ignore
conf=conf,
consumed_asset_events=[],
dag_run=DagRun.model_validate(
{
"dag_id": dag_id,
"run_id": run_id,
"logical_date": logical_date, # type: ignore
"data_interval_start": data_interval_start,
"data_interval_end": data_interval_end,
"start_date": start_date, # type: ignore
"run_type": run_type, # type: ignore
"run_after": run_after, # type: ignore
"conf": conf,
"consumed_asset_events": [],
**({"state": DagRunState.RUNNING} if "state" in DagRun.model_fields else {}),
}
),
task_reschedule_count=task_reschedule_count,
max_tries=task_retries if max_tries is None else max_tries,
Expand Down
Loading
Loading