Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ repos:
^providers/celery/.*\.py$|
^providers/cncf/kubernetes/.*\.py$|
^providers/databricks/.*\.py$|
^providers/edge3/.*\.py$|
^providers/mysql/.*\.py$|
^providers/openlineage/.*\.py$|
^task_sdk.*\.py$
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any

from sqlalchemy import delete, inspect, text
from sqlalchemy import delete, inspect, select, text
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -140,17 +140,15 @@ def queue_workload(
key = task_instance.key

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = (
session.query(EdgeJobModel)
.filter_by(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == key.dag_id,
EdgeJobModel.task_id == key.task_id,
EdgeJobModel.run_id == key.run_id,
EdgeJobModel.map_index == key.map_index,
EdgeJobModel.try_number == key.try_number,
)
.first()
)
).first()

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
Expand All @@ -176,10 +174,10 @@ def _check_worker_liveness(self, session: Session) -> bool:
"""Reset worker state if heartbeat timed out."""
changed = False
heartbeat_interval: int = conf.getint("edge", "heartbeat_interval")
lifeless_workers: list[EdgeWorkerModel] = (
session.query(EdgeWorkerModel)
lifeless_workers: Sequence[EdgeWorkerModel] = session.scalars(
select(EdgeWorkerModel)
.with_for_update(skip_locked=True)
.filter(
.where(
EdgeWorkerModel.state.not_in(
[
EdgeWorkerState.UNKNOWN,
Expand All @@ -189,8 +187,7 @@ def _check_worker_liveness(self, session: Session) -> bool:
),
EdgeWorkerModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval * 5)),
)
.all()
)
).all()

for worker in lifeless_workers:
changed = True
Expand All @@ -212,15 +209,14 @@ def _check_worker_liveness(self, session: Session) -> bool:
def _update_orphaned_jobs(self, session: Session) -> bool:
"""Update status ob jobs when workers die and don't update anymore."""
heartbeat_interval: int = conf.getint("scheduler", "task_instance_heartbeat_timeout")
lifeless_jobs: list[EdgeJobModel] = (
session.query(EdgeJobModel)
lifeless_jobs: Sequence[EdgeJobModel] = session.scalars(
select(EdgeJobModel)
.with_for_update(skip_locked=True)
.filter(
.where(
EdgeJobModel.state == TaskInstanceState.RUNNING,
EdgeJobModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval)),
)
.all()
)
).all()

