Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Deferrable implementation for RedshiftPauseClusterOperator to follow standard #30853

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 31 additions & 36 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException
Expand All @@ -25,6 +26,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -510,7 +512,9 @@ class RedshiftPauseClusterOperator(BaseOperator):

:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
:param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
:param deferrable: Run operator in the deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
:param max_attempts: Maximum number of attempts to poll the cluster
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -524,64 +528,55 @@ def __init__(
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
max_attempts: int = 15,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.max_attempts = max_attempts
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# These parameters are used to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
self._attempts = 10
self._remaining_attempts = 10
o-nikolas marked this conversation as resolved.
Show resolved Hide resolved
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
while self._remaining_attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._remaining_attempts = self._remaining_attempts - 1

if self._remaining_attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._remaining_attempts)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
trigger=RedshiftPauseClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60),
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error pausing cluster: {event}")
else:
raise AirflowException("No event received from trigger")
self.log.info("Paused cluster successfully")
return


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
73 changes: 73 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator

from botocore.exceptions import WaiterError

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -137,3 +140,73 @@ async def run(self):
},
)
yield TriggerEvent({"status": "success", "message": "Cluster Created"})


class RedshiftPauseClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftPauseClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `paused` state.

:param cluster_identifier: A unique identifier for the cluster.
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_identifier: str,
poll_interval: int,
max_attempts: int,
aws_conn_id: str,
):
self.cluster_identifier = cluster_identifier
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger",
{
"cluster_identifier": self.cluster_identifier,
"poll_interval": str(self.poll_interval),
"max_attempts": str(self.max_attempts),
"aws_conn_id": self.aws_conn_id,
},
)

@cached_property
def hook(self) -> RedshiftHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)

async def run(self):
async with self.hook.async_conn as client:
attempt = 0
waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client)
while attempt < int(self.max_attempts):
attempt = attempt + 1
try:
await waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": int(self.poll_interval),
"MaxAttempts": 1,
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
},
)
break
except WaiterError as error:
if "terminal failure" in str(error):
yield TriggerEvent({"status": "failure", "message": f"Pause Cluster Failed: {error}"})
break
self.log.info(
"Status of cluster is %s", error.last_response["Clusters"][0]["ClusterStatus"]
)
await asyncio.sleep(int(self.poll_interval))
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
if attempt >= int(self.max_attempts):
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
yield TriggerEvent(
{"status": "failure", "message": "Pause Cluster Failed - max attempts reached."}
)
else:
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
yield TriggerEvent({"status": "success", "message": "Cluster paused"})
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/redshift.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"cluster_paused": {
"operation": "DescribeClusters",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "pathAll",
"argument": "Clusters[].ClusterStatus",
"expected": "paused",
"state": "success"
},
{
"expected": "ClusterNotFound",
"argument": "Clusters[].ClusterStatus",
"matcher": "error",
"state": "retry"
},
{
"expected": "deleting",
"matcher": "pathAny",
"state": "failure",
"argument": "Clusters[].ClusterStatus"
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
}
]
}
}
}
14 changes: 9 additions & 5 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftPauseClusterTrigger,
)


class TestRedshiftCreateClusterOperator:
Expand Down Expand Up @@ -389,9 +392,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 10

def test_pause_cluster_deferrable_mode(self):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
def test_pause_cluster_deferrable_mode(self, mock_get_conn):
"""Test Pause cluster operator with defer when deferrable param is true"""

mock_get_conn().pause_cluster.return_value = True
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", deferrable=True
)
Expand All @@ -400,8 +404,8 @@ def test_pause_cluster_deferrable_mode(self):
redshift_operator.execute(context=None)

assert isinstance(
exc.value.trigger, RedshiftClusterTrigger
), "Trigger is not a RedshiftClusterTrigger"
exc.value.trigger, RedshiftPauseClusterTrigger
), "Trigger is not a RedshiftPauseClusterTrigger"

def test_pause_cluster_execute_complete_success(self):
"""Asserts that logging occurs as expected"""
Expand Down
Loading