-
Notifications
You must be signed in to change notification settings - Fork 14.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RdsDbSensor to amazon provider package #26003
Changes from all commits
7163a62
51f2055
c0d61ca
67c1531
522026d
b5c05b3
ef1c5e8
8281317
2176ffc
0d537b1
26f1f4e
ea4977d
1facc0c
afefe36
e310910
ae6f53e
0aa3d52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -41,7 +41,6 @@ def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optiona | |||||||||
super().__init__(*args, **kwargs) | ||||||||||
|
||||||||||
def _describe_item(self, item_type: str, item_name: str) -> list: | ||||||||||
|
||||||||||
if item_type == 'instance_snapshot': | ||||||||||
db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name) | ||||||||||
return db_snaps['DBSnapshots'] | ||||||||||
|
@@ -51,17 +50,29 @@ def _describe_item(self, item_type: str, item_name: str) -> list: | |||||||||
elif item_type == 'export_task': | ||||||||||
exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name) | ||||||||||
return exports['ExportTasks'] | ||||||||||
elif item_type == "db_instance": | ||||||||||
instances = self.hook.conn.describe_db_instances(DBInstanceIdentifier=item_name) | ||||||||||
return instances["DBInstances"] | ||||||||||
elif item_type == "db_cluster": | ||||||||||
clusters = self.hook.conn.describe_db_clusters(DBClusterIdentifier=item_name) | ||||||||||
return clusters["DBClusters"] | ||||||||||
else: | ||||||||||
raise AirflowException(f"Method for {item_type} is not implemented") | ||||||||||
|
||||||||||
def _check_item(self, item_type: str, item_name: str) -> bool: | ||||||||||
"""Get certain item from `_describe_item()` and check its status""" | ||||||||||
if item_type == "db_instance": | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update 2022/09/06 |
||||||||||
status_field = "DBInstanceStatus" | ||||||||||
else: | ||||||||||
status_field = "Status" | ||||||||||
try: | ||||||||||
items = self._describe_item(item_type, item_name) | ||||||||||
except ClientError: | ||||||||||
return False | ||||||||||
else: | ||||||||||
return bool(items) and any(map(lambda s: items[0]['Status'].lower() == s, self.target_statuses)) | ||||||||||
return bool(items) and any( | ||||||||||
map(lambda status: items[0][status_field].lower() == status, self.target_statuses) | ||||||||||
) | ||||||||||
Comment on lines
+73
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the
Suggested change
For me, both variants are well readable but maybe this one is better. WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestions. I'd like to make these updates in a separate PR. |
||||||||||
|
||||||||||
|
||||||||||
class RdsSnapshotExistenceSensor(RdsBaseSensor): | ||||||||||
|
@@ -149,7 +160,48 @@ def poke(self, context: 'Context'): | |||||||||
return self._check_item(item_type='export_task', item_name=self.export_task_identifier) | ||||||||||
|
||||||||||
|
||||||||||
class RdsDbSensor(RdsBaseSensor): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2022/09/05 update |
||||||||||
""" | ||||||||||
Waits for an RDS instance or cluster to enter one of a number of states | ||||||||||
|
||||||||||
.. seealso:: | ||||||||||
For more information on how to use this sensor, take a look at the guide: | ||||||||||
:ref:`howto/sensor:RdsDbSensor` | ||||||||||
|
||||||||||
:param db_type: Type of the DB - either "instance" or "cluster" | ||||||||||
:param db_identifier: The AWS identifier for the DB | ||||||||||
:param target_statuses: Target status of DB | ||||||||||
""" | ||||||||||
|
||||||||||
def __init__( | ||||||||||
self, | ||||||||||
*, | ||||||||||
db_identifier: str, | ||||||||||
db_type: str = "instance", | ||||||||||
target_statuses: Optional[List[str]] = None, | ||||||||||
aws_conn_id: str = "aws_default", | ||||||||||
**kwargs, | ||||||||||
): | ||||||||||
super().__init__(aws_conn_id=aws_conn_id, **kwargs) | ||||||||||
self.db_identifier = db_identifier | ||||||||||
self.target_statuses = target_statuses or ["available"] | ||||||||||
self.db_type = RdsDbType(db_type) | ||||||||||
|
||||||||||
def poke(self, context: 'Context'): | ||||||||||
self.log.info( | ||||||||||
"Poking for statuses : %s\nfor db instance %s", self.target_statuses, self.db_identifier | ||||||||||
) | ||||||||||
item_type = self._check_item_type() | ||||||||||
return self._check_item(item_type=item_type, item_name=self.db_identifier) | ||||||||||
|
||||||||||
def _check_item_type(self): | ||||||||||
if self.db_type == RdsDbType.CLUSTER: | ||||||||||
return "db_cluster" | ||||||||||
return "db_instance" | ||||||||||
|
||||||||||
|
||||||||||
__all__ = [ | ||||||||||
"RdsExportTaskExistenceSensor", | ||||||||||
"RdsDbSensor", | ||||||||||
"RdsSnapshotExistenceSensor", | ||||||||||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,6 +169,22 @@ To delete a AWS DB instance you can use | |
Sensors | ||
------- | ||
|
||
.. _howto/sensor:RdsDbSensor: | ||
|
||
Wait on an Amazon RDS instance or cluster status | ||
================================================ | ||
|
||
To wait for an Amazon RDS instance or cluster to reach a specific status you can use | ||
:class:`~airflow.providers.amazon.aws.sensors.rds.RdsDbSensor`. | ||
By default, the sensor waits for a database instance to reach the ``available`` state. | ||
|
||
.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py | ||
:language: python | ||
:dedent: 4 | ||
:start-after: [START howto_sensor_rds_instance] | ||
:end-before: [END howto_sensor_rds_instance] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
.. _howto/sensor:RdsSnapshotExistenceSensor: | ||
|
||
Wait on an Amazon RDS snapshot status | ||
|
@@ -204,3 +220,4 @@ Reference | |
--------- | ||
|
||
* `AWS boto3 library documentation for RDS <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html>`__ | ||
* `RDS DB instance statuses <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html>`__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,9 +22,11 @@ | |
from airflow.providers.amazon.aws.hooks.rds import RdsHook | ||
from airflow.providers.amazon.aws.sensors.rds import ( | ||
RdsBaseSensor, | ||
RdsDbSensor, | ||
RdsExportTaskExistenceSensor, | ||
RdsSnapshotExistenceSensor, | ||
) | ||
from airflow.providers.amazon.aws.utils.rds import RdsDbType | ||
from airflow.utils import timezone | ||
|
||
try: | ||
|
@@ -50,15 +52,18 @@ | |
EXPORT_TASK_SOURCE = 'arn:aws:rds:es-east-1::snapshot:my-db-instance-snap' | ||
|
||
|
||
def _create_db_instance_snapshot(hook: RdsHook): | ||
def _create_db_instance(hook: RdsHook): | ||
hook.conn.create_db_instance( | ||
DBInstanceIdentifier=DB_INSTANCE_NAME, | ||
DBInstanceClass='db.m4.large', | ||
Engine='postgres', | ||
DBInstanceClass="db.t4g.micro", | ||
hankehly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Engine="postgres", | ||
) | ||
if not hook.conn.describe_db_instances()['DBInstances']: | ||
raise ValueError('AWS not properly mocked') | ||
if not hook.conn.describe_db_instances()["DBInstances"]: | ||
raise ValueError("AWS not properly mocked") | ||
|
||
|
||
def _create_db_instance_snapshot(hook: RdsHook): | ||
_create_db_instance(hook) | ||
hook.conn.create_db_snapshot( | ||
DBInstanceIdentifier=DB_INSTANCE_NAME, | ||
DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT, | ||
|
@@ -67,7 +72,7 @@ def _create_db_instance_snapshot(hook: RdsHook): | |
raise ValueError('AWS not properly mocked') | ||
|
||
|
||
def _create_db_cluster_snapshot(hook: RdsHook): | ||
def _create_db_cluster(hook: RdsHook): | ||
hook.conn.create_db_cluster( | ||
DBClusterIdentifier=DB_CLUSTER_NAME, | ||
Engine='mysql', | ||
|
@@ -77,6 +82,9 @@ def _create_db_cluster_snapshot(hook: RdsHook): | |
if not hook.conn.describe_db_clusters()['DBClusters']: | ||
raise ValueError('AWS not properly mocked') | ||
|
||
|
||
def _create_db_cluster_snapshot(hook: RdsHook): | ||
_create_db_cluster(hook) | ||
hook.conn.create_db_cluster_snapshot( | ||
DBClusterIdentifier=DB_CLUSTER_NAME, | ||
DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT, | ||
|
@@ -225,3 +233,67 @@ def test_export_task_poke_false(self): | |
dag=self.dag, | ||
) | ||
assert not op.poke(None) | ||
|
||
|
||
@pytest.mark.skipif(mock_rds is None, reason="mock_rds package not present") | ||
class TestRdsDbSensor: | ||
@classmethod | ||
def setup_class(cls): | ||
cls.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) | ||
cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name="us-east-1") | ||
|
||
@classmethod | ||
def teardown_class(cls): | ||
del cls.dag | ||
del cls.hook | ||
|
||
@mock_rds | ||
def test_poke_true_instance(self): | ||
""" | ||
By default RdsDbSensor should wait for an instance to enter the 'available' state | ||
""" | ||
_create_db_instance(self.hook) | ||
op = RdsDbSensor( | ||
task_id="instance_poke_true", | ||
db_identifier=DB_INSTANCE_NAME, | ||
aws_conn_id=AWS_CONN, | ||
dag=self.dag, | ||
) | ||
assert op.poke(None) | ||
|
||
@mock_rds | ||
def test_poke_false_instance(self): | ||
_create_db_instance(self.hook) | ||
op = RdsDbSensor( | ||
task_id="instance_poke_false", | ||
db_identifier=DB_INSTANCE_NAME, | ||
target_statuses=["stopped"], | ||
aws_conn_id=AWS_CONN, | ||
dag=self.dag, | ||
) | ||
assert not op.poke(None) | ||
|
||
@mock_rds | ||
def test_poke_true_cluster(self): | ||
_create_db_cluster(self.hook) | ||
op = RdsDbSensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2022/09/05 update |
||
task_id="cluster_poke_true", | ||
db_identifier=DB_CLUSTER_NAME, | ||
db_type=RdsDbType.CLUSTER, | ||
aws_conn_id=AWS_CONN, | ||
dag=self.dag, | ||
) | ||
assert op.poke(None) | ||
|
||
@mock_rds | ||
def test_poke_false_cluster(self): | ||
_create_db_cluster(self.hook) | ||
op = RdsDbSensor( | ||
task_id="cluster_poke_false", | ||
db_identifier=DB_CLUSTER_NAME, | ||
target_statuses=["stopped"], | ||
db_type=RdsDbType.CLUSTER, | ||
aws_conn_id=AWS_CONN, | ||
dag=self.dag, | ||
) | ||
assert not op.poke(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2022/09/05 update
Added
db_cluster
to be consistent with #21231 / #20907