diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index 55d5414cd8965..a2ebe43565299 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -177,7 +177,11 @@ def _check_worker_liveness(self, session: Session) -> bool: .with_for_update(skip_locked=True) .filter( EdgeWorkerModel.state.not_in( - [EdgeWorkerState.UNKNOWN, EdgeWorkerState.OFFLINE, EdgeWorkerState.OFFLINE_MAINTENANCE] + [ + EdgeWorkerState.UNKNOWN, + EdgeWorkerState.OFFLINE, + EdgeWorkerState.OFFLINE_MAINTENANCE, + ] ), EdgeWorkerModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval * 5)), ) @@ -186,7 +190,17 @@ def _check_worker_liveness(self, session: Session) -> bool: for worker in lifeless_workers: changed = True - worker.state = EdgeWorkerState.UNKNOWN + # If the worker dies in maintenance mode we want to remember it, so it can start in maintenance mode + worker.state = ( + EdgeWorkerState.OFFLINE_MAINTENANCE + if worker.state + in ( + EdgeWorkerState.MAINTENANCE_MODE, + EdgeWorkerState.MAINTENANCE_PENDING, + EdgeWorkerState.MAINTENANCE_REQUEST, + ) + else EdgeWorkerState.UNKNOWN + ) reset_metrics(worker.worker_name) return changed diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py index f251d32acf998..b24330733c8a5 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py @@ -122,7 +122,12 @@ def redefine_state(worker_state: EdgeWorkerState, body_state: EdgeWorkerState) - EdgeWorkerState.MAINTENANCE_PENDING, EdgeWorkerState.MAINTENANCE_MODE, ) - or worker_state == EdgeWorkerState.OFFLINE_MAINTENANCE + or worker_state + in ( + EdgeWorkerState.OFFLINE_MAINTENANCE, + EdgeWorkerState.MAINTENANCE_MODE, + EdgeWorkerState.MAINTENANCE_PENDING, + ) and body_state == EdgeWorkerState.STARTING ): return EdgeWorkerState.MAINTENANCE_REQUEST diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py index 4cf5ffc80031e..3553ec0e083db 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py @@ -143,6 +143,24 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo EdgeWorkerState.MAINTENANCE_REQUEST, id="maintenance_starting", ), + pytest.param( + EdgeWorkerState.MAINTENANCE_MODE, + EdgeWorkerState.STARTING, + EdgeWorkerState.MAINTENANCE_REQUEST, + id="maintenance_crash", + ), + pytest.param( + EdgeWorkerState.MAINTENANCE_PENDING, + EdgeWorkerState.STARTING, + EdgeWorkerState.MAINTENANCE_REQUEST, + id="maintenance_crash_2", + ), + pytest.param( + EdgeWorkerState.MAINTENANCE_REQUEST, + EdgeWorkerState.STARTING, + EdgeWorkerState.MAINTENANCE_REQUEST, + id="maintenance_crash_3", + ), ], ) def test_redefine_state(