Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ def create_app(apps: str = "all") -> FastAPI:
dag_bag = create_dag_bag()

if "execution" in apps_list or "all" in apps_list:
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

task_exec_api_app = create_task_execution_api_app()
task_exec_api_app.state.dag_bag = SchedulerDagBag()
task_exec_api_app.state.dag_bag = dag_bag
init_error_handlers(task_exec_api_app)
app.mount("/execution", task_exec_api_app)

Expand Down
11 changes: 5 additions & 6 deletions airflow-core/src/airflow/api_fastapi/common/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@

from fastapi import Depends, Request

from airflow.models.dagbag import DagBag
from airflow.settings import DAGS_FOLDER
from airflow.models.dagbag import SchedulerDagBag


def create_dag_bag() -> DagBag:
def create_dag_bag() -> SchedulerDagBag:
"""Create DagBag to retrieve DAGs from the database."""
return DagBag(DAGS_FOLDER, read_dags_from_db=True)
return SchedulerDagBag()


def dag_bag_from_app(request: Request) -> DagBag:
def dag_bag_from_app(request: Request) -> SchedulerDagBag:
"""
FastAPI dependency resolver that returns the shared DagBag instance from app.state.

Expand All @@ -39,4 +38,4 @@ def dag_bag_from_app(request: Request) -> DagBag:
return request.app.state.dag_bag


DagBagDep = Annotated[DagBag, Depends(dag_bag_from_app)]
DagBagDep = Annotated[SchedulerDagBag, Depends(dag_bag_from_app)]
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DAGRunClearBody(StrictBaseModel):
only_failed: bool = False
run_on_latest_version: bool = Field(
default=False,
description="(Experimental) Run on the latest bundle version of the DAG after clearing the DAG Run.",
description="(Experimental) Run on the latest bundle version of the Dag after clearing the Dag Run.",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ClearTaskInstancesBody(StrictBaseModel):
include_past: bool = False
run_on_latest_version: bool = Field(
default=False,
description="(Experimental) Run on the latest bundle version of the DAG after "
description="(Experimental) Run on the latest bundle version of the dag after "
"clearing the task instances.",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8690,7 +8690,7 @@ components:
run_on_latest_version:
type: boolean
title: Run On Latest Version
description: (Experimental) Run on the latest bundle version of the DAG
description: (Experimental) Run on the latest bundle version of the dag
after clearing the task instances.
default: false
additionalProperties: false
Expand Down Expand Up @@ -9333,8 +9333,8 @@ components:
run_on_latest_version:
type: boolean
title: Run On Latest Version
description: (Experimental) Run on the latest bundle version of the DAG
after clearing the DAG Run.
description: (Experimental) Run on the latest bundle version of the Dag
after clearing the Dag Run.
default: false
additionalProperties: false
type: object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def materialize_asset(
)

dag: DAG | None
if not (dag := dag_bag.get_dag(dag_id)):
if not (dag := dag_bag.get_latest_version_of_dag(dag_id, session)):
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with ID `{dag_id}` was not found")

return dag.create_dagrun(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import ParamValidationError
from airflow.listeners.listener import get_listener_manager
from airflow.models import DAG, DagModel, DagRun
from airflow.models import DagModel, DagRun
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -168,7 +168,7 @@ def patch_dag_run(
f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found",
)

dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_dag_for_run(dag_run, session=session)

if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
Expand Down Expand Up @@ -274,9 +274,11 @@ def clear_dag_run(
f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found",
)

dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_dag_for_run(dag_run, session=session)

if body.dry_run:
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
task_instances = dag.clear(
run_id=dag_run_id,
task_ids=None,
Expand All @@ -290,6 +292,8 @@ def clear_dag_run(
task_instances=cast("list[TaskInstanceResponse]", task_instances),
total_entries=len(task_instances),
)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
dag.clear(
run_id=dag_run_id,
task_ids=None,
Expand Down Expand Up @@ -352,7 +356,7 @@ def get_dag_runs(
query = select(DagRun)

if dag_id != "~":
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found")

Expand Down Expand Up @@ -417,7 +421,9 @@ def trigger_dag_run(
)

try:
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with dag_id: '{dag_id}' not found")
params = body.validate_context(dag)

dag_run = dag.create_dagrun(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion

dag_versions_router = AirflowRouter(tags=["DagVersion"], prefix="/dags/{dag_id}/dagVersions")
Expand Down Expand Up @@ -112,7 +111,7 @@ def get_dag_versions(
query = select(DagVersion).options(joinedload(DagVersion.dag_model))

if dag_id != "~":
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DAG, DagModel
from airflow.models import DagModel
from airflow.models.dag_favorite import DagFavorite
from airflow.models.dagrun import DagRun

Expand Down Expand Up @@ -172,7 +172,7 @@ def get_dag(
dag_bag: DagBagDep,
) -> DAGResponse:
"""Get basic information about a DAG."""
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")

Expand All @@ -199,8 +199,7 @@ def get_dag(
)
def get_dag_details(dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> DAGDetailsResponse:
"""Get details of DAG."""
# todo: can we use lazy deser dag here?
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import DagAccessEntity, requires_access_dag
from airflow.exceptions import TaskNotFound
from airflow.models import DagRun

if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator


extra_links_router = AirflowRouter(
tags=["Extra Links"], prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/links"
)
Expand All @@ -57,7 +57,13 @@ def get_extra_links(
"""Get extra links for task instance."""
from airflow.models.taskinstance import TaskInstance

if (dag := dag_bag.get_dag(dag_id)) is None:
dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id))

if dag_run:
dag = dag_bag.get_dag_for_run(dag_run, session=session)
else:
dag = dag_bag.get_latest_version_of_dag(dag_id, session=session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with ID = {dag_id} not found")

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_log(
metadata["end_of_log"] = True
raise HTTPException(status.HTTP_404_NOT_FOUND, "TaskInstance not found")

dag = dag_bag.get_dag(dag_id)
dag = dag_bag.get_dag_for_run(ti.dag_run, session=session)
if dag:
with contextlib.suppress(TaskNotFound):
ti.task = dag.get_task(ti.task_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ def get_mapped_task_instances(
# 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404
unfiltered_total_count = get_query_count(query, session=session)
if unfiltered_total_count == 0:
dag = dag_bag.get_dag(dag_id)
dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id))
if dag_run:
dag = dag_bag.get_dag_for_run(dag_run, session=session)
else:
dag = dag_bag.get_latest_version_of_dag(dag_id, session=session)
if not dag:
error_message = f"DAG {dag_id} not found"
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
Expand Down Expand Up @@ -258,7 +262,8 @@ def get_task_instance_dependencies(
deps = []

if ti.state in [None, TaskInstanceState.SCHEDULED]:
dag = dag_bag.get_dag(ti.dag_id)
dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == ti.dag_id, DagRun.run_id == ti.run_id))
dag = dag_bag.get_dag_for_run(dag_run, session=session)

if dag:
try:
Expand Down Expand Up @@ -437,20 +442,14 @@ def get_task_instances(
This endpoint allows specifying `~` as the dag_id, dag_run_id to retrieve Task Instances for all DAGs
and DAG runs.
"""
dag_run = None
query = (
select(TI)
.join(TI.dag_run)
.outerjoin(TI.dag_version)
.options(joinedload(TI.dag_version))
.options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model)))
)

