From 44e4ccbf745a394807403222a6d17c893269e712 Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Fri, 28 Jun 2024 10:55:48 -0400 Subject: [PATCH 1/2] Adding cluster to ecs trigger event to avoid defer error --- airflow/providers/amazon/aws/triggers/ecs.py | 4 +++- tests/providers/amazon/aws/operators/test_ecs.py | 2 +- tests/providers/amazon/aws/triggers/test_ecs.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py index 1177aa657a2b..dd86899f2200 100644 --- a/airflow/providers/amazon/aws/triggers/ecs.py +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -179,7 +179,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1} ) # we reach this point only if the waiter met a success criteria - yield TriggerEvent({"status": "success", "task_arn": self.task_arn}) + yield TriggerEvent( + {"status": "success", "task_arn": self.task_arn, "cluster": self.cluster} + ) return except WaiterError as error: if "terminal failure" in str(error): diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 9bccd22f9e21..4f4a6178d800 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -674,7 +674,7 @@ def test_with_defer(self, client_mock, xcom_mock): @mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock) def test_execute_complete(self, client_mock): - event = {"status": "success", "task_arn": "my_arn"} + event = {"status": "success", "task_arn": "my_arn", "cluster": "test_cluster"} self.ecs.reattach = True self.ecs.execute_complete(None, event) diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py index 27c815be8237..a5c11d92986b 100644 --- a/tests/providers/amazon/aws/triggers/test_ecs.py +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -92,3 +92,4 @@ async def test_run_success(self, _, client_mock): assert response.payload["status"] == "success" assert response.payload["task_arn"] == "my_task_arn" + assert response.payload["cluster"] == "cluster" From 55b345db5a0c551e53eccb4a6153f03d55000e6d Mon Sep 17 00:00:00 2001 From: Mike Ellis Date: Fri, 28 Jun 2024 11:13:35 -0400 Subject: [PATCH 2/2] restore cluster after defer --- airflow/providers/amazon/aws/operators/ecs.py | 1 + tests/providers/amazon/aws/operators/test_ecs.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 294291dc0afd..1cd8685cf282 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -586,6 +586,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if event["status"] != "success": raise AirflowException(f"Error in task execution: {event}") self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps + self.cluster = event["cluster"] self._after_execution() if self._aws_logs_enabled(): # same behavior as non-deferrable mode, return last line of logs of the task. diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 4f4a6178d800..a6915214a076 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -680,7 +680,7 @@ def test_execute_complete(self, client_mock): self.ecs.execute_complete(None, event) # task gets described to assert its success - client_mock().describe_tasks.assert_called_once_with(cluster="c", tasks=["my_arn"]) + client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", tasks=["my_arn"]) @pytest.mark.db_test @pytest.mark.parametrize(