diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py index b1d165617546d..8e290e61a87a7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/rds.py @@ -20,7 +20,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from airflow.exceptions import AirflowNotFoundException +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.utils.rds import RdsDbType from airflow.sensors.base import BaseSensorOperator @@ -104,18 +104,17 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor): :param export_task_identifier: A unique identifier for the snapshot export task. :param target_statuses: Target status of export task + :param error_statuses: Target error status of export task to fail the sensor """ - template_fields: Sequence[str] = ( - "export_task_identifier", - "target_statuses", - ) + template_fields: Sequence[str] = ("export_task_identifier", "target_statuses", "error_statuses") def __init__( self, *, export_task_identifier: str, target_statuses: list[str] | None = None, + error_statuses: list[str] | None = None, aws_conn_id: str | None = "aws_default", **kwargs, ): @@ -129,6 +128,7 @@ def __init__( "canceling", "canceled", ] + self.error_statuses = error_statuses or ["failed"] def poke(self, context: Context): self.log.info( @@ -136,6 +136,11 @@ def poke(self, context: Context): ) try: state = self.hook.get_export_task_state(self.export_task_identifier) + if state in self.error_statuses: + raise AirflowException( + f"Export task {self.export_task_identifier} failed with status {state}" + ) + except AirflowNotFoundException: return False return state in self.target_statuses diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py index 4edad83add0ef..b585ca21c7b6e 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py @@ -16,8 +16,12 @@ # under the License. from __future__ import annotations +from unittest.mock import patch + +import pytest from moto import mock_aws +from airflow.exceptions import AirflowException from airflow.models import DAG from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.sensors.rds import ( @@ -99,6 +103,11 @@ def _start_export_task(hook: RdsHook): raise ValueError("AWS not properly mocked") +def _start_export_task_with_error(hook: RdsHook, mock_describe_export_tasks): + _create_db_instance_snapshot(hook) + mock_describe_export_tasks.return_value = "failed" + + class TestBaseRdsSensor: dag = None base_sensor = None @@ -223,6 +232,22 @@ def test_export_task_poke_false(self): ) assert not op.poke(None) + @mock_aws + @patch("airflow.providers.amazon.aws.hooks.rds.RdsHook.get_export_task_state") + def test_error_statuses(self, mock_describe_export_tasks): + # Simulate an error condition + _start_export_task_with_error(self.hook, mock_describe_export_tasks) + op = RdsExportTaskExistenceSensor( + task_id="export_task_error", + export_task_identifier=EXPORT_TASK_NAME, + aws_conn_id=AWS_CONN, + dag=self.dag, + ) + with pytest.raises(AirflowException): + op.poke(None) + + assert "failed" in op.error_statuses + class TestRdsDbSensor: @classmethod