Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
Expand Down Expand Up @@ -105,7 +107,9 @@ class RedshiftCreateClusterOperator(AwsBaseOperator[RedshiftHook]):
:param wait_for_completion: Whether wait for the cluster to be in ``available`` state
:param max_attempt: The maximum number of attempts to be made. Default: 5
:param poll_interval: The amount of time in seconds to wait between attempts. Default: 60
:param deferrable: If True, the operator will run in deferrable mode
:param deferrable: If True, the operator will run in deferrable mode.
:param delete_cluster_on_failure: If True, best-effort deletion of the redshift cluster will be attempted
after post-creation failure. Default: True.
"""

template_fields: Sequence[str] = aws_template_fields(
Expand Down Expand Up @@ -188,6 +192,7 @@ def __init__(
max_attempt: int = 5,
poll_interval: int = 60,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
delete_cluster_on_failure: bool = True,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -229,6 +234,7 @@ def __init__(
self.poll_interval = poll_interval
self.deferrable = deferrable
self.kwargs = kwargs
self.delete_cluster_on_failure = delete_cluster_on_failure

def execute(self, context: Context):
self.log.info("Creating Redshift cluster %s", self.cluster_identifier)
Expand Down Expand Up @@ -311,17 +317,39 @@ def execute(self, context: Context):
),
method_name="execute_complete",
)
if self.wait_for_completion:
self.hook.get_conn().get_waiter("cluster_available").wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempt,
},
)

self.log.info("Created Redshift cluster %s", self.cluster_identifier)
self.log.info(cluster)
try:
if self.wait_for_completion:
self.hook.get_conn().get_waiter("cluster_available").wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempt,
},
)

self.log.info("Created Redshift cluster %s", self.cluster_identifier)
self.log.info(cluster)
except WaiterError:
# Best-effort cleanup when post-initiation steps fail (e.g. IAM/permission errors).
if cluster:
self.log.warning(
"Execution failed after Redshift cluster %s was started by this task instance.",
self.cluster_identifier,
)

if self.delete_cluster_on_failure:
try:
self.log.warning(
"Attempting deletion of Redshift cluster %s.", self.cluster_identifier
)
self.hook.delete_cluster(cluster_identifier=self.cluster_identifier)
except Exception:
self.log.exception(
"Failed while attempting to delete Reshift cluster %s.",
self.cluster_identifier,
)
raise

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
validated_event = validate_execute_complete_event(event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import boto3
import pytest
from botocore.exceptions import ClientError, WaiterError

from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
Expand Down Expand Up @@ -151,6 +152,96 @@ def test_template_fields(self):
)
validate_template_fields(operator)

@mock.patch.object(RedshiftHook, "delete_cluster")
@mock.patch.object(RedshiftHook, "conn")
def test_create_cluster_cleanup_on_waiter_auth_failure(
self,
mock_conn,
mock_delete_cluster,
):
# Simulate waiter failure (e.g. DescribeClusters denied).
waiter_error = WaiterError(
name="ClusterAvailable",
reason="AccessDenied for DescribeClusters",
last_response={},
)
mock_conn.get_waiter.return_value.wait.side_effect = waiter_error

operator = RedshiftCreateClusterOperator(
task_id="task_test",
cluster_identifier="test-cluster",
node_type="ra3.large",
master_username="adminuser",
master_user_password="Test123$",
cluster_type="single-node",
wait_for_completion=True,
delete_cluster_on_failure=True,
)

with pytest.raises(WaiterError):
operator.execute({})

# Cluster creation happened.
mock_conn.create_cluster.assert_called_once()

# Cleanup attempted.
mock_delete_cluster.assert_called_once_with(
cluster_identifier="test-cluster",
)

@mock.patch.object(RedshiftHook, "delete_cluster")
@mock.patch.object(RedshiftHook, "conn")
def test_create_cluster_cleanup_failure_does_not_mask_original_error(
self,
mock_conn,
mock_delete_cluster,
):
# Simulate waiter failure (e.g. DescribeClusters denied).
waiter_error = WaiterError(
name="ClusterAvailable",
reason="AccessDenied for DescribeClusters",
last_response={},
)

# Simulate deletion failure (e.g. DeleteCluster denied).
cleanup_error = ClientError(
error_response={
"Error": {
"Code": "UnauthorizedOperation",
"Message": "You are not authorized to perform this operation",
}
},
operation_name="DeleteCluster",
)

mock_conn.get_waiter.return_value.wait.side_effect = waiter_error
mock_delete_cluster.side_effect = cleanup_error

operator = RedshiftCreateClusterOperator(
task_id="task_test",
cluster_identifier="test-cluster",
node_type="ra3.large",
master_username="adminuser",
master_user_password="Test123$",
cluster_type="single-node",
wait_for_completion=True,
delete_cluster_on_failure=True,
)

with pytest.raises(WaiterError) as exc:
operator.execute({})

# Original exception must be preserved.
assert exc.value is waiter_error

# Cluster creation happened.
mock_conn.create_cluster.assert_called_once()

# Cleanup attempted despite failure.
mock_delete_cluster.assert_called_once_with(
cluster_identifier="test-cluster",
)


class TestRedshiftCreateClusterSnapshotOperator:
@mock.patch.object(RedshiftHook, "cluster_status")
Expand Down