diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04062bddec77f..c66f5c246cbcb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -440,6 +440,7 @@ repos: ^providers/celery/.*\.py$| ^providers/cncf/kubernetes/.*\.py$| ^providers/databricks/.*\.py$| + ^providers/edge3/.*\.py$| ^providers/mysql/.*\.py$| ^providers/openlineage/.*\.py$| ^task_sdk.*\.py$ diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index f918e4b803559..b82b5df02ed90 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -23,7 +23,7 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, inspect, text +from sqlalchemy import delete, inspect, select, text from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import Session @@ -140,17 +140,15 @@ def queue_workload( key = task_instance.key # Check if job already exists with same dag_id, task_id, run_id, map_index, try_number - existing_job = ( - session.query(EdgeJobModel) - .filter_by( - dag_id=key.dag_id, - task_id=key.task_id, - run_id=key.run_id, - map_index=key.map_index, - try_number=key.try_number, + existing_job = session.scalars( + select(EdgeJobModel).where( + EdgeJobModel.dag_id == key.dag_id, + EdgeJobModel.task_id == key.task_id, + EdgeJobModel.run_id == key.run_id, + EdgeJobModel.map_index == key.map_index, + EdgeJobModel.try_number == key.try_number, ) - .first() - ) + ).first() if existing_job: existing_job.state = TaskInstanceState.QUEUED @@ -176,10 +174,10 @@ def _check_worker_liveness(self, session: Session) -> bool: """Reset worker state if heartbeat timed out.""" changed = False heartbeat_interval: int = conf.getint("edge", "heartbeat_interval") - lifeless_workers: list[EdgeWorkerModel] = ( - session.query(EdgeWorkerModel) + lifeless_workers: Sequence[EdgeWorkerModel] = session.scalars( + select(EdgeWorkerModel) .with_for_update(skip_locked=True) - .filter( + .where( EdgeWorkerModel.state.not_in( [ EdgeWorkerState.UNKNOWN, @@ -189,8 +187,7 @@ def _check_worker_liveness(self, session: Session) -> bool: ), EdgeWorkerModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval * 5)), ) - .all() - ) + ).all() for worker in lifeless_workers: changed = True @@ -212,15 +209,14 @@ def _check_worker_liveness(self, session: Session) -> bool: def _update_orphaned_jobs(self, session: Session) -> bool: """Update status ob jobs when workers die and don't update anymore.""" heartbeat_interval: int = conf.getint("scheduler", "task_instance_heartbeat_timeout") - lifeless_jobs: list[EdgeJobModel] = ( - session.query(EdgeJobModel) + lifeless_jobs: Sequence[EdgeJobModel] = session.scalars( + select(EdgeJobModel) .with_for_update(skip_locked=True) - .filter( + .where( EdgeJobModel.state == TaskInstanceState.RUNNING, EdgeJobModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval)), ) - .all() - ) + ).all() for job in lifeless_jobs: ti = TaskInstance.get_task_instance( @@ -254,10 +250,10 @@ def _purge_jobs(self, session: Session) -> bool: purged_marker = False job_success_purge = conf.getint("edge", "job_success_purge") job_fail_purge = conf.getint("edge", "job_fail_purge") - jobs: list[EdgeJobModel] = ( - session.query(EdgeJobModel) + jobs: Sequence[EdgeJobModel] = session.scalars( + select(EdgeJobModel) .with_for_update(skip_locked=True) - .filter( + .where( EdgeJobModel.state.in_( [ TaskInstanceState.RUNNING, @@ -269,8 +265,7 @@ def _purge_jobs(self, session: Session) -> bool: ] ) ) - .all() - ) + ).all() # Sync DB with executor otherwise runs out of sync in multi scheduler deployment already_removed = self.running - set(job.key for job in jobs) diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index b8f411ebd0e03..c38dffed3e918 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -21,6 +21,7 @@ import pytest import time_machine +from sqlalchemy import delete, select from airflow.configuration import conf from airflow.models.taskinstancekey import TaskInstanceKey @@ -40,7 +41,7 @@ class TestEdgeExecutor: @pytest.fixture(autouse=True) def setup_test_cases(self): with create_session() as session: - session.query(EdgeJobModel).delete() + session.execute(delete(EdgeJobModel)) def get_test_executor(self, pool_slots=1): key = TaskInstanceKey( @@ -104,7 +105,7 @@ def test_sync_orphaned_tasks(self, mock_stats_incr): mock_stats_incr.call_count == 2 with create_session() as session: - jobs = session.query(EdgeJobModel).all() + jobs = session.scalars(select(EdgeJobModel)).all() assert len(jobs) == 1 assert jobs[0].task_id == "started_running_orphaned" assert jobs[0].state == TaskInstanceState.REMOVED @@ -154,8 +155,8 @@ def remove_from_running(key: TaskInstanceKey): executor.sync() with create_session() as session: - jobs = session.query(EdgeJobModel).all() - assert len(session.query(EdgeJobModel).all()) == 1 + jobs = session.scalars(select(EdgeJobModel)).all() + assert len(session.scalars(select(EdgeJobModel)).all()) == 1 assert jobs[0].task_id == "started_running" assert jobs[0].state == TaskInstanceState.RUNNING @@ -215,7 +216,7 @@ def test_sync_active_worker(self): # Prepare some data with create_session() as session: # Clear existing workers to avoid unique constraint violation - session.query(EdgeWorkerModel).delete() + session.execute(delete(EdgeWorkerModel)) session.commit() # Add workers with different states @@ -253,7 +254,7 @@ def test_sync_active_worker(self): executor.sync() with create_session() as session: - for worker in session.query(EdgeWorkerModel).all(): + for worker in session.scalars(select(EdgeWorkerModel)).all(): print(worker.worker_name) if "maintenance_" in worker.worker_name: EdgeWorkerState.OFFLINE_MAINTENANCE @@ -304,7 +305,7 @@ def test_revoke_task(self): # Verify job exists before revoke with create_session() as session: - jobs = session.query(EdgeJobModel).all() + jobs = session.scalars(select(EdgeJobModel)).all() assert len(jobs) == 1 # Revoke the task @@ -317,7 +318,7 @@ def test_revoke_task(self): # Verify job is removed from database with create_session() as session: - jobs = session.query(EdgeJobModel).all() + jobs = session.scalars(select(EdgeJobModel)).all() assert len(jobs) == 0 def test_revoke_task_nonexistent(self): diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py index 6845a6d2257ff..e62426f003875 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py @@ -20,6 +20,7 @@ from unittest.mock import patch import pytest +from sqlalchemy import delete, select from airflow.providers.common.compat.sdk import Stats from airflow.providers.edge3.models.edge_job import EdgeJobModel @@ -42,7 +43,7 @@ class TestJobsApiRoutes: @pytest.fixture(autouse=True) def setup_test_cases(self, dag_maker, session: Session): - session.query(EdgeJobModel).delete() + session.execute(delete(EdgeJobModel)) session.commit() @patch(f"{Stats.__module__}.Stats.incr") @@ -94,4 +95,6 @@ def test_state(self, mock_stats_incr, session: Session): ) mock_stats_incr.call_count == 2 - assert session.query(EdgeJobModel).scalar().state == TaskInstanceState.SUCCESS + db_job: EdgeJobModel | None = session.scalar(select(EdgeJobModel)) + assert db_job is not None + assert db_job.state == TaskInstanceState.SUCCESS diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_logs.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_logs.py index 9854b32f9c8e5..5cc8a533c3ee6 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_logs.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_logs.py @@ -16,9 +16,11 @@ # under the License. from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING import pytest +from sqlalchemy import delete, select from airflow.providers.common.compat.sdk import timezone from airflow.providers.edge3.models.edge_logs import EdgeLogsModel @@ -45,7 +47,7 @@ def setup_test_cases(self, dag_maker, session: Session): EmptyOperator(task_id=TASK_ID) dag_maker.create_dagrun(run_id=RUN_ID) - session.query(EdgeLogsModel).delete() + session.execute(delete(EdgeLogsModel)) session.commit() def test_logfile_path(self, session: Session): @@ -68,7 +70,7 @@ def test_push_logs(self, session: Session): body=log_data, session=session, ) - logs: list[EdgeLogsModel] = session.query(EdgeLogsModel).all() + logs: Sequence[EdgeLogsModel] = session.scalars(select(EdgeLogsModel)).all() assert len(logs) == 1 assert logs[0].dag_id == DAG_ID assert logs[0].task_id == TASK_ID diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_ui.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_ui.py index 7dec06c1004e5..f1934b5d0f6dd 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_ui.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_ui.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING import pytest +from sqlalchemy import delete from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState @@ -34,7 +35,7 @@ class TestUiApiRoutes: @pytest.fixture(autouse=True) def setup_test_cases(self, session: Session): - session.query(EdgeWorkerModel).delete() + session.execute(delete(EdgeWorkerModel)) session.add(EdgeWorkerModel(worker_name="worker1", queues=["default"], state=EdgeWorkerState.RUNNING)) session.commit() diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py index 5f1b3b6db935c..57a0bd810213d 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py @@ -16,11 +16,13 @@ # under the License. from __future__ import annotations +from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING import pytest from fastapi import HTTPException +from sqlalchemy import delete, select from airflow.providers.common.compat.sdk import timezone from airflow.providers.edge3.cli.worker import EdgeWorker @@ -47,7 +49,7 @@ def cli_worker(self, tmp_path: Path) -> EdgeWorker: @pytest.fixture(autouse=True) def setup_test_cases(self, session: Session): - session.query(EdgeWorkerModel).delete() + session.execute(delete(EdgeWorkerModel)) def test_assert_version(self): from airflow import __version__ as airflow_version @@ -87,7 +89,7 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo register("test_worker", body, session) session.commit() - worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() + worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all() assert len(worker) == 1 assert worker[0].worker_name == "test_worker" if input_queues: @@ -138,7 +140,9 @@ def test_register_duplicate_worker( # Should succeed for offline/unknown states register("test_worker", body, session) session.commit() - worker = session.query(EdgeWorkerModel).filter_by(worker_name="test_worker").first() + worker = session.execute( + select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == "test_worker") + ).scalar_one_or_none() assert worker is not None # State should be updated (or redefined based on redefine_state logic) assert worker.state is not None @@ -237,7 +241,7 @@ def test_set_state(self, session: Session, cli_worker: EdgeWorker): ) return_queues = set_state("test2_worker", body, session).queues - worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() + worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all() assert len(worker) == 1 assert worker[0].worker_name == "test2_worker" assert worker[0].state == EdgeWorkerState.RUNNING @@ -271,7 +275,7 @@ def test_update_queues( session.commit() body = WorkerQueueUpdateBody(new_queues=add_queues, remove_queues=remove_queues) update_queues("test2_worker", body, session) - worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() + worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all() assert len(worker) == 1 assert worker[0].worker_name == "test2_worker" assert len(expected_queues) == len(worker[0].queues or [])