Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Deferrable implementation for RedshiftPauseClusterOperator to follow standard #30853

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 23 additions & 34 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -520,64 +521,52 @@ def __init__(
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
max_attempts: int = 15,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.max_attempts = max_attempts
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# These parameters are used to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
self._attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
while self._attempts >= 1:
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
trigger=RedshiftPauseClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempt=self.max_attempts,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error pausing cluster: {event}")
else:
raise AirflowException("No event received from trigger")
self.log.info("Paused cluster successfully")
return


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
52 changes: 52 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,55 @@ async def run(self):
},
)
yield TriggerEvent({"status": "success", "message": "Cluster Created"})


class RedshiftPauseClusterTrigger(BaseTrigger):
"""
Trigger for RedshiftPauseClusterOperator.
The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `paused` state.

:param cluster_identifier: A unique identifier for the cluster.
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempt: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
cluster_identifier: str,
poll_interval: int,
max_attempt: int,
aws_conn_id: str,
):
self.cluster_identifier = cluster_identifier
self.poll_interval = poll_interval
self.max_attempt = max_attempt
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger",
{
"cluster_identifier": str(self.cluster_identifier),
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
"poll_interval": str(self.poll_interval),
"max_attempt": str(self.max_attempt),
"aws_conn_id": str(self.aws_conn_id),
},
)

@cached_property
def hook(self) -> RedshiftHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)

async def run(self):
async with self.hook.async_conn as client:
waiter = self.hook.get_waiter("cluster_paused", deferrable=True, client=client)
await waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": int(self.poll_interval),
"MaxAttempts": int(self.max_attempt),
},
)
yield TriggerEvent({"status": "success", "message": "Cluster paused"})
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/redshift.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"cluster_paused": {
"operation": "DescribeClusters",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "pathAll",
"argument": "Clusters[].ClusterStatus",
"expected": "paused",
"state": "success"
},
{
"expected": "ClusterNotFound",
"argument": "Clusters[].ClusterStatus",
"matcher": "error",
"state": "retry"
},
{
"expected": "deleting",
"matcher": "pathAny",
"state": "failure",
"argument": "Clusters[].ClusterStatus"
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
}
]
}
}
}
14 changes: 9 additions & 5 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftPauseClusterTrigger,
)


class TestRedshiftCreateClusterOperator:
Expand Down Expand Up @@ -377,9 +380,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 10

def test_pause_cluster_deferrable_mode(self):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
def test_pause_cluster_deferrable_mode(self, mock_get_conn):
"""Test Pause cluster operator with defer when deferrable param is true"""

mock_get_conn().pause_cluster.return_value = True
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", deferrable=True
)
Expand All @@ -388,8 +392,8 @@ def test_pause_cluster_deferrable_mode(self):
redshift_operator.execute(context=None)

assert isinstance(
exc.value.trigger, RedshiftClusterTrigger
), "Trigger is not a RedshiftClusterTrigger"
exc.value.trigger, RedshiftPauseClusterTrigger
), "Trigger is not a RedshiftPauseClusterTrigger"

def test_pause_cluster_execute_complete_success(self):
"""Asserts that logging occurs as expected"""
Expand Down
44 changes: 43 additions & 1 deletion tests/providers/amazon/aws/triggers/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

import pytest

from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftCreateClusterTrigger
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftCreateClusterTrigger,
RedshiftPauseClusterTrigger,
)
from airflow.triggers.base import TriggerEvent

if sys.version_info < (3, 8):
Expand Down Expand Up @@ -72,3 +75,42 @@ async def test_redshift_create_cluster_trigger_run(self, mock_async_conn):
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster Created"})


class TestRedshiftPauseClusterTrigger:
def test_redshift_resume_cluster_trigger_serialize(self):
redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempt=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)
class_path, args = redshift_resume_cluster_trigger.serialize()
assert (
class_path == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftPauseClusterTrigger"
)
assert args["cluster_identifier"] == TEST_CLUSTER_IDENTIFIER
assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
assert args["max_attempt"] == str(TEST_MAX_ATTEMPT)
assert args["aws_conn_id"] == TEST_AWS_CONN_ID

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_waiter")
@async_mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.async_conn")
syedahsn marked this conversation as resolved.
Show resolved Hide resolved
async def test_redshift_resume_cluster_trigger_run(self, mock_async_conn, mock_get_waiter):
mock = async_mock.MagicMock()
mock_async_conn.__aenter__.return_value = mock

mock_get_waiter().wait = AsyncMock()

redshift_resume_cluster_trigger = RedshiftPauseClusterTrigger(
cluster_identifier=TEST_CLUSTER_IDENTIFIER,
poll_interval=TEST_POLL_INTERVAL,
max_attempt=TEST_MAX_ATTEMPT,
aws_conn_id=TEST_AWS_CONN_ID,
)

generator = redshift_resume_cluster_trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "message": "Cluster paused"})