diff --git a/providers/amazon/docs/operators/rds.rst b/providers/amazon/docs/operators/rds.rst index 841dd90a6a3b6..916c61d854849 100644 --- a/providers/amazon/docs/operators/rds.rst +++ b/providers/amazon/docs/operators/rds.rst @@ -29,6 +29,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py index 18f227d85ae7f..c7841a08ce162 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/rds.py @@ -20,19 +20,19 @@ import json from collections.abc import Sequence from datetime import timedelta -from functools import cached_property from typing import TYPE_CHECKING, Any from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.rds import RdsHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.rds import ( RdsDbAvailableTrigger, RdsDbDeletedTrigger, RdsDbStoppedTrigger, ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.amazon.aws.utils.rds import RdsDbType from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.providers.amazon.aws.utils.waiter_with_logging import wait @@ -44,9 +44,10 @@ from airflow.utils.context import Context -class RdsBaseOperator(BaseOperator): +class RdsBaseOperator(AwsBaseOperator[RdsHook]): """Base operator that implements common functions for all operators.""" + aws_hook_class = RdsHook ui_color = "#eeaa88" ui_fgcolor = "#ffffff" @@ -63,10 +64,6 @@ def __init__( self._await_interval = 60 # seconds - @cached_property - def hook(self) -> RdsHook: - return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - def execute(self, context: Context) -> str: """Different implementations for snapshots, tasks and events.""" raise NotImplementedError @@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]` `USER Tagging `__ :param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("db_snapshot_identifier", "db_identifier", "tags") + template_fields = aws_template_fields("db_snapshot_identifier", "db_identifier", "tags") def __init__( self, @@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator): Only when db_type='instance' :param source_region: The ID of the region that contains the snapshot to be copied :param wait_for_completion: If True, waits for snapshot copy to complete. (default: True) + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ( + template_fields = aws_template_fields( "source_db_snapshot_identifier", "target_db_snapshot_identifier", "tags", @@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator): :param db_type: Type of the DB - either "instance" or "cluster" :param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("db_snapshot_identifier",) + template_fields = aws_template_fields( + "db_snapshot_identifier", + ) def __init__( self, @@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator): :param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True) :param waiter_interval: The number of seconds to wait before checking the export status. (default: 30) :param waiter_max_attempts: The number of attempts to make before failing. (default: 40) + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ( + template_fields = aws_template_fields( "export_task_identifier", "source_arn", "s3_bucket_name", @@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator): :param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True) :param check_interval: The amount of time in seconds to wait between attempts :param max_attempts: The maximum number of attempts to be made + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("export_task_identifier",) + template_fields = aws_template_fields( + "export_task_identifier", + ) def __init__( self, @@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator): :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]` `USER Tagging `__ :param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True) + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ( + template_fields = aws_template_fields( "subscription_name", "sns_topic_arn", "source_type", @@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator): :ref:`howto/operator:RdsDeleteEventSubscriptionOperator` :param subscription_name: The name of the RDS event notification subscription you want to delete + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("subscription_name",) + template_fields = aws_template_fields( + "subscription_name", + ) def __init__( self, @@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator): :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs") + template_fields = aws_template_fields( + "db_instance_identifier", "db_instance_class", "engine", "rds_kwargs" + ) def __init__( self, @@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator): :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("db_instance_identifier", "rds_kwargs") + template_fields = aws_template_fields("db_instance_identifier", "rds_kwargs") def __init__( self, @@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator): :param waiter_max_attempts: The maximum number of attempts to check DB instance state :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created. This implies waiting for completion. This mode requires aiobotocore module to be installed. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("db_identifier", "db_type") + template_fields = aws_template_fields("db_identifier", "db_type") def __init__( self, @@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator): :param waiter_max_attempts: The maximum number of attempts to check DB instance state :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created. This implies waiting for completion. This mode requires aiobotocore module to be installed. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ("db_identifier", "db_snapshot_identifier", "db_type") + template_fields = aws_template_fields("db_identifier", "db_snapshot_identifier", "db_type") def __init__( self, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py index 8e290e61a87a7..03a170d29a14d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py @@ -17,36 +17,30 @@ from __future__ import annotations from collections.abc import Sequence -from functools import cached_property from typing import TYPE_CHECKING from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.rds import RdsHook +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.amazon.aws.utils.rds import RdsDbType -from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context -class RdsBaseSensor(BaseSensorOperator): +class RdsBaseSensor(AwsBaseSensor[RdsHook]): """Base operator that implements common functions for all sensors.""" + aws_hook_class = RdsHook ui_color = "#ddbb77" ui_fgcolor = "#ffffff" - def __init__( - self, *args, aws_conn_id: str | None = "aws_conn_id", hook_params: dict | None = None, **kwargs - ): + def __init__(self, *args, hook_params: dict | None = None, **kwargs): self.hook_params = hook_params or {} - self.aws_conn_id = aws_conn_id self.target_statuses: list[str] = [] super().__init__(*args, **kwargs) - @cached_property - def hook(self): - return RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params) - class RdsSnapshotExistenceSensor(RdsBaseSensor): """ @@ -59,9 +53,19 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor): :param db_type: Type of the DB - either "instance" or "cluster" :param db_snapshot_identifier: The identifier for the DB snapshot :param target_statuses: Target status of snapshot + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ( + template_fields: Sequence[str] = aws_template_fields( "db_snapshot_identifier", "target_statuses", ) @@ -72,10 +76,9 @@ def __init__( db_type: str, db_snapshot_identifier: str, target_statuses: list[str] | None = None, - aws_conn_id: str | None = "aws_conn_id", **kwargs, ): - super().__init__(aws_conn_id=aws_conn_id, **kwargs) + super().__init__(**kwargs) self.db_type = RdsDbType(db_type) self.db_snapshot_identifier = db_snapshot_identifier self.target_statuses = target_statuses or ["available"] @@ -107,7 +110,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor): :param error_statuses: Target error status of export task to fail the sensor """ - template_fields: Sequence[str] = ("export_task_identifier", "target_statuses", "error_statuses") + template_fields: Sequence[str] = aws_template_fields( + "export_task_identifier", "target_statuses", "error_statuses" + ) def __init__( self, @@ -115,10 +120,9 @@ def __init__( export_task_identifier: str, target_statuses: list[str] | None = None, error_statuses: list[str] | None = None, - aws_conn_id: str | None = "aws_default", **kwargs, ): - super().__init__(aws_conn_id=aws_conn_id, **kwargs) + super().__init__(**kwargs) self.export_task_identifier = export_task_identifier self.target_statuses = target_statuses or [ @@ -159,7 +163,7 @@ class RdsDbSensor(RdsBaseSensor): :param target_statuses: Target status of DB """ - template_fields: Sequence[str] = ( + template_fields: Sequence[str] = aws_template_fields( "db_identifier", "db_type", "target_statuses", @@ -171,10 +175,9 @@ def __init__( db_identifier: str, db_type: RdsDbType | str = RdsDbType.INSTANCE, target_statuses: list[str] | None = None, - aws_conn_id: str | None = "aws_default", **kwargs, ): - super().__init__(aws_conn_id=aws_conn_id, **kwargs) + super().__init__(**kwargs) self.db_identifier = db_identifier self.target_statuses = target_statuses or ["available"] self.db_type = db_type diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py index c6537c9391b0d..6c22c75e8aff7 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py @@ -54,6 +54,8 @@ AWS_CONN = "amazon_default" +REGION = "us-east-1" + DB_INSTANCE_NAME = "my-db-instance" DB_CLUSTER_NAME = "my-db-cluster" @@ -282,6 +284,7 @@ def test_template_fields(self): db_snapshot_identifier=DB_INSTANCE_SNAPSHOT, db_identifier=DB_INSTANCE_NAME, aws_conn_id=AWS_CONN, + region_name=REGION, ) validate_template_fields(operator) @@ -410,6 +413,7 @@ def test_template_fields(self): source_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT, target_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT_COPY, aws_conn_id=AWS_CONN, + region_name=REGION, ) validate_template_fields(operator) @@ -527,6 +531,7 @@ def test_template_fields(self): db_snapshot_identifier=DB_CLUSTER_SNAPSHOT, aws_conn_id=AWS_CONN, wait_for_completion=False, + region_name=REGION, ) validate_template_fields(operator) @@ -610,6 +615,7 @@ def test_template_fields(self): s3_bucket_name=EXPORT_TASK_BUCKET, aws_conn_id=AWS_CONN, wait_for_completion=False, + region_name=REGION, ) validate_template_fields(operator) @@ -682,6 +688,7 @@ def test_template_fields(self): task_id="test_cancel", export_task_identifier=EXPORT_TASK_NAME, aws_conn_id=AWS_CONN, + region_name=REGION, ) validate_template_fields(operator) @@ -759,6 +766,7 @@ def test_template_fields(self): source_type="db-instance", source_ids=[DB_INSTANCE_NAME], aws_conn_id=AWS_CONN, + region_name=REGION, ) validate_template_fields(operator) @@ -800,6 +808,7 @@ def test_template_fields(self): task_id="test_delete", subscription_name=SUBSCRIPTION_NAME, aws_conn_id=AWS_CONN, + region_name=REGION, ) validate_template_fields(operator) @@ -879,6 +888,7 @@ def test_template_fields(self): "DBName": DB_INSTANCE_NAME, }, aws_conn_id=AWS_CONN, + region_name=REGION, ) validate_template_fields(operator) @@ -949,6 +959,7 @@ def test_template_fields(self): }, aws_conn_id=AWS_CONN, wait_for_completion=False, + region_name=REGION, ) validate_template_fields(operator) @@ -1062,6 +1073,7 @@ def test_template_fields(self): db_identifier=DB_CLUSTER_NAME, db_type="cluster", db_snapshot_identifier=DB_CLUSTER_SNAPSHOT, + region_name=REGION, ) validate_template_fields(operator) @@ -1133,6 +1145,10 @@ def test_deferred(self, conn_mock): def test_template_fields(self): operator = RdsStartDbOperator( - task_id="test_start_db_cluster", db_identifier=DB_CLUSTER_NAME, db_type="cluster" + region_name=REGION, + aws_conn_id=AWS_CONN, + task_id="test_start_db_cluster", + db_identifier=DB_CLUSTER_NAME, + db_type="cluster", ) validate_template_fields(operator) diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py index b585ca21c7b6e..3bdd8f673bf3b 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py @@ -33,6 +33,8 @@ from airflow.providers.amazon.aws.utils.rds import RdsDbType from airflow.utils import timezone +from unit.amazon.aws.utils.test_template_fields import validate_template_fields + DEFAULT_DATE = timezone.datetime(2019, 1, 1) AWS_CONN = "aws_default" @@ -146,6 +148,16 @@ def teardown_class(cls): del cls.dag del cls.hook + def test_template_fields(self): + sensor = RdsSnapshotExistenceSensor( + task_id="test_template_fields", + db_type="instance", + db_snapshot_identifier=DB_INSTANCE_SNAPSHOT, + aws_conn_id=AWS_CONN, + region_name="us-east-1", + ) + validate_template_fields(sensor) + @mock_aws def test_db_instance_snapshot_poke_true(self): _create_db_instance_snapshot(self.hook) @@ -209,6 +221,15 @@ def teardown_class(cls): del cls.dag del cls.hook + def test_template_fields(self): + sensor = RdsExportTaskExistenceSensor( + task_id="test_template_fields", + export_task_identifier=EXPORT_TASK_NAME, + aws_conn_id=AWS_CONN, + region_name="us-east-1", + ) + validate_template_fields(sensor) + @mock_aws def test_export_task_poke_true(self): _create_db_instance_snapshot(self.hook) @@ -264,6 +285,15 @@ def teardown_class(cls): del cls.dag del cls.hook + def test_template_fields(self): + sensor = RdsDbSensor( + task_id="test_template_fields", + db_identifier=DB_INSTANCE_NAME, + aws_conn_id=AWS_CONN, + region_name="us-east-1", + ) + validate_template_fields(sensor) + @mock_aws def test_poke_true_instance(self): """