Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from airflow.cli.cli_config import GroupCommand
from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance, TaskInstanceState
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS
from airflow.providers.edge3.models.edge_job import EdgeJobModel
Expand All @@ -40,6 +40,7 @@
from airflow.stats import Stats
from airflow.utils.db import DBLocks, create_global_lock
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
import argparse
Expand Down Expand Up @@ -68,7 +69,8 @@ def _check_db_schema(self, engine: Engine) -> None:
"""
Check if already existing table matches the newest table schema.

workaround till Airflow 3.0.0, then it is possible to use alembic also for provider distributions.
workaround till support for Airflow 2.x is dropped,
then it is possible to use alembic also for provider distributions.
"""
inspector = inspect(engine)
edge_job_columns = None
Expand All @@ -78,7 +80,7 @@ def _check_db_schema(self, engine: Engine) -> None:
edge_job_columns = [column["name"] for column in edge_job_schema]
for column in edge_job_schema:
if column["name"] == "command":
edge_job_command_len = column["type"].length
edge_job_command_len = column["type"].length # type: ignore[attr-defined]

# version 0.6.0rc1 added new column concurrency_slots
if edge_job_columns and "concurrency_slots" not in edge_job_columns:
Expand Down Expand Up @@ -284,7 +286,7 @@ def _update_orphaned_jobs(self, session: Session) -> bool:
map_index=job.map_index,
session=session,
)
job.state = ti.state if ti else TaskInstanceState.REMOVED
job.state = ti.state if ti and ti.state else TaskInstanceState.REMOVED

if job.state != TaskInstanceState.RUNNING:
# Edge worker does not backport emitted Airflow metrics, so export some metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def key(self):

@property
def last_update_t(self) -> float:
return self.last_update.timestamp()
return self.last_update.timestamp() if self.last_update else datetime.now().timestamp()
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def request_maintenance(
) -> None:
"""Write maintenance request to the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
worker.state = EdgeWorkerState.MAINTENANCE_REQUEST
worker.maintenance_comment = maintenance_comment

Expand All @@ -239,7 +241,9 @@ def request_maintenance(
def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Write maintenance exit to the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
worker.state = EdgeWorkerState.MAINTENANCE_EXIT
worker.maintenance_comment = None

Expand All @@ -248,7 +252,9 @@ def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None:
def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Remove a worker that is offline or just gone from DB."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
Expand All @@ -267,7 +273,9 @@ def change_maintenance_comment(
) -> None:
"""Write maintenance comment in the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
if worker.state in (
EdgeWorkerState.MAINTENANCE_MODE,
EdgeWorkerState.MAINTENANCE_PENDING,
Expand All @@ -285,7 +293,9 @@ def change_maintenance_comment(
def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Request to shutdown the edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
if worker.state not in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
Expand All @@ -298,7 +308,9 @@ def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None:
def add_worker_queues(worker_name: str, queues: list[str], session: Session = NEW_SESSION) -> None:
"""Add queues to an edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
Expand All @@ -314,7 +326,9 @@ def add_worker_queues(worker_name: str, queues: list[str], session: Session = NE
def remove_worker_queues(worker_name: str, queues: list[str], session: Session = NEW_SESSION) -> None:
"""Remove queues from an edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def _get_api_endpoint(session: Session = NEW_SESSION) -> dict[str, Any]:
from sqlalchemy import select

from airflow.auth.managers.models.resource_details import AccessView
from airflow.models.taskinstance import TaskInstanceState
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.yaml import safe_load
from airflow.www.auth import has_access_view

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class status: # type: ignore[no-redef]
HTTP_204_NO_CONTENT = 204
HTTP_400_BAD_REQUEST = 400
HTTP_403_FORBIDDEN = 403
HTTP_404_NOT_FOUND = 404
HTTP_500_INTERNAL_SERVER_ERROR = 500

class HTTPException(ProblemException): # type: ignore[no-redef]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
status,
)
from airflow.stats import Stats
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import TaskInstanceState

jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
Expand Down Expand Up @@ -78,8 +77,8 @@ def fetch(
if body.queues:
query = query.where(EdgeJobModel.queue.in_(body.queues))
query = query.limit(1)
query = with_row_locks(query, of=EdgeJobModel, session=session, skip_locked=True)
job: EdgeJobModel = session.scalar(query)
query = query.with_for_update(skip_locked=True)
job: EdgeJobModel | None = session.scalar(query)
if not job:
return None
job.state = TaskInstanceState.RUNNING
Expand Down Expand Up @@ -148,7 +147,7 @@ def state(
)
Stats.incr("edge_worker.ti.finish", tags=tags)

query = (
query2 = (
update(EdgeJobModel)
.where(
EdgeJobModel.dag_id == dag_id,
Expand All @@ -159,4 +158,4 @@ def state(
)
.values(state=state, last_update=timezone.utcnow())
)
session.execute(query)
session.execute(query2)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

from fastapi import Depends, HTTPException, status
from sqlalchemy import select
Expand All @@ -44,6 +45,10 @@
Worker,
WorkerCollectionResponse,
)
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.engine import ScalarResult

ui_router = AirflowRouter(tags=["UI"])

Expand All @@ -59,7 +64,7 @@ def worker(
) -> WorkerCollectionResponse:
"""Return Edge Workers."""
query = select(EdgeWorkerModel).order_by(EdgeWorkerModel.worker_name)
workers: list[EdgeWorkerModel] = session.scalars(query)
workers: ScalarResult[EdgeWorkerModel] = session.scalars(query)

result = [
Worker(
Expand Down Expand Up @@ -91,7 +96,7 @@ def jobs(
) -> JobCollectionResponse:
"""Return Edge Jobs."""
query = select(EdgeJobModel).order_by(EdgeJobModel.queued_dttm)
jobs: list[EdgeJobModel] = session.scalars(query)
jobs: ScalarResult[EdgeJobModel] = session.scalars(query)

result = [
Job(
Expand All @@ -100,7 +105,7 @@ def jobs(
run_id=j.run_id,
map_index=j.map_index,
try_number=j.try_number,
state=j.state,
state=TaskInstanceState(j.state),
queue=j.queue,
queued_dttm=j.queued_dttm,
edge_worker=j.edge_worker,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def register(
"""Register a new worker to the backend."""
_assert_version(body.sysinfo)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, queues=body.queues)
worker.state = redefine_state(worker.state, body.state)
Expand All @@ -194,7 +194,9 @@ def set_state(
) -> WorkerSetStateReturn:
"""Set state of worker and returns the current assigned queues."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found")
worker.state = redefine_state(worker.state, body.state)
worker.maintenance_comment = redefine_maintenance_comments(
worker.maintenance_comment, body.maintenance_comments
Expand Down Expand Up @@ -229,7 +231,9 @@ def update_queues(
session: SessionDep,
) -> None:
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found")
if body.new_queues:
worker.add_queues(body.new_queues)
if body.remove_queues:
Expand Down
Loading