diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index de3be5e384ae1..6b15cd342cc05 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -140,20 +140,40 @@ def execute_async( del self.edge_queued_tasks[key] self.validate_airflow_tasks_run_command(command) # type: ignore[attr-defined] - session.add( - EdgeJobModel( + + # Check if job already exists with same dag_id, task_id, run_id, map_index, try_number + existing_job = ( + session.query(EdgeJobModel) + .filter_by( dag_id=key.dag_id, task_id=key.task_id, run_id=key.run_id, map_index=key.map_index, try_number=key.try_number, - state=TaskInstanceState.QUEUED, - queue=queue or DEFAULT_QUEUE, - concurrency_slots=task_instance.pool_slots, - command=str(command), ) + .first() ) + if existing_job: + existing_job.state = TaskInstanceState.QUEUED + existing_job.queue = queue or DEFAULT_QUEUE + existing_job.concurrency_slots = task_instance.pool_slots + existing_job.command = str(command) + else: + session.add( + EdgeJobModel( + dag_id=key.dag_id, + task_id=key.task_id, + run_id=key.run_id, + map_index=key.map_index, + try_number=key.try_number, + state=TaskInstanceState.QUEUED, + queue=queue or DEFAULT_QUEUE, + concurrency_slots=task_instance.pool_slots, + command=str(command), + ) + ) + @provide_session def queue_workload( self, @@ -168,20 +188,40 @@ def queue_workload( task_instance = workload.ti key = task_instance.key - session.add( - EdgeJobModel( + + # Check if job already exists with same dag_id, task_id, run_id, map_index, try_number + existing_job = ( + session.query(EdgeJobModel) + .filter_by( dag_id=key.dag_id, task_id=key.task_id, run_id=key.run_id, map_index=key.map_index, try_number=key.try_number, - state=TaskInstanceState.QUEUED, - queue=task_instance.queue, - concurrency_slots=task_instance.pool_slots, - command=workload.model_dump_json(), ) + .first() ) + if existing_job: + existing_job.state = TaskInstanceState.QUEUED + existing_job.queue = task_instance.queue + existing_job.concurrency_slots = task_instance.pool_slots + existing_job.command = workload.model_dump_json() + else: + session.add( + EdgeJobModel( + dag_id=key.dag_id, + task_id=key.task_id, + run_id=key.run_id, + map_index=key.map_index, + try_number=key.try_number, + state=TaskInstanceState.QUEUED, + queue=task_instance.queue, + concurrency_slots=task_instance.pool_slots, + command=workload.model_dump_json(), + ) + ) + def _check_worker_liveness(self, session: Session) -> bool: """Reset worker state if heartbeat timed out.""" changed = False diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index 7c86b39f12dd4..ca831fb3e06dc 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -247,6 +247,11 @@ def test_sync_active_worker(self): # Prepare some data with create_session() as session: + # Clear existing workers to avoid unique constraint violation + session.query(EdgeWorkerModel).delete() + session.commit() + + # Add workers with different states for worker_name, state, last_heartbeat in [ ( "inactive_timed_out_worker", @@ -338,3 +343,95 @@ def test_queue_workload(self): with create_session() as session: jobs = session.query(EdgeJobModel).all() assert len(jobs) == 1 + + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="API only available in Airflow <3.0") + def test_execute_async_updates_existing_job(self): + executor, key = self.get_test_executor() + + # First insert a job with the same key + with create_session() as session: + session.add( + EdgeJobModel( + dag_id=key.dag_id, + run_id=key.run_id, + task_id=key.task_id, + map_index=key.map_index, + try_number=key.try_number, + state=TaskInstanceState.SCHEDULED, + queue="default", + concurrency_slots=1, + command="old-command", + last_update=timezone.utcnow(), + ) + ) + session.commit() + + # Trigger execute_async which should update the existing job + executor.edge_queued_tasks = deepcopy(executor.queued_tasks) + executor.execute_async(key=key, command=["airflow", "tasks", "run", "new", "command"]) + + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(jobs) == 1 + job = jobs[0] + assert job.state == TaskInstanceState.QUEUED + assert job.command != "old-command" + assert "new" in job.command + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="API only available in Airflow 3.0+") + def test_queue_workload_updates_existing_job(self): + from uuid import uuid4 + + from airflow.executors.workloads import ExecuteTask, TaskInstance + + executor = self.get_test_executor()[0] + + key = TaskInstanceKey(dag_id="mock", run_id="mock", task_id="mock", map_index=-1, try_number=1) + + # Insert an existing job + with create_session() as session: + session.add( + EdgeJobModel( + dag_id=key.dag_id, + task_id=key.task_id, + run_id=key.run_id, + map_index=key.map_index, + try_number=key.try_number, + state=TaskInstanceState.SCHEDULED, + queue="default", + command="old-command", + concurrency_slots=1, + last_update=timezone.utcnow(), + ) + ) + session.commit() + + # Queue a workload with same key + workload = ExecuteTask( + token="mock", + ti=TaskInstance( + id=uuid4(), + task_id=key.task_id, + dag_id=key.dag_id, + run_id=key.run_id, + try_number=key.try_number, + map_index=key.map_index, + pool_slots=1, + queue="updated-queue", + priority_weight=1, + start_date=timezone.utcnow(), + dag_version_id=uuid4(), + ), + dag_rel_path="mock.py", + log_path="mock.log", + bundle_info={"name": "n/a", "version": "no matter"}, + ) + + executor.queue_workload(workload=workload) + + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(jobs) == 1 + job = jobs[0] + assert job.queue == "updated-queue" + assert job.command != "old-command"