diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index 6382756b14a77..ea818b6fb7043 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -30,7 +30,7 @@ from airflow.cli.cli_config import GroupCommand from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor -from airflow.models.taskinstance import TaskInstance, TaskInstanceState +from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.sdk import timezone from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge3.models.edge_job import EdgeJobModel @@ -40,6 +40,7 @@ from airflow.stats import Stats from airflow.utils.db import DBLocks, create_global_lock from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: import argparse @@ -68,7 +69,8 @@ def _check_db_schema(self, engine: Engine) -> None: """ Check if already existing table matches the newest table schema. - workaround till Airflow 3.0.0, then it is possible to use alembic also for provider distributions. + workaround till support for Airflow 2.x is dropped, + then it is possible to use alembic also for provider distributions. """ inspector = inspect(engine) edge_job_columns = None @@ -78,7 +80,7 @@ def _check_db_schema(self, engine: Engine) -> None: edge_job_columns = [column["name"] for column in edge_job_schema] for column in edge_job_schema: if column["name"] == "command": - edge_job_command_len = column["type"].length + edge_job_command_len = column["type"].length # type: ignore[attr-defined] # version 0.6.0rc1 added new column concurrency_slots if edge_job_columns and "concurrency_slots" not in edge_job_columns: @@ -284,7 +286,7 @@ def _update_orphaned_jobs(self, session: Session) -> bool: map_index=job.map_index, session=session, ) - job.state = ti.state if ti else TaskInstanceState.REMOVED + job.state = ti.state if ti and ti.state else TaskInstanceState.REMOVED if job.state != TaskInstanceState.RUNNING: # Edge worker does not backport emitted Airflow metrics, so export some metrics diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py index e836716fe2e5e..ccf21de848fcd 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py +++ b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py @@ -94,4 +94,4 @@ def key(self): @property def last_update_t(self) -> float: - return self.last_update.timestamp() + return self.last_update.timestamp() if self.last_update else datetime.now().timestamp() diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py index 9eda47517edcb..7d4d53d39f16a 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py +++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py @@ -230,7 +230,9 @@ def request_maintenance( ) -> None: """Write maintenance request to the db.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") worker.state = EdgeWorkerState.MAINTENANCE_REQUEST worker.maintenance_comment = maintenance_comment @@ -239,7 +241,9 @@ def request_maintenance( def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None: """Write maintenance exit to the db.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") worker.state = EdgeWorkerState.MAINTENANCE_EXIT worker.maintenance_comment = None @@ -248,7 +252,9 @@ def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None: def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None: """Remove a worker that is offline or just gone from DB.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") if worker.state in ( EdgeWorkerState.OFFLINE, EdgeWorkerState.OFFLINE_MAINTENANCE, @@ -267,7 +273,9 @@ def change_maintenance_comment( ) -> None: """Write maintenance comment in the db.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") if worker.state in ( EdgeWorkerState.MAINTENANCE_MODE, EdgeWorkerState.MAINTENANCE_PENDING, @@ -285,7 +293,9 @@ def change_maintenance_comment( def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None: """Request to shutdown the edge worker.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") if worker.state not in ( EdgeWorkerState.OFFLINE, EdgeWorkerState.OFFLINE_MAINTENANCE, @@ -298,7 +308,9 @@ def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None: def add_worker_queues(worker_name: str, queues: list[str], session: Session = NEW_SESSION) -> None: """Add queues to an edge worker.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") if worker.state in ( EdgeWorkerState.OFFLINE, EdgeWorkerState.OFFLINE_MAINTENANCE, @@ -314,7 +326,9 @@ def add_worker_queues(worker_name: str, queues: list[str], session: Session = NE def remove_worker_queues(worker_name: str, queues: list[str], session: Session = NEW_SESSION) -> None: """Remove queues from an edge worker.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") if worker.state in ( EdgeWorkerState.OFFLINE, EdgeWorkerState.OFFLINE_MAINTENANCE, diff --git a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py index d37f30805cd24..110e9efb7fcd4 100644 --- a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py +++ b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py @@ -66,8 +66,7 @@ def _get_api_endpoint(session: Session = NEW_SESSION) -> dict[str, Any]: from sqlalchemy import select from airflow.auth.managers.models.resource_details import AccessView - from airflow.models.taskinstance import TaskInstanceState - from airflow.utils.state import State + from airflow.utils.state import State, TaskInstanceState from airflow.utils.yaml import safe_load from airflow.www.auth import has_access_view diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py index 14889bd94ad89..046dbea7b630b 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py @@ -68,6 +68,7 @@ class status: # type: ignore[no-redef] HTTP_204_NO_CONTENT = 204 HTTP_400_BAD_REQUEST = 400 HTTP_403_FORBIDDEN = 403 + HTTP_404_NOT_FOUND = 404 HTTP_500_INTERNAL_SERVER_ERROR = 500 class HTTPException(ProblemException): # type: ignore[no-redef] diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py index 4f1804fa9b69c..a162ffc9db70a 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py @@ -39,7 +39,6 @@ status, ) from airflow.stats import Stats -from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import TaskInstanceState jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs") @@ -78,8 +77,8 @@ def fetch( if body.queues: query = query.where(EdgeJobModel.queue.in_(body.queues)) query = query.limit(1) - query = with_row_locks(query, of=EdgeJobModel, session=session, skip_locked=True) - job: EdgeJobModel = session.scalar(query) + query = query.with_for_update(skip_locked=True) + job: EdgeJobModel | None = session.scalar(query) if not job: return None job.state = TaskInstanceState.RUNNING @@ -148,7 +147,7 @@ def state( ) Stats.incr("edge_worker.ti.finish", tags=tags) - query = ( + query2 = ( update(EdgeJobModel) .where( EdgeJobModel.dag_id == dag_id, @@ -159,4 +158,4 @@ def state( ) .values(state=state, last_update=timezone.utcnow()) ) - session.execute(query) + session.execute(query2) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py index 11ea9b71e9054..14b44deaff7ee 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py @@ -18,6 +18,7 @@ from __future__ import annotations from datetime import datetime +from typing import TYPE_CHECKING from fastapi import Depends, HTTPException, status from sqlalchemy import select @@ -44,6 +45,10 @@ Worker, WorkerCollectionResponse, ) +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.engine import ScalarResult ui_router = AirflowRouter(tags=["UI"]) @@ -59,7 +64,7 @@ def worker( ) -> WorkerCollectionResponse: """Return Edge Workers.""" query = select(EdgeWorkerModel).order_by(EdgeWorkerModel.worker_name) - workers: list[EdgeWorkerModel] = session.scalars(query) + workers: ScalarResult[EdgeWorkerModel] = session.scalars(query) result = [ Worker( @@ -91,7 +96,7 @@ def jobs( ) -> JobCollectionResponse: """Return Edge Jobs.""" query = select(EdgeJobModel).order_by(EdgeJobModel.queued_dttm) - jobs: list[EdgeJobModel] = session.scalars(query) + jobs: ScalarResult[EdgeJobModel] = session.scalars(query) result = [ Job( @@ -100,7 +105,7 @@ def jobs( run_id=j.run_id, map_index=j.map_index, try_number=j.try_number, - state=j.state, + state=TaskInstanceState(j.state), queue=j.queue, queued_dttm=j.queued_dttm, edge_worker=j.edge_worker, diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py index d99ae0b270b4c..1a696509d1ed6 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py @@ -172,7 +172,7 @@ def register( """Register a new worker to the backend.""" _assert_version(body.sysinfo) query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) if not worker: worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, queues=body.queues) worker.state = redefine_state(worker.state, body.state) @@ -194,7 +194,9 @@ def set_state( ) -> WorkerSetStateReturn: """Set state of worker and returns the current assigned queues.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found") worker.state = redefine_state(worker.state, body.state) worker.maintenance_comment = redefine_maintenance_comments( worker.maintenance_comment, body.maintenance_comments @@ -229,7 +231,9 @@ def update_queues( session: SessionDep, ) -> None: query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found") if body.new_queues: worker.add_queues(body.new_queues) if body.remove_queues: