diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index 6b15cd342cc05..89741d95bc83c 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -399,6 +399,35 @@ def end(self) -> None: def terminate(self): """Terminate the executor is not doing anything.""" + @provide_session + def revoke_task(self, *, ti: TaskInstance, session: Session = NEW_SESSION): + """ + Revoke a task instance from the executor. + + This method removes the task from the executor's internal state and deletes + the corresponding EdgeJobModel record to prevent edge workers from picking it up. + + :param ti: Task instance to revoke + :param session: Database session + """ + # Remove from executor's internal state + self.running.discard(ti.key) + self.queued_tasks.pop(ti.key, None) + if ti.key in self.last_reported_state: + del self.last_reported_state[ti.key] + + # Delete the job from the database to prevent edge workers from picking it up + session.execute( + delete(EdgeJobModel).where( + EdgeJobModel.dag_id == ti.dag_id, + EdgeJobModel.task_id == ti.task_id, + EdgeJobModel.run_id == ti.run_id, + EdgeJobModel.map_index == ti.map_index, + EdgeJobModel.try_number == ti.try_number, + ) + ) + self.log.info("Revoked task instance %s from EdgeExecutor", ti.key) + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ Try to adopt running task instances that have been abandoned by a SchedulerJob dying. diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index ca831fb3e06dc..de9760ec9719a 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -435,3 +435,86 @@ def test_queue_workload_updates_existing_job(self): job = jobs[0] assert job.queue == "updated-queue" assert job.command != "old-command" + + def test_revoke_task(self): + """Test that revoke_task removes task from executor and database.""" + executor = EdgeExecutor() + key = TaskInstanceKey( + dag_id="test_dag", run_id="test_run", task_id="test_task", map_index=-1, try_number=1 + ) + + # Create a mock task instance + ti = MagicMock() + ti.key = key + ti.dag_id = "test_dag" + ti.task_id = "test_task" + ti.run_id = "test_run" + ti.map_index = -1 + ti.try_number = 1 + + # Add task to executor's internal state + executor.running.add(key) + executor.queued_tasks[key] = [None, None, None, ti] + executor.last_reported_state[key] = TaskInstanceState.QUEUED + + # Add corresponding job to database + with create_session() as session: + session.add( + EdgeJobModel( + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + state=TaskInstanceState.QUEUED, + queue="default", + command="mock", + concurrency_slots=1, + ) + ) + session.commit() + + # Verify job exists before revoke + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(jobs) == 1 + + # Revoke the task + executor.revoke_task(ti=ti) + + # Verify task is removed from executor's internal state + assert key not in executor.running + assert key not in executor.queued_tasks + assert key not in executor.last_reported_state + + # Verify job is removed from database + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(jobs) == 0 + + def test_revoke_task_nonexistent(self): + """Test that revoke_task handles non-existent tasks gracefully.""" + executor = EdgeExecutor() + key = TaskInstanceKey( + dag_id="nonexistent_dag", + run_id="nonexistent_run", + task_id="nonexistent_task", + map_index=-1, + try_number=1, + ) + + # Create a mock task instance + ti = MagicMock() + ti.key = key + ti.dag_id = "nonexistent_dag" + ti.task_id = "nonexistent_task" + ti.run_id = "nonexistent_run" + ti.map_index = -1 + ti.try_number = 1 + + # Revoke a task that doesn't exist (should not raise error) + executor.revoke_task(ti=ti) + + # Verify nothing breaks + assert key not in executor.running + assert key not in executor.queued_tasks