for job in lifeless_jobs:
ti = TaskInstance.get_task_instance(
Expand Down Expand Up @@ -254,10 +250,10 @@ def _purge_jobs(self, session: Session) -> bool:
purged_marker = False
job_success_purge = conf.getint("edge", "job_success_purge")
job_fail_purge = conf.getint("edge", "job_fail_purge")
jobs: list[EdgeJobModel] = (
session.query(EdgeJobModel)
jobs: Sequence[EdgeJobModel] = session.scalars(
select(EdgeJobModel)
.with_for_update(skip_locked=True)
.filter(
.where(
EdgeJobModel.state.in_(
[
TaskInstanceState.RUNNING,
Expand All @@ -269,8 +265,7 @@ def _purge_jobs(self, session: Session) -> bool:
]
)
)
.all()
)
).all()

# Sync DB with executor otherwise runs out of sync in multi scheduler deployment
already_removed = self.running - set(job.key for job in jobs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest
import time_machine
from sqlalchemy import delete, select

from airflow.configuration import conf
from airflow.models.taskinstancekey import TaskInstanceKey
Expand All @@ -40,7 +41,7 @@ class TestEdgeExecutor:
@pytest.fixture(autouse=True)
def setup_test_cases(self):
with create_session() as session:
session.query(EdgeJobModel).delete()
session.execute(delete(EdgeJobModel))

def get_test_executor(self, pool_slots=1):
key = TaskInstanceKey(
Expand Down Expand Up @@ -104,7 +105,7 @@ def test_sync_orphaned_tasks(self, mock_stats_incr):
mock_stats_incr.call_count == 2

with create_session() as session:
jobs = session.query(EdgeJobModel).all()
jobs = session.scalars(select(EdgeJobModel)).all()
assert len(jobs) == 1
assert jobs[0].task_id == "started_running_orphaned"
assert jobs[0].state == TaskInstanceState.REMOVED
Expand Down Expand Up @@ -154,8 +155,8 @@ def remove_from_running(key: TaskInstanceKey):
executor.sync()

with create_session() as session:
jobs = session.query(EdgeJobModel).all()
assert len(session.query(EdgeJobModel).all()) == 1
jobs = session.scalars(select(EdgeJobModel)).all()
assert len(session.scalars(select(EdgeJobModel)).all()) == 1
assert jobs[0].task_id == "started_running"
assert jobs[0].state == TaskInstanceState.RUNNING

Expand Down Expand Up @@ -215,7 +216,7 @@ def test_sync_active_worker(self):
# Prepare some data
with create_session() as session:
# Clear existing workers to avoid unique constraint violation
session.query(EdgeWorkerModel).delete()
session.execute(delete(EdgeWorkerModel))
session.commit()

# Add workers with different states
Expand Down Expand Up @@ -253,7 +254,7 @@ def test_sync_active_worker(self):
executor.sync()

with create_session() as session:
for worker in session.query(EdgeWorkerModel).all():
for worker in session.scalars(select(EdgeWorkerModel)).all():
print(worker.worker_name)
if "maintenance_" in worker.worker_name:
EdgeWorkerState.OFFLINE_MAINTENANCE
Expand Down Expand Up @@ -304,7 +305,7 @@ def test_revoke_task(self):

# Verify job exists before revoke
with create_session() as session:
jobs = session.query(EdgeJobModel).all()
jobs = session.scalars(select(EdgeJobModel)).all()
assert len(jobs) == 1

# Revoke the task
Expand All @@ -317,7 +318,7 @@ def test_revoke_task(self):

# Verify job is removed from database
with create_session() as session:
jobs = session.query(EdgeJobModel).all()
jobs = session.scalars(select(EdgeJobModel)).all()
assert len(jobs) == 0

def test_revoke_task_nonexistent(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import patch

import pytest
from sqlalchemy import delete, select

from airflow.providers.common.compat.sdk import Stats
from airflow.providers.edge3.models.edge_job import EdgeJobModel
Expand All @@ -42,7 +43,7 @@
class TestJobsApiRoutes:
@pytest.fixture(autouse=True)
def setup_test_cases(self, dag_maker, session: Session):
session.query(EdgeJobModel).delete()
session.execute(delete(EdgeJobModel))
session.commit()

@patch(f"{Stats.__module__}.Stats.incr")
Expand Down Expand Up @@ -94,4 +95,6 @@ def test_state(self, mock_stats_incr, session: Session):
)
mock_stats_incr.call_count == 2

assert session.query(EdgeJobModel).scalar().state == TaskInstanceState.SUCCESS
db_job: EdgeJobModel | None = session.scalar(select(EdgeJobModel))
assert db_job is not None
assert db_job.state == TaskInstanceState.SUCCESS
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING

import pytest
from sqlalchemy import delete, select

from airflow.providers.common.compat.sdk import timezone
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
Expand All @@ -45,7 +47,7 @@ def setup_test_cases(self, dag_maker, session: Session):
EmptyOperator(task_id=TASK_ID)
dag_maker.create_dagrun(run_id=RUN_ID)

session.query(EdgeLogsModel).delete()
session.execute(delete(EdgeLogsModel))
session.commit()

def test_logfile_path(self, session: Session):
Expand All @@ -68,7 +70,7 @@ def test_push_logs(self, session: Session):
body=log_data,
session=session,
)
logs: list[EdgeLogsModel] = session.query(EdgeLogsModel).all()
logs: Sequence[EdgeLogsModel] = session.scalars(select(EdgeLogsModel)).all()
assert len(logs) == 1
assert logs[0].dag_id == DAG_ID
assert logs[0].task_id == TASK_ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import TYPE_CHECKING

import pytest
from sqlalchemy import delete

from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState

Expand All @@ -34,7 +35,7 @@
class TestUiApiRoutes:
@pytest.fixture(autouse=True)
def setup_test_cases(self, session: Session):
session.query(EdgeWorkerModel).delete()
session.execute(delete(EdgeWorkerModel))
session.add(EdgeWorkerModel(worker_name="worker1", queues=["default"], state=EdgeWorkerState.RUNNING))
session.commit()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.
from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING

import pytest
from fastapi import HTTPException
from sqlalchemy import delete, select

from airflow.providers.common.compat.sdk import timezone
from airflow.providers.edge3.cli.worker import EdgeWorker
Expand All @@ -47,7 +49,7 @@ def cli_worker(self, tmp_path: Path) -> EdgeWorker:

@pytest.fixture(autouse=True)
def setup_test_cases(self, session: Session):
session.query(EdgeWorkerModel).delete()
session.execute(delete(EdgeWorkerModel))

def test_assert_version(self):
from airflow import __version__ as airflow_version
Expand Down Expand Up @@ -87,7 +89,7 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo
register("test_worker", body, session)
session.commit()

worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all()
assert len(worker) == 1
assert worker[0].worker_name == "test_worker"
if input_queues:
Expand Down Expand Up @@ -138,7 +140,9 @@ def test_register_duplicate_worker(
# Should succeed for offline/unknown states
register("test_worker", body, session)
session.commit()
worker = session.query(EdgeWorkerModel).filter_by(worker_name="test_worker").first()
worker = session.execute(
select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == "test_worker")
).scalar_one_or_none()
assert worker is not None
# State should be updated (or redefined based on redefine_state logic)
assert worker.state is not None
Expand Down Expand Up @@ -237,7 +241,7 @@ def test_set_state(self, session: Session, cli_worker: EdgeWorker):
)
return_queues = set_state("test2_worker", body, session).queues

worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"
assert worker[0].state == EdgeWorkerState.RUNNING
Expand Down Expand Up @@ -271,7 +275,7 @@ def test_update_queues(
session.commit()
body = WorkerQueueUpdateBody(new_queues=add_queues, remove_queues=remove_queues)
update_queues("test2_worker", body, session)
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"
assert len(expected_queues) == len(worker[0].queues or [])
Expand Down