diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 891b4d727ce9c..8131be4f65412 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -19,6 +19,7 @@ import warnings from ast import literal_eval +from datetime import timedelta from typing import TYPE_CHECKING, Any, List, Sequence, cast from botocore.exceptions import ClientError, WaiterError @@ -26,6 +27,10 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.eks import EksHook +from airflow.providers.amazon.aws.triggers.eks import ( + EksCreateFargateProfileTrigger, + EksDeleteFargateProfileTrigger, +) try: from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator @@ -353,6 +358,11 @@ class EksCreateFargateProfileOperator(BaseOperator): maintained on each worker node). :param region: Which AWS region the connection should use. (templated) If this is None or empty then the default boto3 behaviour is used. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check profile status + :param waiter_max_attempts: The maximum number of attempts to check the status of the profile. + :param deferrable: If True, the operator will wait asynchronously for the profile to be created. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ( @@ -371,11 +381,14 @@ def __init__( cluster_name: str, pod_execution_role_arn: str, selectors: list, - fargate_profile_name: str | None = DEFAULT_FARGATE_PROFILE_NAME, + fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME, create_fargate_profile_kwargs: dict | None = None, wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + waiter_delay: int = 10, + waiter_max_attempts: int = 60, + deferrable: bool = False, **kwargs, ) -> None: self.cluster_name = cluster_name @@ -386,6 +399,9 @@ def __init__( self.wait_for_completion = wait_for_completion self.aws_conn_id = aws_conn_id self.region = region + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable super().__init__(**kwargs) def execute(self, context: Context): @@ -401,13 +417,35 @@ def execute(self, context: Context): selectors=self.selectors, **self.create_fargate_profile_kwargs, ) - - if self.wait_for_completion: + if self.deferrable: + self.defer( + trigger=EksCreateFargateProfileTrigger( + cluster_name=self.cluster_name, + fargate_profile_name=self.fargate_profile_name, + aws_conn_id=self.aws_conn_id, + poll_interval=self.waiter_delay, + max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)), + ) + elif self.wait_for_completion: self.log.info("Waiting for Fargate profile to provision. This will take some time.") eks_hook.conn.get_waiter("fargate_profile_active").wait( - clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name + clusterName=self.cluster_name, + fargateProfileName=self.fargate_profile_name, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error creating Fargate profile: {event}") + else: + self.log.info("Fargate profile created successfully") + return + class EksDeleteClusterOperator(BaseOperator): """ @@ -587,6 +625,11 @@ class EksDeleteFargateProfileOperator(BaseOperator): maintained on each worker node). :param region: Which AWS region the connection should use. (templated) If this is None or empty then the default boto3 behaviour is used. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check profile status + :param waiter_max_attempts: The maximum number of attempts to check the status of the profile. + :param deferrable: If True, the operator will wait asynchronously for the profile to be deleted. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ( @@ -604,6 +647,9 @@ def __init__( wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -612,6 +658,9 @@ def __init__( self.wait_for_completion = wait_for_completion self.aws_conn_id = aws_conn_id self.region = region + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): eks_hook = EksHook( @@ -622,12 +671,35 @@ def execute(self, context: Context): eks_hook.delete_fargate_profile( clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name ) - if self.wait_for_completion: + if self.deferrable: + self.defer( + trigger=EksDeleteFargateProfileTrigger( + cluster_name=self.cluster_name, + fargate_profile_name=self.fargate_profile_name, + aws_conn_id=self.aws_conn_id, + poll_interval=self.waiter_delay, + max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)), + ) + elif self.wait_for_completion: self.log.info("Waiting for Fargate profile to delete. This will take some time.") eks_hook.conn.get_waiter("fargate_profile_deleted").wait( - clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name + clusterName=self.cluster_name, + fargateProfileName=self.fargate_profile_name, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, ) + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error deleting Fargate profile: {event}") + else: + self.log.info("Fargate profile deleted successfully") + return + class EksPodOperator(KubernetesPodOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/eks.py b/airflow/providers/amazon/aws/triggers/eks.py new file mode 100644 index 0000000000000..dddab74b30496 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/eks.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any + +from botocore.exceptions import WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.eks import EksHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class EksCreateFargateProfileTrigger(BaseTrigger): + """ + Trigger for EksCreateFargateProfileOperator. + The trigger will asynchronously wait for the fargate profile to be created. + + :param cluster_name: The name of the EKS cluster + :param fargate_profile_name: The name of the fargate profile + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + cluster_name: str, + fargate_profile_name: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.cluster_name = cluster_name + self.fargate_profile_name = fargate_profile_name + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "cluster_name": self.cluster_name, + "fargate_profile_name": self.fargate_profile_name, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": self.aws_conn_id, + }, + ) + + async def run(self): + self.hook = EksHook(aws_conn_id=self.aws_conn_id) + async with self.hook.async_conn as client: + attempt = 0 + waiter = client.get_waiter("fargate_profile_active") + while attempt < int(self.max_attempts): + attempt += 1 + try: + await waiter.wait( + clusterName=self.cluster_name, + fargateProfileName=self.fargate_profile_name, + WaiterConfig={"Delay": int(self.poll_interval), "MaxAttempts": 1}, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"Create Fargate Profile failed: {error}") + self.log.info( + "Status of fargate profile is %s", error.last_response["fargateProfile"]["status"] + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + raise AirflowException( + f"Create Fargate Profile failed - max attempts reached: {self.max_attempts}" + ) + else: + yield TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) + + +class EksDeleteFargateProfileTrigger(BaseTrigger): + """ + Trigger for EksDeleteFargateProfileOperator. + The trigger will asynchronously wait for the fargate profile to be deleted. + + :param cluster_name: The name of the EKS cluster + :param fargate_profile_name: The name of the fargate profile + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + cluster_name: str, + fargate_profile_name: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.cluster_name = cluster_name + self.fargate_profile_name = fargate_profile_name + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "cluster_name": self.cluster_name, + "fargate_profile_name": self.fargate_profile_name, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": self.aws_conn_id, + }, + ) + + async def run(self): + self.hook = EksHook(aws_conn_id=self.aws_conn_id) + async with self.hook.async_conn as client: + attempt = 0 + waiter = client.get_waiter("fargate_profile_deleted") + while attempt < int(self.max_attempts): + attempt += 1 + try: + await waiter.wait( + clusterName=self.cluster_name, + fargateProfileName=self.fargate_profile_name, + WaiterConfig={"Delay": int(self.poll_interval), "MaxAttempts": 1}, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise AirflowException(f"Delete Fargate Profile failed: {error}") + self.log.info( + "Status of fargate profile is %s", error.last_response["fargateProfile"]["status"] + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + raise AirflowException( + f"Delete Fargate Profile failed - max attempts reached: {self.max_attempts}" + ) + else: + yield TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 05924eebc9742..51af2e4d24673 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -528,6 +528,9 @@ triggers: - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.triggers.emr + - integration-name: Amazon Elastic Kubernetes Service (EKS) + python-modules: + - airflow.providers.amazon.aws.triggers.eks transfers: - source-integration-name: Amazon DynamoDB diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 37342bbdaa216..089aef17046f4 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -23,6 +23,7 @@ import pytest from botocore.waiter import Waiter +from airflow.exceptions import TaskDeferred from airflow.providers.amazon.aws.hooks.eks import ClusterStates, EksHook from airflow.providers.amazon.aws.operators.eks import ( EksCreateClusterOperator, @@ -33,6 +34,10 @@ EksDeleteNodegroupOperator, EksPodOperator, ) +from airflow.providers.amazon.aws.triggers.eks import ( + EksCreateFargateProfileTrigger, + EksDeleteFargateProfileTrigger, +) from airflow.typing_compat import TypedDict from tests.providers.amazon.aws.utils.eks_test_constants import ( NODEROLE_ARN, @@ -369,10 +374,27 @@ def test_execute_with_wait_when_fargate_profile_does_not_already_exist( operator.execute({}) mock_create_fargate_profile.assert_called_with(**convert_keys(parameters)) mock_waiter.assert_called_with( - mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME + mock.ANY, + clusterName=CLUSTER_NAME, + fargateProfileName=FARGATE_PROFILE_NAME, + WaiterConfig={"Delay": 10, "MaxAttempts": 60}, ) assert_expected_waiter_type(mock_waiter, "FargateProfileActive") + @mock.patch.object(EksHook, "create_fargate_profile") + def test_create_fargate_profile_deferrable(self, _): + op_kwargs = {**self.create_fargate_profile_params} + operator = EksCreateFargateProfileOperator( + task_id=TASK_ID, + **op_kwargs, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute({}) + assert isinstance( + exc.value.trigger, EksCreateFargateProfileTrigger + ), "Trigger is not a EksCreateFargateProfileTrigger" + class TestEksCreateNodegroupOperator: def setup_method(self) -> None: @@ -532,10 +554,23 @@ def test_existing_fargate_profile_with_wait(self, mock_delete_fargate_profile, m clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name ) mock_waiter.assert_called_with( - mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME + mock.ANY, + clusterName=CLUSTER_NAME, + fargateProfileName=FARGATE_PROFILE_NAME, + WaiterConfig={"Delay": 30, "MaxAttempts": 60}, ) assert_expected_waiter_type(mock_waiter, "FargateProfileDeleted") + @mock.patch.object(EksHook, "delete_fargate_profile") + def test_delete_fargate_profile_deferrable(self, _): + self.delete_fargate_profile_operator.deferrable = True + + with pytest.raises(TaskDeferred) as exc: + self.delete_fargate_profile_operator.execute({}) + assert isinstance( + exc.value.trigger, EksDeleteFargateProfileTrigger + ), "Trigger is not a EksDeleteFargateProfileTrigger" + class TestEksPodOperator: @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") diff --git a/tests/providers/amazon/aws/triggers/test_eks.py b/tests/providers/amazon/aws/triggers/test_eks.py new file mode 100644 index 0000000000000..abab121d243d9 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_eks.py @@ -0,0 +1,299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.eks import EksHook +from airflow.providers.amazon.aws.triggers.eks import ( + EksCreateFargateProfileTrigger, + EksDeleteFargateProfileTrigger, +) +from airflow.triggers.base import TriggerEvent + +TEST_CLUSTER_IDENTIFIER = "test-cluster" +TEST_FARGATE_PROFILE_NAME = "test-fargate-profile" +TEST_POLL_INTERVAL = 10 +TEST_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-aws-id" + + +class TestEksCreateFargateProfileTrigger: + def test_eks_create_fargate_profile_serialize(self): + eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + class_path, args = eks_create_fargate_profile_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.eks.EksCreateFargateProfileTrigger" + assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER + assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + + @pytest.mark.asyncio + @mock.patch.object(EksHook, "async_conn") + async def test_eks_create_fargate_profile_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + a_mock.get_waiter().wait = AsyncMock() + + eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = eks_create_fargate_profile_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_create_fargate_profile_trigger_run_multiple_attempts( + self, mock_async_conn, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"fargateProfile": {"status": "CREATING"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = eks_create_fargate_profile_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Created"}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_create_fargate_profile_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"fargateProfile": {"status": "CREATING"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + ) + with pytest.raises(AirflowException) as exc: + generator = eks_create_fargate_profile_trigger.run() + await generator.asend(None) + assert "Create Fargate Profile failed - max attempts reached:" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_create_fargate_profile_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_creating = WaiterError( + name="test_name", + reason="test_reason", + last_response={"fargateProfile": {"status": "CREATING"}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"fargateProfile": {"status": "CREATE_FAILED"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) + mock_sleep.return_value = True + + eks_create_fargate_profile_trigger = EksCreateFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + with pytest.raises(AirflowException) as exc: + generator = eks_create_fargate_profile_trigger.run() + await generator.asend(None) + assert f"Create Fargate Profile failed: {error_failed}" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 3 + + +class TestEksDeleteFargateProfileTrigger: + def test_eks_delete_fargate_profile_serialize(self): + eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + class_path, args = eks_delete_fargate_profile_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.eks.EksDeleteFargateProfileTrigger" + assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER + assert args["fargate_profile_name"] == TEST_FARGATE_PROFILE_NAME + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS) + + @pytest.mark.asyncio + @mock.patch.object(EksHook, "async_conn") + async def test_eks_delete_fargate_profile_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + a_mock.get_waiter().wait = AsyncMock() + + eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = eks_delete_fargate_profile_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_delete_fargate_profile_trigger_run_multiple_attempts( + self, mock_async_conn, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"fargateProfile": {"status": "DELETING"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + + generator = eks_delete_fargate_profile_trigger.run() + response = await generator.asend(None) + assert a_mock.get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_delete_fargate_profile_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_sleep + ): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"fargateProfile": {"status": "DELETING"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) + mock_sleep.return_value = True + + eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + ) + with pytest.raises(AirflowException) as exc: + generator = eks_delete_fargate_profile_trigger.run() + await generator.asend(None) + assert "Delete Fargate Profile failed - max attempts reached: 2" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_delete_fargate_profile_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_creating = WaiterError( + name="test_name", + reason="test_reason", + last_response={"fargateProfile": {"status": "DELETING"}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"fargateProfile": {"status": "DELETE_FAILED"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) + mock_sleep.return_value = True + + eks_delete_fargate_profile_trigger = EksDeleteFargateProfileTrigger( + cluster_name=TEST_CLUSTER_IDENTIFIER, + fargate_profile_name=TEST_FARGATE_PROFILE_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPTS, + ) + with pytest.raises(AirflowException) as exc: + generator = eks_delete_fargate_profile_trigger.run() + await generator.asend(None) + assert f"Delete Fargate Profile failed: {error_failed}" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 3