Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RdsDbSensor to amazon provider package #26003

Merged
merged 17 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions airflow/providers/amazon/aws/sensors/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optiona
super().__init__(*args, **kwargs)

def _describe_item(self, item_type: str, item_name: str) -> list:

if item_type == 'instance_snapshot':
db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
return db_snaps['DBSnapshots']
Expand All @@ -51,17 +50,29 @@ def _describe_item(self, item_type: str, item_name: str) -> list:
elif item_type == 'export_task':
exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
return exports['ExportTasks']
elif item_type == "db_instance":
instances = self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name)
return instances["DBInstances"]
elif item_type == "db_cluster":
Copy link
Contributor Author

@hankehly hankehly Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022/09/05 update
Added db_cluster to be consistent with #21231 / #20907

clusters = self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name)
return clusters["DBClusters"]
else:
raise AirflowException(f"Method for {item_type} is not implemented")

def _check_item(self, item_type: str, item_name: str) -> bool:
"""Get certain item from `_describe_item()` and check its status"""
if item_type == "db_instance":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update 2022/09/06
#26003 (comment)

status_field = "DBInstanceStatus"
else:
status_field = "Status"
try:
items = self._describe_item(item_type, item_name)
except ClientError:
return False
else:
return bool(items) and any(map(lambda s: items[0]['Status'].lower() == s, self.target_statuses))
return bool(items) and any(
map(lambda status: items[0][status_field].lower() == status, self.target_statuses)
)
Comment on lines +73 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the return could be modified from map to the generator if you wish.

Suggested change
return bool(items) and any(
map(lambda status: items[0][status_field].lower() == status, self.target_statuses)
)
return bool(items) and any(items[0][status_field].lower() == status for status in self.target_statuses)

For me, both variants are well readable but maybe this one is better. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions. I'd like to make these updates in a separate PR.
The formatting is the result of black (line too long)



class RdsSnapshotExistenceSensor(RdsBaseSensor):
Expand Down Expand Up @@ -149,7 +160,48 @@ def poke(self, context: 'Context'):
return self._check_item(item_type='export_task', item_name=self.export_task_identifier)


class RdsDbSensor(RdsBaseSensor):
Copy link
Contributor Author

@hankehly hankehly Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022/09/05 update
Renamed RdsInstanceSensor to RdsDbSensor because it considers "cluster" databases as well

"""
Waits for an RDS instance or cluster to enter one of a number of states

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:RdsDbSensor`

:param db_type: Type of the DB - either "instance" or "cluster"
:param db_identifier: The AWS identifier for the DB
:param target_statuses: Target status of DB
"""

def __init__(
self,
*,
db_identifier: str,
db_type: str = "instance",
target_statuses: Optional[List[str]] = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.db_identifier = db_identifier
self.target_statuses = target_statuses or ["available"]
self.db_type = RdsDbType(db_type)

def poke(self, context: 'Context'):
self.log.info(
"Poking for statuses : %s\nfor db instance %s", self.target_statuses, self.db_identifier
)
item_type = self._check_item_type()
return self._check_item(item_type=item_type, item_name=self.db_identifier)

def _check_item_type(self):
if self.db_type == RdsDbType.CLUSTER:
return "db_cluster"
return "db_instance"


__all__ = [
"RdsExportTaskExistenceSensor",
"RdsDbSensor",
"RdsSnapshotExistenceSensor",
]
17 changes: 17 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/rds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,22 @@ To delete a AWS DB instance you can use
Sensors
-------

.. _howto/sensor:RdsDbSensor:

Wait on an Amazon RDS instance or cluster status
================================================

To wait for an Amazon RDS instance or cluster to reach a specific status you can use
:class:`~airflow.providers.amazon.aws.sensors.rds.RdsDbSensor`.
By default, the sensor waits for a database instance to reach the ``available`` state.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_rds_instance]
:end-before: [END howto_sensor_rds_instance]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rds-instance-sensor

Copy link
Contributor Author

@hankehly hankehly Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022/09/05 update

Screen Shot 2022-09-05 at 13 13 54


.. _howto/sensor:RdsSnapshotExistenceSensor:

Wait on an Amazon RDS snapshot status
Expand Down Expand Up @@ -204,3 +220,4 @@ Reference
---------

* `AWS boto3 library documentation for RDS <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html>`__
* `RDS DB instance statuses <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html>`__
84 changes: 78 additions & 6 deletions tests/providers/amazon/aws/sensors/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.sensors.rds import (
RdsBaseSensor,
RdsDbSensor,
RdsExportTaskExistenceSensor,
RdsSnapshotExistenceSensor,
)
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.utils import timezone

try:
Expand All @@ -50,15 +52,18 @@
EXPORT_TASK_SOURCE = 'arn:aws:rds:es-east-1::snapshot:my-db-instance-snap'


