From de4d15ef2a7caab9a22f9af7699ce463ff5d3de9 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Mon, 6 Mar 2023 00:25:08 -0800 Subject: [PATCH 01/12] Change base_aws.py to support async_conn Add async custom waiter support in get_waiter, and base_waiter.py Add Deferrable mode to RedshiftCreateClusterOperator Add RedshiftCreateClusterTrigger and unit test Add README.md for writing Triggers for AMPP --- airflow/providers/amazon/aws/hooks/base_aws.py | 9 +++++++++ .../providers/amazon/aws/triggers/redshift_cluster.py | 2 +- airflow/providers/amazon/aws/waiters/base_waiter.py | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 541ef37d312a2..5b29c7c0dedd1 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -42,6 +42,7 @@ import jinja2 import requests import tenacity +from aiobotocore.session import AioSession, get_session as async_get_session from botocore.client import ClientMeta from botocore.config import Config from botocore.credentials import ReadOnlyCredentials @@ -658,6 +659,14 @@ def async_conn(self): return self.get_client_type(region_name=self.region_name, deferrable=True) + @cached_property + def async_conn(self): + """Get an Aiobotocore client to use for async operations (cached).""" + if not self.client_type: + raise ValueError("client_type must be specified.") + + return self.get_client_type(region_name=self.region_name, deferrable=True) + @cached_property def conn_client_meta(self) -> ClientMeta: """Get botocore client metadata from Hook connection (cached).""" diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index ef19d0b5a1d66..53704191c8e71 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -21,7 +21,7 @@ from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent - +from typing import Any class RedshiftClusterTrigger(BaseTrigger): """AWS Redshift trigger""" diff --git a/airflow/providers/amazon/aws/waiters/base_waiter.py b/airflow/providers/amazon/aws/waiters/base_waiter.py index 488767a084a21..0662c049a96f0 100644 --- a/airflow/providers/amazon/aws/waiters/base_waiter.py +++ b/airflow/providers/amazon/aws/waiters/base_waiter.py @@ -18,6 +18,7 @@ from __future__ import annotations import boto3 +from aiobotocore.waiter import create_waiter_with_client as create_async_waiter_with_client from botocore.waiter import Waiter, WaiterModel, create_waiter_with_client From 72eda44d7fc2a2e73afa3ae44d97bd06938b112f Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 22 Mar 2023 03:26:40 -0700 Subject: [PATCH 02/12] Add Deferrable mode to Redshift Pause Cluster Operator --- .../amazon/aws/operators/redshift_cluster.py | 62 +++++++------------ .../amazon/aws/triggers/redshift_cluster.py | 50 ++++++++++++++- .../amazon/aws/waiters/redshift.json | 29 +++++++++ .../aws/triggers/test_redshift_cluster.py | 44 ++++++++++++- 4 files changed, 145 insertions(+), 40 deletions(-) create mode 100644 airflow/providers/amazon/aws/waiters/redshift.json diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 77ac521c9baf6..364ca037ff666 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftClusterTrigger, RedshiftCreateClusterTrigger, + RedshiftPauseClusterTrigger, ) if TYPE_CHECKING: @@ -520,64 +521,49 @@ def __init__( aws_conn_id: str = "aws_default", deferrable: bool = False, poll_interval: int = 10, + max_attempts: int = 15, **kwargs, ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id self.deferrable = deferrable - self.poll_interval = poll_interval - # These parameters are added to address an issue with the boto3 API where the API + self.max_attempts = max_attempts + # These parameters are used to address an issue with the boto3 API where the API # prematurely reports the cluster as available to receive requests. This causes the cluster # to reject initial attempts to pause the cluster despite reporting the correct state. - self._attempts = 10 - self._attempt_interval = 15 + self.poll_interval = poll_interval + self._attempts = max_attempts def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + while self._attempts >= 1: + try: + redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) + break + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._attempts = self._attempts - 1 + if self._attempts > 0: + self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) + time.sleep(self._attempt_interval) + else: + raise error if self.deferrable: self.defer( - timeout=self.execution_timeout, - trigger=RedshiftClusterTrigger( - task_id=self.task_id, + trigger=RedshiftPauseClusterTrigger( + cluster_identifier=self.cluster_identifier, poll_interval=self.poll_interval, + max_attempt=self.max_attempts, aws_conn_id=self.aws_conn_id, - cluster_identifier=self.cluster_identifier, - attempts=self._attempts, - operation_type="pause_cluster", ), method_name="execute_complete", ) - else: - while self._attempts >= 1: - try: - redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) - return - except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._attempts = self._attempts - 1 - - if self._attempts > 0: - self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) - time.sleep(self._attempt_interval) - else: - raise error - def execute_complete(self, context: Context, event: Any = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - msg = f"{event['status']}: {event['message']}" - raise AirflowException(msg) - elif "status" in event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - self.log.info("Paused cluster successfully") - else: - raise AirflowException("No event received from trigger") + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error resuming cluster: {event}") + return class RedshiftDeleteClusterOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 53704191c8e71..d6b8ebc5bbc9b 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -21,7 +21,6 @@ from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent -from typing import Any class RedshiftClusterTrigger(BaseTrigger): """AWS Redshift trigger""" @@ -137,3 +136,52 @@ async def run(self): }, ) yield TriggerEvent({"status": "success", "message": "Cluster Created"}) + + +class RedshiftPauseClusterTrigger(BaseTrigger): + """ + Trigger for RedshiftPauseClusterOperator. + The trigger will asynchronously poll the boto3 API and wait for the + Redshift cluster to be in the `paused` state. + + :param cluster_identifier: A unique identifier for the cluster. + :param poll_interval: The amount of time in seconds to wait between attempts. + :param max_attempt: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + cluster_identifier: str, + poll_interval: int, + max_attempt: int, + aws_conn_id: str, + ): + self.cluster_identifier = cluster_identifier + self.poll_interval = poll_interval + self.max_attempt = max_attempt + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger", + { + "cluster_identifier": str(self.cluster_identifier), + "poll_interval": str(self.poll_interval), + "max_attempt": str(self.max_attempt), + "aws_conn_id": str(self.aws_conn_id), + }, + ) + + async def run(self): + self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + async with self.redshift_hook.async_conn as client: + waiter = self.redshift_hook.get_waiter("cluster_paused", deferrable=True, client=client) + await waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": int(self.max_attempt), + }, + ) + yield TriggerEvent({"status": "success", "message": "Cluster paused"}) diff --git a/airflow/providers/amazon/aws/waiters/redshift.json b/airflow/providers/amazon/aws/waiters/redshift.json new file mode 100644 index 0000000000000..748f3c6cf6534 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/redshift.json @@ -0,0 +1,29 @@ +{ + "version": 2, + "waiters": { + "cluster_paused": { + "operation": "DescribeClusters", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "pathAll", + "argument": "Clusters[].ClusterStatus", + "expected": "paused", + "state": "success" + }, + { + "expected": "ClusterNotFound", + "matcher": "error", + "state": "retry" + }, + { + "expected": "deleting", + "matcher": "pathAny", + "state": "failure", + "argument": "Clusters[].ClusterStatus" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 941258659e9ae..58773cba6e7d2 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -20,7 +20,10 @@ import pytest -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftCreateClusterTrigger, + RedshiftPauseClusterTrigger, +) from airflow.triggers.base import TriggerEvent if sys.version_info < (3, 8): @@ -72,3 +75,42 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "message": "Cluster Created"}) + + +class TestRedshiftPauseClusterTrigger: + def test_redshift_resume_cluster_trigger_serialize(self): + redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempt=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + class_path, args = redshift_resume_cluster_trigger.serialize() + assert ( + class_path == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger" + ) + assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER + assert args["poll_interval"] == str(TEST_POLL_INTERVAL) + assert args["max_attempt"] == str(TEST_MAX_ATTEMPT) + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempt=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) From 1190c2055db8758e1188dfe289e075bfa9d15a21 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Mon, 24 Apr 2023 23:48:11 -0700 Subject: [PATCH 03/12] Rebase with main Fix tests to work with new code --- airflow/providers/amazon/aws/hooks/base_aws.py | 9 --------- .../amazon/aws/operators/redshift_cluster.py | 9 ++++++--- .../amazon/aws/triggers/redshift_cluster.py | 10 +++++++--- .../providers/amazon/aws/waiters/base_waiter.py | 1 - airflow/providers/amazon/aws/waiters/redshift.json | 1 + .../amazon/aws/operators/test_redshift_cluster.py | 14 +++++++++----- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 5b29c7c0dedd1..541ef37d312a2 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -42,7 +42,6 @@ import jinja2 import requests import tenacity -from aiobotocore.session import AioSession, get_session as async_get_session from botocore.client import ClientMeta from botocore.config import Config from botocore.credentials import ReadOnlyCredentials @@ -659,14 +658,6 @@ def async_conn(self): return self.get_client_type(region_name=self.region_name, deferrable=True) - @cached_property - def async_conn(self): - """Get an Aiobotocore client to use for async operations (cached).""" - if not self.client_type: - raise ValueError("client_type must be specified.") - - return self.get_client_type(region_name=self.region_name, deferrable=True) - @cached_property def conn_client_meta(self) -> ClientMeta: """Get botocore client metadata from Hook connection (cached).""" diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 364ca037ff666..25a5e7e18e042 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -529,11 +529,12 @@ def __init__( self.aws_conn_id = aws_conn_id self.deferrable = deferrable self.max_attempts = max_attempts + self.poll_interval = poll_interval # These parameters are used to address an issue with the boto3 API where the API # prematurely reports the cluster as available to receive requests. This causes the cluster # to reject initial attempts to pause the cluster despite reporting the correct state. - self.poll_interval = poll_interval - self._attempts = max_attempts + self._attempts = 10 + self._attempt_interval = 15 def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) @@ -562,7 +563,9 @@ def execute(self, context: Context): def execute_complete(self, context, event=None): if event["status"] != "success": - raise AirflowException(f"Error resuming cluster: {event}") + raise AirflowException(f"Error pausing cluster: {event}") + else: + self.log.info("Paused cluster successfully") return diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index d6b8ebc5bbc9b..8d71732c2ac68 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -22,6 +22,7 @@ from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent + class RedshiftClusterTrigger(BaseTrigger): """AWS Redshift trigger""" @@ -173,10 +174,13 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @cached_property + def hook(self) -> RedshiftHook: + return RedshiftHook(aws_conn_id=self.aws_conn_id) + async def run(self): - self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - async with self.redshift_hook.async_conn as client: - waiter = self.redshift_hook.get_waiter("cluster_paused", deferrable=True, client=client) + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) await waiter.wait( ClusterIdentifier=self.cluster_identifier, WaiterConfig={ diff --git a/airflow/providers/amazon/aws/waiters/base_waiter.py b/airflow/providers/amazon/aws/waiters/base_waiter.py index 0662c049a96f0..488767a084a21 100644 --- a/airflow/providers/amazon/aws/waiters/base_waiter.py +++ b/airflow/providers/amazon/aws/waiters/base_waiter.py @@ -18,7 +18,6 @@ from __future__ import annotations import boto3 -from aiobotocore.waiter import create_waiter_with_client as create_async_waiter_with_client from botocore.waiter import Waiter, WaiterModel, create_waiter_with_client diff --git a/airflow/providers/amazon/aws/waiters/redshift.json b/airflow/providers/amazon/aws/waiters/redshift.json index 748f3c6cf6534..28170a38f990a 100644 --- a/airflow/providers/amazon/aws/waiters/redshift.json +++ b/airflow/providers/amazon/aws/waiters/redshift.json @@ -14,6 +14,7 @@ }, { "expected": "ClusterNotFound", + "argument": "Clusters[].ClusterStatus", "matcher": "error", "state": "retry" }, diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 64a276f14d02c..0fc62018fbedf 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -31,7 +31,10 @@ RedshiftPauseClusterOperator, RedshiftResumeClusterOperator, ) -from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, + RedshiftPauseClusterTrigger, +) class TestRedshiftCreateClusterOperator: @@ -377,9 +380,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): redshift_operator.execute(None) assert mock_conn.pause_cluster.call_count == 10 - def test_pause_cluster_deferrable_mode(self): + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") + def test_pause_cluster_deferrable_mode(self, mock_get_conn): """Test Pause cluster operator with defer when deferrable param is true""" - + mock_get_conn().pause_cluster.return_value = True redshift_operator = RedshiftPauseClusterOperator( task_id="task_test", cluster_identifier="test_cluster", deferrable=True ) @@ -388,8 +392,8 @@ def test_pause_cluster_deferrable_mode(self): redshift_operator.execute(context=None) assert isinstance( - exc.value.trigger, RedshiftClusterTrigger - ), "Trigger is not a RedshiftClusterTrigger" + exc.value.trigger, RedshiftPauseClusterTrigger + ), "Trigger is not a RedshiftPauseClusterTrigger" def test_pause_cluster_execute_complete_success(self): """Asserts that logging occurs as expected""" From 09b196fc885cf7133074d5c8806f6c2ec3da9a48 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 26 Apr 2023 08:57:40 -0700 Subject: [PATCH 04/12] update doc string --- airflow/providers/amazon/aws/operators/redshift_cluster.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 25a5e7e18e042..a5f4f7103d0fb 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -507,7 +507,9 @@ class RedshiftPauseClusterOperator(BaseOperator): :param cluster_identifier: id of the AWS Redshift Cluster :param aws_conn_id: aws connection to use - :param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>= + :param deferrable: Run operator in the deferrable mode + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state + :param max_attempts: Maximum number of attempts to poll the cluster """ template_fields: Sequence[str] = ("cluster_identifier",) From 690c5b91ee98bb8cdd4e0331d2f7b03e8d0cfd87 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 3 May 2023 15:34:56 -0700 Subject: [PATCH 05/12] Add logging to deferrable waiter Add unit tests for deferrable waiters --- .../amazon/aws/operators/redshift_cluster.py | 2 +- .../amazon/aws/triggers/redshift_cluster.py | 42 +++++++---- .../aws/triggers/test_redshift_cluster.py | 69 ++++++++++++++++++- 3 files changed, 97 insertions(+), 16 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index a5f4f7103d0fb..3c57143923e51 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -557,7 +557,7 @@ def execute(self, context: Context): trigger=RedshiftPauseClusterTrigger( cluster_identifier=self.cluster_identifier, poll_interval=self.poll_interval, - max_attempt=self.max_attempts, + max_attempts=self.max_attempts, aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 8d71732c2ac68..e83b8c4c14e38 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -16,8 +16,11 @@ # under the License. from __future__ import annotations +import asyncio from typing import Any, AsyncIterator +from botocore.exceptions import WaiterError + from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -147,7 +150,7 @@ class RedshiftPauseClusterTrigger(BaseTrigger): :param cluster_identifier: A unique identifier for the cluster. :param poll_interval: The amount of time in seconds to wait between attempts. - :param max_attempt: The maximum number of attempts to be made. + :param max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -155,12 +158,12 @@ def __init__( self, cluster_identifier: str, poll_interval: int, - max_attempt: int, + max_attempts: int, aws_conn_id: str, ): self.cluster_identifier = cluster_identifier self.poll_interval = poll_interval - self.max_attempt = max_attempt + self.max_attempts = max_attempts self.aws_conn_id = aws_conn_id def serialize(self) -> tuple[str, dict[str, Any]]: @@ -169,7 +172,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "cluster_identifier": str(self.cluster_identifier), "poll_interval": str(self.poll_interval), - "max_attempt": str(self.max_attempt), + "max_attempts": str(self.max_attempts), "aws_conn_id": str(self.aws_conn_id), }, ) @@ -180,12 +183,27 @@ def hook(self) -> RedshiftHook: async def run(self): async with self.hook.async_conn as client: - waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) - await waiter.wait( - ClusterIdentifier=self.cluster_identifier, - WaiterConfig={ - "Delay": int(self.poll_interval), - "MaxAttempts": int(self.max_attempt), - }, + attempt = 0 + while attempt < int(self.max_attempts): + attempt = attempt + 1 + try: + waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) + await waiter.wait( + ClusterIdentifier=self.cluster_identifier, + WaiterConfig={ + "Delay": int(self.poll_interval), + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + self.log.info( + "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] + ) + await asyncio.sleep(int(self.poll_interval)) + if attempt >= int(self.max_attempts): + yield TriggerEvent( + {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} ) - yield TriggerEvent({"status": "success", "message": "Cluster paused"}) + else: + yield TriggerEvent({"status": "success", "message": "Cluster paused"}) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 58773cba6e7d2..130ba0dd65daa 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -19,6 +19,7 @@ import sys import pytest +from botocore.exceptions import WaiterError from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftCreateClusterTrigger, @@ -82,7 +83,7 @@ def test_redshift_resume_cluster_trigger_serialize(self): redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, - max_attempt=TEST_MAX_ATTEMPT, + max_attempts=TEST_MAX_ATTEMPT, aws_conn_id=TEST_AWS_CONN_ID, ) class_path, args = redshift_resume_cluster_trigger.serialize() @@ -91,7 +92,7 @@ def test_redshift_resume_cluster_trigger_serialize(self): ) assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER assert args["poll_interval"] == str(TEST_POLL_INTERVAL) - assert args["max_attempt"] == str(TEST_MAX_ATTEMPT) + assert args["max_attempts"] == str(TEST_MAX_ATTEMPT) assert args["aws_conn_id"] == TEST_AWS_CONN_ID @pytest.mark.asyncio @@ -106,7 +107,7 @@ async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_g redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, - max_attempt=TEST_MAX_ATTEMPT, + max_attempts=TEST_MAX_ATTEMPT, aws_conn_id=TEST_AWS_CONN_ID, ) @@ -114,3 +115,65 @@ async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_g response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_resume_cluster_trigger_run_multiple_attempts( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) + mock_sleep.return_value = True + + redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=2, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 2 + assert response == TriggerEvent( + {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} + ) From 74cddfbcffc764aecc2aa7585e333a5eb8a9d930 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Thu, 4 May 2023 12:21:33 -0700 Subject: [PATCH 06/12] Add check for failure early Add test for waiter failure --- .../amazon/aws/triggers/redshift_cluster.py | 13 +++++-- .../aws/triggers/test_redshift_cluster.py | 39 +++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index e83b8c4c14e38..071cafce03a57 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -170,8 +170,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger", { - "cluster_identifier": str(self.cluster_identifier), - "poll_interval": str(self.poll_interval), + "cluster_identifier": self.cluster_identifier, + "poll_interval": self.poll_interval, "max_attempts": str(self.max_attempts), "aws_conn_id": str(self.aws_conn_id), }, @@ -184,10 +184,10 @@ def hook(self) -> RedshiftHook: async def run(self): async with self.hook.async_conn as client: attempt = 0 + waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) while attempt < int(self.max_attempts): attempt = attempt + 1 try: - waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client) await waiter.wait( ClusterIdentifier=self.cluster_identifier, WaiterConfig={ @@ -197,13 +197,18 @@ async def run(self): ) break except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent( + {"status": "failure", "message": f"Resume Cluster Failed: {error}"} + ) + break self.log.info( "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] ) await asyncio.sleep(int(self.poll_interval)) if attempt >= int(self.max_attempts): yield TriggerEvent( - {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} + {"status": "failure", "message": "Pause Cluster Failed - max attempts reached."} ) else: yield TriggerEvent({"status": "success", "message": "Cluster paused"}) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 130ba0dd65daa..96129c25bfecb 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -177,3 +177,42 @@ async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( assert response == TriggerEvent( {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} ) + + @pytest.mark.asyncio + @async_mock.patch("asyncio.sleep") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") + async def test_redshift_resume_cluster_trigger_run_attempts_failed( + self, mock_async_conn, mock_get_waiter, mock_sleep + ): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + error_available = WaiterError( + name="test_name", + reason="Max attempts exceeded", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"Clusters": [{"ClusterStatus": "available"}]}, + ) + mock_get_waiter().wait.side_effect = AsyncMock( + side_effect=[error_available, error_available, error_failed] + ) + mock_sleep.return_value = True + + redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + cluster_identifier=TEST_CLUSTER_IDENTIFIER, + poll_interval=TEST_POLL_INTERVAL, + max_attempts=TEST_MAX_ATTEMPT, + aws_conn_id=TEST_AWS_CONN_ID, + ) + + generator = redshift_resume_cluster_trigger.run() + response = await generator.asend(None) + + assert mock_get_waiter().wait.call_count == 3 + assert response == TriggerEvent( + {"status": "failure", "message": f"Resume Cluster Failed: {error_failed}"} + ) From d3033ccbc074d3cd6ab53fd56de3c17b862a1dbb Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Thu, 4 May 2023 12:43:45 -0700 Subject: [PATCH 07/12] Fix broken tests --- airflow/providers/amazon/aws/triggers/redshift_cluster.py | 8 +++----- .../amazon/aws/triggers/test_redshift_cluster.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 071cafce03a57..879f027e35824 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -171,9 +171,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger", { "cluster_identifier": self.cluster_identifier, - "poll_interval": self.poll_interval, + "poll_interval": str(self.poll_interval), "max_attempts": str(self.max_attempts), - "aws_conn_id": str(self.aws_conn_id), + "aws_conn_id": self.aws_conn_id, }, ) @@ -198,9 +198,7 @@ async def run(self): break except WaiterError as error: if "terminal failure" in str(error): - yield TriggerEvent( - {"status": "failure", "message": f"Resume Cluster Failed: {error}"} - ) + yield TriggerEvent({"status": "failure", "message": f"Pause Cluster Failed: {error}"}) break self.log.info( "Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"] diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 96129c25bfecb..179e5126a50d1 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -175,7 +175,7 @@ async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( assert mock_get_waiter().wait.call_count == 2 assert response == TriggerEvent( - {"status": "failure", "message": "Resume Cluster Failed - max attempts reached."} + {"status": "failure", "message": "Pause Cluster Failed - max attempts reached."} ) @pytest.mark.asyncio @@ -214,5 +214,5 @@ async def test_redshift_resume_cluster_trigger_run_attempts_failed( assert mock_get_waiter().wait.call_count == 3 assert response == TriggerEvent( - {"status": "failure", "message": f"Resume Cluster Failed: {error_failed}"} + {"status": "failure", "message": f"Pause Cluster Failed: {error_failed}"} ) From c4afb07967f58bb5d1de06160f6a7ef3749dbb3e Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Sun, 7 May 2023 21:34:58 -0700 Subject: [PATCH 08/12] Add timeout to Trigger --- airflow/providers/amazon/aws/operators/redshift_cluster.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 8e7848cf95530..b4acba30301de 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations +from datetime import timedelta import time from typing import TYPE_CHECKING, Any, Sequence @@ -565,6 +566,9 @@ def execute(self, context: Context): aws_conn_id=self.aws_conn_id, ), method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.max_attempts*self.poll_interval + 60), ) def execute_complete(self, context, event=None): From 2b681265498a1534a3d150649182edfef2cb8bdf Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Mon, 8 May 2023 01:21:36 -0700 Subject: [PATCH 09/12] Fix Static checks --- airflow/providers/amazon/aws/operators/redshift_cluster.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index b4acba30301de..0690241bea5d9 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations -from datetime import timedelta import time +from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException @@ -568,7 +568,7 @@ def execute(self, context: Context): method_name="execute_complete", # timeout is set to ensure that if a trigger dies, the timeout does not restart # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) - timeout=timedelta(seconds=self.max_attempts*self.poll_interval + 60), + timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60), ) def execute_complete(self, context, event=None): From cc6eb9b73dc817b64c3f5512ab14a0b1e47f28b2 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 9 May 2023 11:35:14 -0700 Subject: [PATCH 10/12] rename variable --- .../providers/amazon/aws/operators/redshift_cluster.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 0690241bea5d9..de1438e93ca29 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -540,20 +540,20 @@ def __init__( # These parameters are used to address an issue with the boto3 API where the API # prematurely reports the cluster as available to receive requests. This causes the cluster # to reject initial attempts to pause the cluster despite reporting the correct state. - self._attempts = 10 + self._remaining_attempts = 10 self._attempt_interval = 15 def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - while self._attempts >= 1: + while self._remaining_attempts >= 1: try: redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) break except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: - self._attempts = self._attempts - 1 + self._remaining_attempts = self._remaining_attempts - 1 - if self._attempts > 0: - self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) + if self._remaining_attempts > 0: + self.log.error("Unable to pause cluster. %d attempts remaining.", self._remaining_attempts) time.sleep(self._attempt_interval) else: raise error From ed511dcfdf804149da4a3510d6ff458ab2c91fd3 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 9 May 2023 12:47:52 -0700 Subject: [PATCH 11/12] Fix static checks --- airflow/providers/amazon/aws/operators/redshift_cluster.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index de1438e93ca29..6b44785da621e 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -553,7 +553,9 @@ def execute(self, context: Context): self._remaining_attempts = self._remaining_attempts - 1 if self._remaining_attempts > 0: - self.log.error("Unable to pause cluster. %d attempts remaining.", self._remaining_attempts) + self.log.error( + "Unable to pause cluster. %d attempts remaining.", self._remaining_attempts + ) time.sleep(self._attempt_interval) else: raise error From 6bea10b21015a02a4450718b598f021c8bfbb6a7 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 9 May 2023 13:54:21 -0700 Subject: [PATCH 12/12] Use patch.object for patching in unit tests --- .../amazon/aws/waiters/redshift.json | 10 ++-- .../aws/operators/test_redshift_cluster.py | 2 +- .../aws/triggers/test_redshift_cluster.py | 47 ++++++++++--------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/airflow/providers/amazon/aws/waiters/redshift.json b/airflow/providers/amazon/aws/waiters/redshift.json index 28170a38f990a..587f8ce989702 100644 --- a/airflow/providers/amazon/aws/waiters/redshift.json +++ b/airflow/providers/amazon/aws/waiters/redshift.json @@ -13,16 +13,16 @@ "state": "success" }, { - "expected": "ClusterNotFound", - "argument": "Clusters[].ClusterStatus", "matcher": "error", + "argument": "Clusters[].ClusterStatus", + "expected": "ClusterNotFound", "state": "retry" }, { - "expected": "deleting", "matcher": "pathAny", - "state": "failure", - "argument": "Clusters[].ClusterStatus" + "argument": "Clusters[].ClusterStatus", + "expected": "deleting", + "state": "failure" } ] } diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 06b81270381a7..f4bb22d4b5905 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -392,7 +392,7 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): redshift_operator.execute(None) assert mock_conn.pause_cluster.call_count == 10 - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") + @mock.patch.object(RedshiftHook, "get_conn") def test_pause_cluster_deferrable_mode(self, mock_get_conn): """Test Pause cluster operator with defer when deferrable param is true""" mock_get_conn().pause_cluster.return_value = True diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py index 179e5126a50d1..2e7f6490d6846 100644 --- a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -21,6 +21,7 @@ import pytest from botocore.exceptions import WaiterError +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftCreateClusterTrigger, RedshiftPauseClusterTrigger, @@ -79,14 +80,14 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn): class TestRedshiftPauseClusterTrigger: - def test_redshift_resume_cluster_trigger_serialize(self): - redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + def test_redshift_pause_cluster_trigger_serialize(self): + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, max_attempts=TEST_MAX_ATTEMPT, aws_conn_id=TEST_AWS_CONN_ID, ) - class_path, args = redshift_resume_cluster_trigger.serialize() + class_path, args = redshift_pause_cluster_trigger.serialize() assert ( class_path == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger" ) @@ -96,31 +97,31 @@ def test_redshift_resume_cluster_trigger_serialize(self): assert args["aws_conn_id"] == TEST_AWS_CONN_ID @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") - async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run(self, mock_async_conn, mock_get_waiter): mock = async_mock.MagicMock() mock_async_conn.__aenter__.return_value = mock mock_get_waiter().wait = AsyncMock() - redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, max_attempts=TEST_MAX_ATTEMPT, aws_conn_id=TEST_AWS_CONN_ID, ) - generator = redshift_resume_cluster_trigger.run() + generator = redshift_pause_cluster_trigger.run() response = await generator.asend(None) assert response == TriggerEvent({"status": "success", "message": "Cluster paused"}) @pytest.mark.asyncio @async_mock.patch("asyncio.sleep") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") - async def test_redshift_resume_cluster_trigger_run_multiple_attempts( + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run_multiple_attempts( self, mock_async_conn, mock_get_waiter, mock_sleep ): mock = async_mock.MagicMock() @@ -133,14 +134,14 @@ async def test_redshift_resume_cluster_trigger_run_multiple_attempts( mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) mock_sleep.return_value = True - redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, max_attempts=TEST_MAX_ATTEMPT, aws_conn_id=TEST_AWS_CONN_ID, ) - generator = redshift_resume_cluster_trigger.run() + generator = redshift_pause_cluster_trigger.run() response = await generator.asend(None) assert mock_get_waiter().wait.call_count == 3 @@ -148,9 +149,9 @@ async def test_redshift_resume_cluster_trigger_run_multiple_attempts( @pytest.mark.asyncio @async_mock.patch("asyncio.sleep") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") - async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run_attempts_exceeded( self, mock_async_conn, mock_get_waiter, mock_sleep ): mock = async_mock.MagicMock() @@ -163,14 +164,14 @@ async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) mock_sleep.return_value = True - redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, max_attempts=2, aws_conn_id=TEST_AWS_CONN_ID, ) - generator = redshift_resume_cluster_trigger.run() + generator = redshift_pause_cluster_trigger.run() response = await generator.asend(None) assert mock_get_waiter().wait.call_count == 2 @@ -180,9 +181,9 @@ async def test_redshift_resume_cluster_trigger_run_attempts_exceeded( @pytest.mark.asyncio @async_mock.patch("asyncio.sleep") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter") - @async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn") - async def test_redshift_resume_cluster_trigger_run_attempts_failed( + @async_mock.patch.object(RedshiftHook, "get_waiter") + @async_mock.patch.object(RedshiftHook, "async_conn") + async def test_redshift_pause_cluster_trigger_run_attempts_failed( self, mock_async_conn, mock_get_waiter, mock_sleep ): mock = async_mock.MagicMock() @@ -202,14 +203,14 @@ async def test_redshift_resume_cluster_trigger_run_attempts_failed( ) mock_sleep.return_value = True - redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger( + redshift_pause_cluster_trigger = RedshiftPauseClusterTrigger( cluster_identifier=TEST_CLUSTER_IDENTIFIER, poll_interval=TEST_POLL_INTERVAL, max_attempts=TEST_MAX_ATTEMPT, aws_conn_id=TEST_AWS_CONN_ID, ) - generator = redshift_resume_cluster_trigger.run() + generator = redshift_pause_cluster_trigger.run() response = await generator.asend(None) assert mock_get_waiter().wait.call_count == 3