Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions providers/amazon/docs/operators/rds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
101 changes: 83 additions & 18 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
import json
from collections.abc import Sequence
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.rds import (
RdsDbAvailableTrigger,
RdsDbDeletedTrigger,
RdsDbStoppedTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
Expand All @@ -44,9 +44,10 @@
from airflow.utils.context import Context


class RdsBaseOperator(BaseOperator):
class RdsBaseOperator(AwsBaseOperator[RdsHook]):
"""Base operator that implements common functions for all operators."""

aws_hook_class = RdsHook
ui_color = "#eeaa88"
ui_fgcolor = "#ffffff"

Expand All @@ -63,10 +64,6 @@ def __init__(

self._await_interval = 60 # seconds

@cached_property
def hook(self) -> RdsHook:
return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context) -> str:
"""Different implementations for snapshots, tasks and events."""
raise NotImplementedError
Expand All @@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
:param aws_conn_id: The Airflow connection used for AWS credentials.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ellisms

Any reason you included aws_conn_id param here but not for any of the other operators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, looks like an oversight when adding aws_conn_id,region_name, and verify. I can do another MR to add it to the other operators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of copying parent class docstrings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, I think we should not include docstrings from parent classes. If there are multiple parents, the list can become long and some docstring can quickly become out of date. Another example: #51236

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a:

seealso::
:class: <RespectiveBaseOperator> regarding additional parameters

Although this could lead to a “chaining” in case of 3+ level inheritation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is more about maintaining consistency with these 4 important params. We're trying to signal to the user (and their IDE via param driven autocompletion) what can still be done with these operators, despite those params being moved to the super class. So for me personally, I'm still in favour of this. But as usual, if I'm out voted we can stop doing it and go back and remove all the places where we have (there are a LOT so that will be a a great line count booster PR for someone 😉 )

If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
template_fields = aws_template_fields("db_snapshot_identifier", "db_identifier", "tags")

def __init__(
self,
Expand Down Expand Up @@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
Only when db_type='instance'
:param source_region: The ID of the region that contains the snapshot to be copied
:param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = (
template_fields = aws_template_fields(
"source_db_snapshot_identifier",
"target_db_snapshot_identifier",
"tags",
Expand Down Expand Up @@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):

:param db_type: Type of the DB - either "instance" or "cluster"
:param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("db_snapshot_identifier",)
template_fields = aws_template_fields(
"db_snapshot_identifier",
)

def __init__(
self,
Expand Down Expand Up @@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
:param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
:param waiter_interval: The number of seconds to wait before checking the export status. (default: 30)
:param waiter_max_attempts: The number of attempts to make before failing. (default: 40)
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = (
template_fields = aws_template_fields(
"export_task_identifier",
"source_arn",
"s3_bucket_name",
Expand Down Expand Up @@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
:param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
:param check_interval: The amount of time in seconds to wait between attempts
:param max_attempts: The maximum number of attempts to be made
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("export_task_identifier",)
template_fields = aws_template_fields(
"export_task_identifier",
)

def __init__(
self,
Expand Down Expand Up @@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = (
template_fields = aws_template_fields(
"subscription_name",
"sns_topic_arn",
"source_type",
Expand Down Expand Up @@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
:ref:`howto/operator:RdsDeleteEventSubscriptionOperator`

:param subscription_name: The name of the RDS event notification subscription you want to delete
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("subscription_name",)
template_fields = aws_template_fields(
"subscription_name",
)

def __init__(
self,
Expand Down Expand Up @@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs")
template_fields = aws_template_fields(
"db_instance_identifier", "db_instance_class", "engine", "rds_kwargs"
)

def __init__(
self,
Expand Down Expand Up @@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("db_instance_identifier", "rds_kwargs")
template_fields = aws_template_fields("db_instance_identifier", "rds_kwargs")

def __init__(
self,
Expand Down Expand Up @@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator):
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("db_identifier", "db_type")
template_fields = aws_template_fields("db_identifier", "db_type")

def __init__(
self,
Expand Down Expand Up @@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator):
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
template_fields = aws_template_fields("db_identifier", "db_snapshot_identifier", "db_type")

def __init__(
self,
Expand Down
43 changes: 23 additions & 20 deletions providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,30 @@
from __future__ import annotations

from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class RdsBaseSensor(BaseSensorOperator):
class RdsBaseSensor(AwsBaseSensor[RdsHook]):
"""Base operator that implements common functions for all sensors."""

aws_hook_class = RdsHook
ui_color = "#ddbb77"
ui_fgcolor = "#ffffff"

def __init__(
self, *args, aws_conn_id: str | None = "aws_conn_id", hook_params: dict | None = None, **kwargs
):
def __init__(self, *args, hook_params: dict | None = None, **kwargs):
self.hook_params = hook_params or {}
self.aws_conn_id = aws_conn_id
self.target_statuses: list[str] = []
super().__init__(*args, **kwargs)

@cached_property
def hook(self):
return RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)


class RdsSnapshotExistenceSensor(RdsBaseSensor):
"""
Expand All @@ -59,9 +53,19 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
:param db_type: Type of the DB - either "instance" or "cluster"
:param db_snapshot_identifier: The identifier for the DB snapshot
:param target_statuses: Target status of snapshot
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"db_snapshot_identifier",
"target_statuses",
)
Expand All @@ -72,10 +76,9 @@ def __init__(
db_type: str,
db_snapshot_identifier: str,
target_statuses: list[str] | None = None,
aws_conn_id: str | None = "aws_conn_id",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
super().__init__(**kwargs)
self.db_type = RdsDbType(db_type)
self.db_snapshot_identifier = db_snapshot_identifier
self.target_statuses = target_statuses or ["available"]
Expand Down Expand Up @@ -107,18 +110,19 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
:param error_statuses: Target error status of export task to fail the sensor
"""

template_fields: Sequence[str] = ("export_task_identifier", "target_statuses", "error_statuses")
template_fields: Sequence[str] = aws_template_fields(
"export_task_identifier", "target_statuses", "error_statuses"
)

def __init__(
self,
*,
export_task_identifier: str,
target_statuses: list[str] | None = None,
error_statuses: list[str] | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
super().__init__(**kwargs)

self.export_task_identifier = export_task_identifier
self.target_statuses = target_statuses or [
Expand Down Expand Up @@ -159,7 +163,7 @@ class RdsDbSensor(RdsBaseSensor):
:param target_statuses: Target status of DB
"""

template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"db_identifier",
"db_type",
"target_statuses",
Expand All @@ -171,10 +175,9 @@ def __init__(
db_identifier: str,
db_type: RdsDbType | str = RdsDbType.INSTANCE,
target_statuses: list[str] | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
super().__init__(**kwargs)
self.db_identifier = db_identifier
self.target_statuses = target_statuses or ["available"]
self.db_type = db_type
Expand Down
Loading