if dag_id != "~":
dag = dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: `{dag_id}` was not found")
query = query.where(TI.dag_id == dag_id)

if dag_run_id != "~":
dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id))
if not dag_run:
Expand All @@ -459,6 +458,14 @@ def get_task_instances(
f"DagRun with run_id: `{dag_run_id}` was not found",
)
query = query.where(TI.run_id == dag_run_id)
if dag_id != "~":
if dag_run:
dag = dag_bag.get_dag_for_run(dag_run, session)
else:
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with dag_id: `{dag_id}` was not found")
query = query.where(TI.dag_id == dag_id)

task_instance_select, total_entries = paginated_select(
statement=query,
Expand Down Expand Up @@ -654,7 +661,7 @@ def post_clear_task_instances(
session: SessionDep,
) -> TaskInstanceCollectionResponse:
"""Clear task instances."""
dag = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
error_message = f"DAG {dag_id} not found"
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
Expand All @@ -675,11 +682,10 @@ def post_clear_task_instances(
if dag_run is None:
error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
# If dag_run_id is provided, we should get the dag from SchedulerDagBag
# to ensure we get the right version.
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

dag = SchedulerDagBag().get_dag(dag_run=dag_run, session=session)
# Get the specific dag version:
dag = dag_bag.get_dag_for_run(dag_run=dag_run, session=session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")
if past or future:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -707,21 +713,30 @@ def post_clear_task_instances(
# If we had upstream/downstream etc then also include those!
task_ids.extend(tid for tid in dag.task_dict if tid != task_id)

task_instances = dag.clear(
dry_run=True,
run_id=None if past or future else dag_run_id,
task_ids=task_ids,
dag_bag=dag_bag,
session=session,
**body.model_dump(
include={
"start_date",
"end_date",
"only_failed",
"only_running",
}
),
)
# Prepare common parameters
common_params = {
"dry_run": True,
"task_ids": task_ids,
"dag_bag": dag_bag,
"session": session,
"run_on_latest_version": body.run_on_latest_version,
"only_failed": body.only_failed,
"only_running": body.only_running,
}

if dag_run_id is not None and not (past or future):
# Use run_id-based clearing when we have a specific dag_run_id and not using past/future
task_instances = dag.clear(
**common_params,
run_id=dag_run_id,
)
else:
# Use date-based clearing when no dag_run_id or when past/future is specified
task_instances = dag.clear(
**common_params,
start_date=body.start_date,
end_date=body.end_date,
)

if not dry_run:
clear_task_instances(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@

from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
from airflow.api_fastapi.common.dagbag import DagBagDep
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.tasks import TaskCollectionResponse, TaskResponse
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.exceptions import TaskNotFound
from airflow.models import DAG

tasks_router = AirflowRouter(tags=["Task"], prefix="/dags/{dag_id}/tasks")

Expand All @@ -47,10 +47,11 @@
def get_tasks(
dag_id: str,
dag_bag: DagBagDep,
session: SessionDep,
order_by: str = "task_id",
) -> TaskCollectionResponse:
"""Get tasks for DAG."""
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
try:
Expand All @@ -73,9 +74,9 @@ def get_tasks(
),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK))],
)
def get_task(dag_id: str, task_id, dag_bag: DagBagDep) -> TaskResponse:
def get_task(dag_id: str, task_id, session: SessionDep, dag_bag: DagBagDep) -> TaskResponse:
"""Get simplified representation of a task."""
dag: DAG = dag_bag.get_dag(dag_id)
dag = dag_bag.get_latest_version_of_dag(dag_id, session)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
try:
Expand Down
Loading