Skip to content

Commit

Permalink
bugfix: handle invalid cluster states in NeptuneStopDbClusterOperator (
Browse files Browse the repository at this point in the history
…#38287)

* Cache s3 resource to reduce memory usage

* Fixed missing rebase changes and error in resource init

* catch invalid state exceptions and wait if not terminal status

* PR changes

* Update airflow/providers/amazon/aws/triggers/neptune.py

Co-authored-by: Niko Oliveira <onikolas@amazon.com>

* Update airflow/providers/amazon/aws/operators/neptune.py

Co-authored-by: Niko Oliveira <onikolas@amazon.com>

* PR review changes. Created helper function to deal with waitable ClientErrors

* added missing exception code

* Update airflow/providers/amazon/aws/operators/neptune.py

Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com>

* fix static checks

* Update airflow/providers/amazon/aws/operators/neptune.py

Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com>

* Update airflow/providers/amazon/aws/operators/neptune.py

Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com>

* minor PR changes

---------

Co-authored-by: Niko Oliveira <onikolas@amazon.com>
Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com>
Co-authored-by: Elad Kalif <45845474+eladkal@users.noreply.github.com>
Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com>
  • Loading branch information
5 people authored May 22, 2024
1 parent 791f3cf commit a78ee74
Show file tree
Hide file tree
Showing 6 changed files with 484 additions and 24 deletions.
37 changes: 36 additions & 1 deletion airflow/providers/amazon/aws/hooks/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class NeptuneHook(AwsBaseHook):

AVAILABLE_STATES = ["available"]
STOPPED_STATES = ["stopped"]
ERROR_STATES = [
"cloning-failed",
"inaccessible-encryption-credentials",
"inaccessible-encryption-credentials-recoverable",
"migration-failed",
]

def __init__(self, *args, **kwargs):
kwargs["client_type"] = "neptune"
Expand Down Expand Up @@ -82,4 +88,33 @@ def get_cluster_status(self, cluster_id: str) -> str:
:param cluster_id: The ID of the cluster to get the status of.
:return: The status of the cluster.
"""
return self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
return self.conn.describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]

def get_db_instance_status(self, instance_id: str) -> str:
"""
Get the status of a Neptune instance.
:param instance_id: The ID of the instance to get the status of.
:return: The status of the instance.
"""
return self.conn.describe_db_instances(DBInstanceIdentifier=instance_id)["DBInstances"][0][
"DBInstanceStatus"
]

def wait_for_cluster_instance_availability(
self, cluster_id: str, delay: int = 30, max_attempts: int = 60
) -> None:
"""
Wait for Neptune instances in a cluster to be available.
:param cluster_id: The cluster ID of the instances to wait for.
:param delay: Time in seconds to delay between polls.
:param max_attempts: Maximum number of attempts to poll for completion.
:return: The status of the instances.
"""
filters = [{"Name": "db-cluster-id", "Values": [cluster_id]}]
self.log.info("Waiting for instances in cluster %s.", cluster_id)
self.get_waiter("db_instance_available").wait(
Filters=filters, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
)
self.log.info("Finished waiting for instances in cluster %s.", cluster_id)
149 changes: 128 additions & 21 deletions airflow/providers/amazon/aws/operators/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

from typing import TYPE_CHECKING, Any, Sequence

from botocore.exceptions import ClientError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.neptune import (
NeptuneClusterAvailableTrigger,
NeptuneClusterInstancesAvailableTrigger,
NeptuneClusterStoppedTrigger,
)
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
Expand All @@ -32,6 +36,50 @@
from airflow.utils.context import Context


def handle_waitable_exception(
operator: NeptuneStartDbClusterOperator | NeptuneStopDbClusterOperator, err: str
):
"""
Handle client exceptions for invalid cluster or invalid instance status that are temporary.
After status change, it's possible to retry. Waiter will handle terminal status.
"""
code = err

if code in ("InvalidDBInstanceStateFault", "InvalidDBInstanceState"):
if operator.deferrable:
operator.log.info("Deferring until instances become available: %s", operator.cluster_id)
operator.defer(
trigger=NeptuneClusterInstancesAvailableTrigger(
aws_conn_id=operator.aws_conn_id,
db_cluster_id=operator.cluster_id,
region_name=operator.region_name,
botocore_config=operator.botocore_config,
verify=operator.verify,
),
method_name="execute",
)
else:
operator.log.info("Need to wait for instances to become available: %s", operator.cluster_id)
operator.hook.wait_for_cluster_instance_availability(cluster_id=operator.cluster_id)
if code in ["InvalidClusterState", "InvalidDBClusterStateFault"]:
if operator.deferrable:
operator.log.info("Deferring until cluster becomes available: %s", operator.cluster_id)
operator.defer(
trigger=NeptuneClusterAvailableTrigger(
aws_conn_id=operator.aws_conn_id,
db_cluster_id=operator.cluster_id,
region_name=operator.region_name,
botocore_config=operator.botocore_config,
verify=operator.verify,
),
method_name="execute",
)
else:
operator.log.info("Need to wait for cluster to become available: %s", operator.cluster_id)
operator.hook.wait_for_cluster_availability(operator.cluster_id)


class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
"""Starts an Amazon Neptune DB cluster.
Expand Down Expand Up @@ -78,20 +126,43 @@ def __init__(
self.cluster_id = db_cluster_id
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.delay = waiter_delay
self.max_attempts = waiter_max_attempts
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts

def execute(self, context: Context) -> dict[str, str]:
def execute(self, context: Context, event: dict[str, Any] | None = None, **kwargs) -> dict[str, str]:
self.log.info("Starting Neptune cluster: %s", self.cluster_id)

# Check to make sure the cluster is not already available.
status = self.hook.get_cluster_status(self.cluster_id)
if status.lower() in NeptuneHook.AVAILABLE_STATES:
self.log.info("Neptune cluster %s is already available.", self.cluster_id)
return {"db_cluster_id": self.cluster_id}

resp = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
status = resp.get("DBClusters", {}).get("Status", "Unknown")
elif status.lower() in NeptuneHook.ERROR_STATES:
# some states will not allow you to start the cluster
self.log.error(
"Neptune cluster %s is in error state %s and cannot be started", self.cluster_id, status
)
raise AirflowException(f"Neptune cluster {self.cluster_id} is in error state {status}")

"""
A cluster and its instances must be in a valid state to send the start request.
This loop covers the case where the cluster is not available and also the case where
the cluster is available, but one or more of the instances are in an invalid state.
If either are in an invalid state, wait for the availability and retry.
Let the waiters handle retries and detecting the error states.
"""
try:
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
except ClientError as ex:
code = ex.response["Error"]["Code"]
self.log.warning("Received client error when attempting to start the cluster: %s", code)

if code in ["InvalidDBInstanceState", "InvalidClusterState", "InvalidDBClusterStateFault"]:
handle_waitable_exception(operator=self, err=code)

else:
# re raise for any other type of client error
raise

if self.deferrable:
self.log.info("Deferring for cluster start: %s", self.cluster_id)
Expand All @@ -100,15 +171,17 @@ def execute(self, context: Context) -> dict[str, str]:
trigger=NeptuneClusterAvailableTrigger(
aws_conn_id=self.aws_conn_id,
db_cluster_id=self.cluster_id,
waiter_delay=self.delay,
waiter_max_attempts=self.max_attempts,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
)

elif self.wait_for_completion:
self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
self.hook.wait_for_cluster_availability(self.cluster_id, self.delay, self.max_attempts)
self.hook.wait_for_cluster_availability(
self.cluster_id, self.waiter_delay, self.waiter_max_attempts
)

return {"db_cluster_id": self.cluster_id}

Expand Down Expand Up @@ -171,20 +244,53 @@ def __init__(
self.cluster_id = db_cluster_id
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.delay = waiter_delay
self.max_attempts = waiter_max_attempts
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts

def execute(self, context: Context) -> dict[str, str]:
def execute(self, context: Context, event: dict[str, Any] | None = None, **kwargs) -> dict[str, str]:
self.log.info("Stopping Neptune cluster: %s", self.cluster_id)

# Check to make sure the cluster is not already stopped.
# Check to make sure the cluster is not already stopped or that its not in a bad state
status = self.hook.get_cluster_status(self.cluster_id)
self.log.info("Current status: %s", status)

if status.lower() in NeptuneHook.STOPPED_STATES:
self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
return {"db_cluster_id": self.cluster_id}

resp = self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
status = resp.get("DBClusters", {}).get("Status", "Unknown")
elif status.lower() in NeptuneHook.ERROR_STATES:
# some states will not allow you to stop the cluster
self.log.error(
"Neptune cluster %s is in error state %s and cannot be stopped", self.cluster_id, status
)
raise AirflowException(f"Neptune cluster {self.cluster_id} is in error state {status}")

"""
A cluster and its instances must be in a valid state to send the stop request.
This loop covers the case where the cluster is not available and also the case where
the cluster is available, but one or more of the instances are in an invalid state.
If either are in an invalid state, wait for the availability and retry.
Let the waiters handle retries and detecting the error states.
"""

try:
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)

# cluster must be in available state to stop it
except ClientError as ex:
code = ex.response["Error"]["Code"]
self.log.warning("Received client error when attempting to stop the cluster: %s", code)

# these can be handled by a waiter
if code in [
"InvalidDBInstanceState",
"InvalidDBInstanceStateFault",
"InvalidClusterState",
"InvalidDBClusterStateFault",
]:
handle_waitable_exception(self, code)
else:
# re raise for any other type of client error
raise

if self.deferrable:
self.log.info("Deferring for cluster stop: %s", self.cluster_id)
Expand All @@ -193,22 +299,23 @@ def execute(self, context: Context) -> dict[str, str]:
trigger=NeptuneClusterStoppedTrigger(
aws_conn_id=self.aws_conn_id,
db_cluster_id=self.cluster_id,
waiter_delay=self.delay,
waiter_max_attempts=self.max_attempts,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
)

elif self.wait_for_completion:
self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
self.hook.wait_for_cluster_stopped(self.cluster_id, self.delay, self.max_attempts)
self.log.info("Waiting for Neptune cluster %s to stop.", self.cluster_id)

self.hook.wait_for_cluster_stopped(self.cluster_id, self.waiter_delay, self.waiter_max_attempts)

return {"db_cluster_id": self.cluster_id}

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
status = ""
cluster_id = ""

self.log.info(event)
if event:
status = event.get("status", "")
cluster_id = event.get("cluster_id", "")
Expand Down
45 changes: 45 additions & 0 deletions airflow/providers/amazon/aws/triggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,48 @@ def hook(self) -> AwsGenericHook:
verify=self.verify,
config=self.botocore_config,
)


class NeptuneClusterInstancesAvailableTrigger(AwsBaseWaiterTrigger):
"""
Triggers when a Neptune Cluster Instance is available.
:param db_cluster_id: Cluster ID to wait on instances from
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: AWS region name (example: us-east-1)
"""

def __init__(
self,
*,
db_cluster_id: str,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
aws_conn_id: str | None = None,
region_name: str | None = None,
**kwargs,
) -> None:
super().__init__(
serialized_fields={"db_cluster_id": db_cluster_id},
waiter_name="db_instance_available",
waiter_args={"Filters": [{"Name": "db-cluster-id", "Values": [db_cluster_id]}]},
failure_message="Failed to start Neptune instances",
status_message="Status of Neptune instances are",
status_queries=["DBInstances[].Status"],
return_key="db_cluster_id",
return_value=db_cluster_id,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return NeptuneHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
8 changes: 8 additions & 0 deletions tests/providers/amazon/aws/hooks/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from typing import Generator
from unittest import mock

import pytest
from moto import mock_aws
Expand Down Expand Up @@ -50,3 +51,10 @@ def test_get_conn_returns_a_boto3_connection(self):

def test_get_cluster_status(self, neptune_hook: NeptuneHook, neptune_cluster_id):
assert neptune_hook.get_cluster_status(neptune_cluster_id) is not None

@mock.patch.object(NeptuneHook, "get_waiter")
def test_wait_for_cluster_instance_availability(
self, mock_get_waiter, neptune_hook: NeptuneHook, neptune_cluster_id
):
neptune_hook.wait_for_cluster_instance_availability(neptune_cluster_id)
mock_get_waiter.assert_called_once_with("db_instance_available")
Loading

0 comments on commit a78ee74

Please sign in to comment.