def _create_db_instance_snapshot(hook: RdsHook):
def _create_db_instance(hook: RdsHook):
hook.conn.create_db_instance(
DBInstanceIdentifier=DB_INSTANCE_NAME,
DBInstanceClass='db.m4.large',
Engine='postgres',
DBInstanceClass="db.t4g.micro",
hankehly marked this conversation as resolved.
Show resolved Hide resolved
Engine="postgres",
)
if not hook.conn.describe_db_instances()['DBInstances']:
raise ValueError('AWS not properly mocked')
if not hook.conn.describe_db_instances()["DBInstances"]:
raise ValueError("AWS not properly mocked")


def _create_db_instance_snapshot(hook: RdsHook):
_create_db_instance(hook)
hook.conn.create_db_snapshot(
DBInstanceIdentifier=DB_INSTANCE_NAME,
DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT,
Expand All @@ -67,7 +72,7 @@ def _create_db_instance_snapshot(hook: RdsHook):
raise ValueError('AWS not properly mocked')


def _create_db_cluster_snapshot(hook: RdsHook):
def _create_db_cluster(hook: RdsHook):
hook.conn.create_db_cluster(
DBClusterIdentifier=DB_CLUSTER_NAME,
Engine='mysql',
Expand All @@ -77,6 +82,9 @@ def _create_db_cluster_snapshot(hook: RdsHook):
if not hook.conn.describe_db_clusters()['DBClusters']:
raise ValueError('AWS not properly mocked')


def _create_db_cluster_snapshot(hook: RdsHook):
_create_db_cluster(hook)
hook.conn.create_db_cluster_snapshot(
DBClusterIdentifier=DB_CLUSTER_NAME,
DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT,
Expand Down Expand Up @@ -225,3 +233,67 @@ def test_export_task_poke_false(self):
dag=self.dag,
)
assert not op.poke(None)


@pytest.mark.skipif(mock_rds is None, reason="mock_rds package not present")
class TestRdsDbSensor:
@classmethod
def setup_class(cls):
cls.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name="us-east-1")

@classmethod
def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds
def test_poke_true_instance(self):
"""
By default RdsDbSensor should wait for an instance to enter the 'available' state
"""
_create_db_instance(self.hook)
op = RdsDbSensor(
task_id="instance_poke_true",
db_identifier=DB_INSTANCE_NAME,
aws_conn_id=AWS_CONN,
dag=self.dag,
)
assert op.poke(None)

@mock_rds
def test_poke_false_instance(self):
_create_db_instance(self.hook)
op = RdsDbSensor(
task_id="instance_poke_false",
db_identifier=DB_INSTANCE_NAME,
target_statuses=["stopped"],
aws_conn_id=AWS_CONN,
dag=self.dag,
)
assert not op.poke(None)

@mock_rds
def test_poke_true_cluster(self):
_create_db_cluster(self.hook)
op = RdsDbSensor(
Copy link
Contributor Author

@hankehly hankehly Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022/09/05 update
Added unit test cases for "cluster" db type

task_id="cluster_poke_true",
db_identifier=DB_CLUSTER_NAME,
db_type=RdsDbType.CLUSTER,
aws_conn_id=AWS_CONN,
dag=self.dag,
)
assert op.poke(None)

@mock_rds
def test_poke_false_cluster(self):
_create_db_cluster(self.hook)
op = RdsDbSensor(
task_id="cluster_poke_false",
db_identifier=DB_CLUSTER_NAME,
target_statuses=["stopped"],
db_type=RdsDbType.CLUSTER,
aws_conn_id=AWS_CONN,
dag=self.dag,
)
assert not op.poke(None)
12 changes: 10 additions & 2 deletions tests/system/providers/amazon/aws/rds/example_rds_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RdsCreateDbInstanceOperator,
RdsDeleteDbInstanceOperator,
)
from airflow.providers.amazon.aws.sensors.rds import RdsDbSensor
from tests.system.providers.amazon.aws.utils import set_env_id

ENV_ID = set_env_id()
Expand All @@ -46,7 +47,7 @@
create_db_instance = RdsCreateDbInstanceOperator(
task_id='create_db_instance',
db_instance_identifier=RDS_DB_IDENTIFIER,
db_instance_class="db.m5.large",
db_instance_class="db.t4g.micro",
hankehly marked this conversation as resolved.
Show resolved Hide resolved
engine="postgres",
rds_kwargs={
"MasterUsername": RDS_USERNAME,
Expand All @@ -56,6 +57,13 @@
)
# [END howto_operator_rds_create_db_instance]

# [START howto_sensor_rds_instance]
db_instance_available = RdsDbSensor(
task_id="db_instance_available",
db_identifier=RDS_DB_IDENTIFIER,
)
# [END howto_sensor_rds_instance]

# [START howto_operator_rds_delete_db_instance]
delete_db_instance = RdsDeleteDbInstanceOperator(
task_id='delete_db_instance',
Expand All @@ -66,7 +74,7 @@
)
# [END howto_operator_rds_delete_db_instance]

chain(create_db_instance, delete_db_instance)
chain(create_db_instance, db_instance_available, delete_db_instance)

from tests.system.utils import get_test_run # noqa: E402

Expand Down