diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 2880240b1532d..6b44785da621e 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -17,6 +17,7 @@ from __future__ import annotations import time +from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException @@ -25,6 +26,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftClusterTrigger, RedshiftCreateClusterTrigger, + RedshiftPauseClusterTrigger, ) if TYPE_CHECKING: @@ -510,7 +512,9 @@ class RedshiftPauseClusterOperator(BaseOperator): :param cluster_identifier: id of the AWS Redshift Cluster :param aws_conn_id: aws connection to use - :param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>= + :param deferrable: Run operator in the deferrable mode + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state + :param max_attempts: Maximum number of attempts to poll the cluster """ template_fields: Sequence[str] = ("cluster_identifier",) @@ -524,64 +528,57 @@ def __init__( aws_conn_id: str = "aws_default", deferrable: bool = False, poll_interval: int = 10, + max_attempts: int = 15, **kwargs, ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id self.deferrable = deferrable + self.max_attempts = max_attempts self.poll_interval = poll_interval - # These parameters are added to address an issue with the boto3 API where the API + # These parameters are used to address an issue with the boto3 API where the API # prematurely reports the cluster as available to receive requests. This causes the cluster # to reject initial attempts to pause the cluster despite reporting the correct state. - self._attempts = 10 + self._remaining_attempts = 10 self._attempt_interval = 15 def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + while self._remaining_attempts >= 1: + try: + redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) + break + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._remaining_attempts = self._remaining_attempts - 1 + if self._remaining_attempts > 0: + self.log.error( + "Unable to pause cluster. %d attempts remaining.", self._remaining_attempts + ) + time.sleep(self._attempt_interval) + else: + raise error if self.deferrable: self.defer( - timeout=self.execution_timeout, - trigger=RedshiftClusterTrigger( - task_id=self.task_id, + trigger=RedshiftPauseClusterTrigger( + cluster_identifier=self.cluster_identifier, poll_interval=self.poll_interval, + max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, - cluster_identifier=self.cluster_identifier, - attempts=self._attempts, - operation_type="pause_cluster", ), method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60), ) - else: - while self._attempts >= 1: - try: - redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) - return - except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._attempts = self._attempts - 1 - - if self._attempts > 0: - self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) - time.sleep(self._attempt_interval) - else: - raise error - def execute_complete(self, context: Context, event: Any = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - msg = f"{event['status']}: {event['message']}" - raise AirflowException(msg) - elif "status" in event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - self.log.info("Paused cluster successfully") + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error pausing cluster: {event}") else: - raise AirflowException("No event received from trigger") + self.log.info("Paused cluster successfully") + return class RedshiftDeleteClusterOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index ef19d0b5a1d66..879f027e35824 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -16,8 +16,11 @@ # under the License. from __future__ import annotations +import asyncio from typing import Any, AsyncIterator +from botocore.exceptions import WaiterError + from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -137,3 +140,73 @@ async def run(self): }, ) yield TriggerEvent({"status": "success", "message": "Cluster Created"}) + + +class RedshiftPauseClusterTrigger(BaseTrigger): + """ + Trigger for RedshiftPauseClusterOperator. + The trigger will asynchronously poll the boto3 API and wait for the + Redshift cluster to be in the `paused` state. + + :param cluster_identifier: A unique identifier for the cluster. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + cluster_identifier: str, + poll_interval: int, + max_attempts: int, + aws_conn_id: str, + ): + self.cluster_identifier = cluster_identifier + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger", + { + "cluster_identifier": self.cluster_identifier, + "poll_interval": str(self.poll_interval), + "max_attempts": str(self.max_attempts), + "aws_conn_id": self.aws_conn_id, + }, + ) + + @cached_property + def hook(self) -> RedshiftHook: + return RedshiftHook(aws_conn_id=self.aws_conn_id) + + async def run(self): + async with self.hook.async_conn as client: + attempt = 0 + waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) + while attempt < int(self.max_attempts): + attempt = attempt + 1 + try: + await waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent({"status": "failure", "message": f"Pause Cluster Failed: {error}"}) + break + self.log.info( + "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + yield TriggerEvent( + {"status": "failure", "message": "Pause Cluster Failed - max attempts reached."} + ) + else: + yield TriggerEvent({"status": "success", "message": "Cluster paused"}) diff --git a/airflow/providers/amazon/aws/waiters/redshift.json b/airflow/providers/amazon/aws/waiters/redshift.json new file mode 100644 index 0000000000000..587f8ce989702 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/redshift.json @@ -0,0 +1,30 @@ +{ + "version": 2, + "waiters": { + "cluster_paused": { + "operation": "DescribeClusters", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "pathAll", + "argument": "Clusters[].ClusterStatus", + "expected": "paused", + "state": "success" + }, + { + "matcher": "error", + "argument": "Clusters[].ClusterStatus", + "expected": "ClusterNotFound", + "state": "retry" + }, + { + "matcher": "pathAny", + "argument": "Clusters[].ClusterStatus", + "expected": "deleting", + "state": "failure" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index fca7dafdaad8c..f4bb22d4b5905 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -31,7 +31,10 @@ RedshiftPauseClusterOperator, RedshiftResumeClusterOperator, ) -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, + RedshiftPauseClusterTrigger, +) class TestRedshiftCreateClusterOperator: @@ -389,9 +392,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): redshift_operator.execute(None) assert mock_conn.pause_cluster.call_count == 10 - def test_pause_cluster_deferrable_mode(self): + @mock.patch.object(RedshiftHook, "get_conn") + def test_pause_cluster_deferrable_mode(self, mock_get_conn): """Test Pause cluster operator with defer when deferrable param is true""" - + mock_get_conn().pause_cluster.return_value = True redshift_operator = RedshiftPauseClusterOperator( task_id="task_test", cluster_identifier="test_cluster", deferrable=True ) @@ -400,8 +404,8 @@ def test_pause_cluster_deferrable_mode(self): redshift_operator.execute(context=None) assert isinstance( - exc.value.trigger, RedshiftClusterTrigger - ), "Trigger is not a RedshiftClusterTrigger" + exc.value.trigger, RedshiftPauseClusterTrigger + ), "Trigger is not a RedshiftPauseClusterTrigger" def test_pause_cluster_execute_complete_success(self): """Asserts that logging occurs as expected""" diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 941258659e9ae..2e7f6490d6846 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -19,8 +19,13 @@ import sys import pytest +from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftCreateClusterTrigger, + RedshiftPauseClusterTrigger, +) from airflow.triggers.base import TriggerEvent if sys.version_info < (3, 8): @@ -72,3 +77,143 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "message": "Cluster Created"}) + + +class TestRedshiftPauseClusterTrigger: + def test_redshift_pause_cluster_trigger_serialize(self): + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + class_path, args = redshift_pause_cluster_trigger.serialize() + assert ( + class_path == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger" + ) + assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_pause_cluster_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run_multiple_attempts( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_pause_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_pause_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 2 + assert response == TriggerEvent( + {"status": "failure", "message": "Pause Cluster Failed - max attempts reached."} + ) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run_attempts_failed( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error_available = WaiterError( + name="test_name", + reason="Max attempts exceeded", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_available, error_available, error_failed] + ) + mock_sleep.return_value = True + + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_pause_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + {"status": "failure", "message": f"Pause Cluster Failed: {error_failed}"} + )