Skip to content

Commit

Permalink
Adding cluster to ecs trigger event to avoid defer error (#40482)
Browse files Browse the repository at this point in the history
* Adding cluster to ecs trigger event to avoid defer error

* restore cluster after defer
  • Loading branch information
ellisms authored Jun 28, 2024
1 parent 2423238 commit 6c12744
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
if event["status"] != "success":
raise AirflowException(f"Error in task execution: {event}")
self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps
self.cluster = event["cluster"]
self._after_execution()
if self._aws_logs_enabled():
# same behavior as non-deferrable mode, return last line of logs of the task.
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/triggers/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}
)
# we reach this point only if the waiter met a success criteria
yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
yield TriggerEvent(
{"status": "success", "task_arn": self.task_arn, "cluster": self.cluster}
)
return
except WaiterError as error:
if "terminal failure" in str(error):
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,13 @@ def test_with_defer(self, client_mock, xcom_mock):

@mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
def test_execute_complete(self, client_mock):
event = {"status": "success", "task_arn": "my_arn"}
event = {"status": "success", "task_arn": "my_arn", "cluster": "test_cluster"}
self.ecs.reattach = True

self.ecs.execute_complete(None, event)

# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="c", tasks=["my_arn"])
client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", tasks=["my_arn"])

@pytest.mark.db_test
@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions tests/providers/amazon/aws/triggers/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ async def test_run_success(self, _, client_mock):

assert response.payload["status"] == "success"
assert response.payload["task_arn"] == "my_task_arn"
assert response.payload["cluster"] == "cluster"

0 comments on commit 6c12744

Please sign in to comment.