From ed9cfe813820bb9ce628d35eabdcfeb3c1896734 Mon Sep 17 00:00:00 2001 From: "D. Ferruzzi" Date: Mon, 4 Mar 2024 10:20:23 -0800 Subject: [PATCH] Fix external_executor_id being overwritten (#37784) --- airflow/models/taskinstance.py | 5 +++- tests/conftest.py | 2 ++ tests/models/test_taskinstance.py | 47 ++++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 1b77a71ac73705..57c9483cd4ee7b 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2215,7 +2215,10 @@ def _check_and_change_state_before_execution( ti.state = TaskInstanceState.RUNNING ti.emit_state_change_metric(TaskInstanceState.RUNNING) - ti.external_executor_id = external_executor_id + + if external_executor_id: + ti.external_executor_id = external_executor_id + ti.end_date = None if not test_mode: session.merge(ti).task = task diff --git a/tests/conftest.py b/tests/conftest.py index f6fd2927aa37a1..7fb6a2402f95bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -916,6 +916,7 @@ def maker( run_id=None, run_type=None, data_interval=None, + external_executor_id=None, map_index=-1, **kwargs, ) -> TaskInstance: @@ -936,6 +937,7 @@ def maker( (ti,) = dagrun.task_instances ti.task = task ti.state = state + ti.external_executor_id = external_executor_id ti.map_index = map_index dag_maker.session.flush() diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 4df302b471997a..8431381a6ff92c 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1782,7 +1782,11 @@ def post_execute(self, context, result=None): ti.run() def test_check_and_change_state_before_execution(self, create_task_instance): - ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") + expected_external_executor_id = "banana" + ti = create_task_instance( + dag_id="test_check_and_change_state_before_execution", + external_executor_id=expected_external_executor_id, + ) SerializedDagModel.write_dag(ti.task.dag) serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag @@ -1791,6 +1795,46 @@ def test_check_and_change_state_before_execution(self, create_task_instance): assert ti_from_deserialized_task._try_number == 0 assert ti_from_deserialized_task.check_and_change_state_before_execution() # State should be running, and try_number column should be incremented + assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id + assert ti_from_deserialized_task.state == State.RUNNING + assert ti_from_deserialized_task._try_number == 1 + + def test_check_and_change_state_before_execution_provided_id_overrides(self, create_task_instance): + expected_external_executor_id = "banana" + ti = create_task_instance( + dag_id="test_check_and_change_state_before_execution", + external_executor_id="apple", + ) + assert ti.external_executor_id == "apple" + SerializedDagModel.write_dag(ti.task.dag) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + + assert ti_from_deserialized_task._try_number == 0 + assert ti_from_deserialized_task.check_and_change_state_before_execution( + external_executor_id=expected_external_executor_id + ) + # State should be running, and try_number column should be incremented + assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id + assert ti_from_deserialized_task.state == State.RUNNING + assert ti_from_deserialized_task._try_number == 1 + + def test_check_and_change_state_before_execution_with_exec_id(self, create_task_instance): + expected_external_executor_id = "minions" + ti = create_task_instance(dag_id="test_check_and_change_state_before_execution") + assert ti.external_executor_id is None + SerializedDagModel.write_dag(ti.task.dag) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti_from_deserialized_task = TI(task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id) + + assert ti_from_deserialized_task._try_number == 0 + assert ti_from_deserialized_task.check_and_change_state_before_execution( + external_executor_id=expected_external_executor_id + ) + # State should be running, and try_number column should be incremented + assert ti_from_deserialized_task.external_executor_id == expected_external_executor_id assert ti_from_deserialized_task.state == State.RUNNING assert ti_from_deserialized_task._try_number == 1 @@ -1817,6 +1861,7 @@ def test_check_and_change_state_before_execution_dep_not_met_already_running(sel assert not ti_from_deserialized_task.check_and_change_state_before_execution() assert ti_from_deserialized_task.state == State.RUNNING + assert ti_from_deserialized_task.external_executor_id is None def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( self, create_task_instance