diff --git a/providers/amazon/docs/operators/mwaa.rst b/providers/amazon/docs/operators/mwaa.rst index 863aef79963be..fc248288c10f8 100644 --- a/providers/amazon/docs/operators/mwaa.rst +++ b/providers/amazon/docs/operators/mwaa.rst @@ -76,6 +76,22 @@ In the following example, the task ``wait_for_dag_run`` waits for the DAG run cr :start-after: [START howto_sensor_mwaa_dag_run] :end-before: [END howto_sensor_mwaa_dag_run] +.. _howto/sensor:MwaaTaskSensor: + +Wait on the state of an AWS MWAA Task +======================================== + +To wait for a DAG task instance across MWAA environments until it reaches one of the given states, you can use the +:class:`~airflow.providers.amazon.aws.sensors.mwaa.MwaaTaskSensor` + +In the following example, the task ``wait_for_task`` waits for the DAG run created in the above task to complete. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_mwaa.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_mwaa_task] + :end-before: [END howto_sensor_mwaa_task] + References ---------- diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py index b21f73837f72e..b19f5164fcf53 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py @@ -24,10 +24,9 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor -from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger -from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger, MwaaTaskCompletedTrigger from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.utils.state import DagRunState +from airflow.utils.state import DagRunState, TaskInstanceState if TYPE_CHECKING: from airflow.utils.context import Context @@ -139,7 +138,7 @@ def poke(self, context: Context) -> bool: return state in self.success_states def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - validate_execute_complete_event(event) + return None def execute(self, context: Context): if self.deferrable: @@ -150,10 +149,152 @@ def execute(self, context: Context): external_dag_run_id=self.external_dag_run_id, success_states=self.success_states, failure_states=self.failure_states, - # somehow the type of poke_interval is derived as float ?? - waiter_delay=self.poke_interval, # type: ignore[arg-type] + waiter_delay=int(self.poke_interval), waiter_max_attempts=self.max_retries, aws_conn_id=self.aws_conn_id, + end_from_trigger=True, + ), + method_name="execute_complete", + ) + else: + super().execute(context=context) + + +class MwaaTaskSensor(AwsBaseSensor[MwaaHook]): + """ + Waits for a task in an MWAA Environment to complete. + + If the task fails, an AirflowException is thrown. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:MwaaTaskSensor` + + :param external_env_name: The external MWAA environment name that contains the Task Instance you want to wait for + (templated) + :param external_dag_id: The DAG ID in the external MWAA environment that contains the Task Instance you want to wait for + (templated) + :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated) + :param external_task_id: The Task ID in the external MWAA environment that you want to wait for (templated) + :param success_states: Collection of task instance states that would make this task marked as successful, default is + ``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated) + :param failure_states: Collection of task instance states that would make this task marked as failed and raise an + AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated) + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60) + :param max_retries: Number of times before returning the current state. (default: 720) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = MwaaHook + template_fields: Sequence[str] = aws_template_fields( + "external_env_name", + "external_dag_id", + "external_dag_run_id", + "external_task_id", + "success_states", + "failure_states", + "deferrable", + "max_retries", + "poke_interval", + ) + + def __init__( + self, + *, + external_env_name: str, + external_dag_id: str, + external_dag_run_id: str | None = None, + external_task_id: str, + success_states: Collection[str] | None = None, + failure_states: Collection[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poke_interval: int = 60, + max_retries: int = 720, + **kwargs, + ): + super().__init__(**kwargs) + + self.success_states = set(success_states) if success_states else {TaskInstanceState.SUCCESS.value} + self.failure_states = set(failure_states) if failure_states else {TaskInstanceState.FAILED.value} + + if len(self.success_states & self.failure_states): + raise ValueError("success_states and failure_states must not have any values in common") + + self.external_env_name = external_env_name + self.external_dag_id = external_dag_id + self.external_dag_run_id = external_dag_run_id + self.external_task_id = external_task_id + self.deferrable = deferrable + self.poke_interval = poke_interval + self.max_retries = max_retries + + def poke(self, context: Context) -> bool: + self.log.info( + "Poking for task %s of DAG run %s of DAG %s in MWAA environment %s", + self.external_task_id, + self.external_dag_run_id, + self.external_dag_id, + self.external_env_name, + ) + + response = self.hook.invoke_rest_api( + env_name=self.external_env_name, + path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}/taskInstances/{self.external_task_id}", + method="GET", + ) + # If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has + # happened in the API and KeyError would be raised + # If RestApiStatusCode >= 300, a botocore exception would've already been raised during the + # self.hook.invoke_rest_api call + # The scope of this sensor is going to only be raising AirflowException due to failure of the task + + state = response["RestApiResponse"]["state"] + + if state in self.failure_states: + raise AirflowException( + f"The task {self.external_task_id} of DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} " + f"failed with state: {state}" + ) + + return state in self.success_states + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + return None + + def execute(self, context: Context): + if self.external_dag_run_id is None: + response = self.hook.invoke_rest_api( + env_name=self.external_env_name, + path=f"/dags/{self.external_dag_id}/dagRuns", + method="GET", + ) + self.external_dag_run_id = response["RestApiResponse"]["dag_runs"][-1]["dag_run_id"] + + if self.deferrable: + self.defer( + trigger=MwaaTaskCompletedTrigger( + external_env_name=self.external_env_name, + external_dag_id=self.external_dag_id, + external_dag_run_id=self.external_dag_run_id, + external_task_id=self.external_task_id, + success_states=self.success_states, + failure_states=self.failure_states, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + end_from_trigger=True, ), method_name="execute_complete", ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py index 7a2701aab9d40..4b7ddfd405467 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py @@ -80,7 +80,7 @@ def __init__( waiter_delay: int, waiter_max_attempts: int, waiter_config_overrides: dict[str, Any] | None = None, - aws_conn_id: str | None, + aws_conn_id: str | None = "aws_default", region_name: str | None = None, verify: bool | str | None = None, botocore_config: dict | None = None, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py index 51638084cb188..b31f5a018b9a5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py @@ -22,7 +22,7 @@ from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -from airflow.utils.state import DagRunState +from airflow.utils.state import DagRunState, State, TaskInstanceState if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook @@ -48,7 +48,7 @@ class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger): def __init__( self, - *, + *args, external_env_name: str, external_dag_id: str, external_dag_run_id: str, @@ -56,7 +56,7 @@ def __init__( failure_states: Collection[str] | None = None, waiter_delay: int = 60, waiter_max_attempts: int = 720, - aws_conn_id: str | None = None, + **kwargs, ) -> None: self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value} self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value} @@ -87,7 +87,6 @@ def __init__( return_value=external_dag_run_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, - aws_conn_id=aws_conn_id, waiter_config_overrides={ "acceptors": _build_waiter_acceptors( success_states=self.success_states, @@ -95,6 +94,93 @@ def __init__( in_progress_states=in_progress_states, ) }, + **kwargs, + ) + + def hook(self) -> AwsGenericHook: + return MwaaHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +class MwaaTaskCompletedTrigger(AwsBaseWaiterTrigger): + """ + Trigger when an MWAA Task is complete. + + :param external_env_name: The external MWAA environment name that contains the Task Instance you want to wait for + (templated) + :param external_dag_id: The DAG ID in the external MWAA environment that contains the Task Instance you want to wait for + (templated) + :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated). + If not provided, the latest DAG run is used by default. + :param external_task_id: The Task ID in the external MWAA environment that you want to wait for (templated) + :param success_states: Collection of task instance states that would make this task marked as successful, default is + ``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated) + :param failure_states: Collection of task instance states that would make this task marked as failed and raise an + AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated) + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 720) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *args, + external_env_name: str, + external_dag_id: str, + external_dag_run_id: str | None = None, + external_task_id: str, + success_states: Collection[str] | None = None, + failure_states: Collection[str] | None = None, + waiter_delay: int = 60, + waiter_max_attempts: int = 720, + **kwargs, + ) -> None: + self.success_states = ( + set(success_states) if success_states else {state.value for state in State.success_states} + ) + self.failure_states = ( + set(failure_states) if failure_states else {state.value for state in State.failed_states} + ) + + if len(self.success_states & self.failure_states): + raise ValueError("success_states and failure_states must not have any values in common") + + in_progress_states = {s.value for s in TaskInstanceState} - self.success_states - self.failure_states + + super().__init__( + serialized_fields={ + "external_env_name": external_env_name, + "external_dag_id": external_dag_id, + "external_dag_run_id": external_dag_run_id, + "external_task_id": external_task_id, + "success_states": success_states, + "failure_states": failure_states, + }, + waiter_name="mwaa_task_complete", + waiter_args={ + "Name": external_env_name, + "Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}/taskInstances/{external_task_id}", + "Method": "GET", + }, + failure_message=f"The task {external_task_id} of DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state", + status_message="State of DAG run", + status_queries=["RestApiResponse.state"], + return_key="task_id", + return_value=external_task_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + waiter_config_overrides={ + "acceptors": _build_waiter_acceptors( + success_states=self.success_states, + failure_states=self.failure_states, + in_progress_states=in_progress_states, + ) + }, + **kwargs, ) def hook(self) -> AwsGenericHook: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json index a06e90c42a98b..48e81a0a4c1d2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json @@ -31,6 +31,11 @@ "state": "failure" } ] + }, + "mwaa_task_complete": { + "delay": 60, + "maxAttempts": 20, + "operation": "InvokeRestApi" } } } diff --git a/providers/amazon/tests/system/amazon/aws/example_mwaa.py b/providers/amazon/tests/system/amazon/aws/example_mwaa.py index 0c6695181b484..c5f2f980153ac 100644 --- a/providers/amazon/tests/system/amazon/aws/example_mwaa.py +++ b/providers/amazon/tests/system/amazon/aws/example_mwaa.py @@ -23,7 +23,7 @@ from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from airflow.providers.amazon.aws.hooks.sts import StsHook from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator -from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor +from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor, MwaaTaskSensor from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -43,6 +43,7 @@ # Externally fetched variables: EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME" EXISTING_DAG_ID_KEY = "DAG_ID" +EXISTING_TASK_ID_KEY = "TASK_ID" ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY = "ROLE_WITHOUT_INVOKE_REST_API_ARN" sys_test_context_task = ( @@ -61,6 +62,7 @@ .add_variable(EXISTING_ENVIRONMENT_NAME_KEY) .add_variable(EXISTING_DAG_ID_KEY) .add_variable(ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY) + .add_variable(EXISTING_TASK_ID_KEY) .build() ) @@ -107,6 +109,7 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name): test_context = sys_test_context_task() env_name = test_context[EXISTING_ENVIRONMENT_NAME_KEY] trigger_dag_id = test_context[EXISTING_DAG_ID_KEY] + task_id = test_context[EXISTING_TASK_ID_KEY] restricted_role_arn = test_context[ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY] # [START howto_operator_mwaa_trigger_dag_run] @@ -118,6 +121,16 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name): ) # [END howto_operator_mwaa_trigger_dag_run] + # [START howto_sensor_mwaa_task] + wait_for_task = MwaaTaskSensor( + task_id="wait_for_task", + external_env_name=env_name, + external_dag_id=trigger_dag_id, + external_task_id=task_id, + poke_interval=5, + ) + # [END howto_sensor_mwaa_task] + # [START howto_sensor_mwaa_dag_run] wait_for_dag_run = MwaaDagRunSensor( task_id="wait_for_dag_run", @@ -128,15 +141,29 @@ def test_iam_fallback(role_to_assume_arn, mwaa_env_name): ) # [END howto_sensor_mwaa_dag_run] - chain( - # TEST SETUP - test_context, - # TEST BODY + trigger_dag_run_dont_wait = MwaaTriggerDagRunOperator( + task_id="trigger_dag_run_dont_wait", + env_name=env_name, + trigger_dag_id=trigger_dag_id, + wait_for_completion=False, + ) + + wait_for_task_concurrent = MwaaTaskSensor( + task_id="wait_for_task_concurrent", + external_env_name=env_name, + external_dag_id=trigger_dag_id, + external_task_id=task_id, + poke_interval=5, + ) + + test_context >> [ unpause_dag(env_name, trigger_dag_id), - trigger_dag_run, - wait_for_dag_run, test_iam_fallback(restricted_role_arn, env_name), - ) + trigger_dag_run, + trigger_dag_run_dont_wait, + ] + chain(trigger_dag_run, wait_for_task, wait_for_dag_run) + chain(trigger_dag_run_dont_wait, wait_for_task_concurrent) from tests_common.test_utils.watcher import watcher diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py index 345d4838412d1..d3e13da34a6ab 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py @@ -22,10 +22,10 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook -from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor -from airflow.utils.state import DagRunState +from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor, MwaaTaskSensor +from airflow.utils.state import DagRunState, TaskInstanceState -SENSOR_KWARGS = { +SENSOR_DAG_RUN_KWARGS = { "task_id": "test_mwaa_sensor", "external_env_name": "test_env", "external_dag_id": "test_dag", @@ -35,6 +35,17 @@ "max_retries": 100, } +SENSOR_TASK_KWARGS = { + "task_id": "test_mwaa_sensor", + "external_env_name": "test_env", + "external_dag_id": "test_dag", + "external_dag_run_id": "test_run_id", + "external_task_id": "test_task_id", + "deferrable": True, + "poke_interval": 5, + "max_retries": 100, +} + SENSOR_STATE_KWARGS = { "success_states": ["a", "b"], "failure_states": ["c", "d"], @@ -49,38 +60,80 @@ def mock_invoke_rest_api(): class TestMwaaDagRunSuccessSensor: def test_init_success(self): - sensor = MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS) - assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"] - assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"] - assert sensor.external_dag_run_id == SENSOR_KWARGS["external_dag_run_id"] + sensor = MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS, **SENSOR_STATE_KWARGS) + assert sensor.external_env_name == SENSOR_DAG_RUN_KWARGS["external_env_name"] + assert sensor.external_dag_id == SENSOR_DAG_RUN_KWARGS["external_dag_id"] + assert sensor.external_dag_run_id == SENSOR_DAG_RUN_KWARGS["external_dag_run_id"] assert set(sensor.success_states) == set(SENSOR_STATE_KWARGS["success_states"]) assert set(sensor.failure_states) == set(SENSOR_STATE_KWARGS["failure_states"]) - assert sensor.deferrable == SENSOR_KWARGS["deferrable"] - assert sensor.poke_interval == SENSOR_KWARGS["poke_interval"] - assert sensor.max_retries == SENSOR_KWARGS["max_retries"] + assert sensor.deferrable == SENSOR_DAG_RUN_KWARGS["deferrable"] + assert sensor.poke_interval == SENSOR_DAG_RUN_KWARGS["poke_interval"] + assert sensor.max_retries == SENSOR_DAG_RUN_KWARGS["max_retries"] - sensor = MwaaDagRunSensor(**SENSOR_KWARGS) + sensor = MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS) assert sensor.success_states == {DagRunState.SUCCESS.value} assert sensor.failure_states == {DagRunState.FAILED.value} def test_init_failure(self): with pytest.raises(ValueError, match=r".*success_states.*failure_states.*"): MwaaDagRunSensor( - **SENSOR_KWARGS, success_states={"state1", "state2"}, failure_states={"state2", "state3"} + **SENSOR_DAG_RUN_KWARGS, + success_states={"state1", "state2"}, + failure_states={"state2", "state3"}, + ) + + @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["success_states"]) + def test_poke_completed(self, mock_invoke_rest_api, state): + mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} + assert MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS, **SENSOR_STATE_KWARGS).poke({}) + + @pytest.mark.parametrize("state", ["e", "f"]) + def test_poke_not_completed(self, mock_invoke_rest_api, state): + mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} + assert not MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS, **SENSOR_STATE_KWARGS).poke({}) + + @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["failure_states"]) + def test_poke_terminated(self, mock_invoke_rest_api, state): + mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} + with pytest.raises(AirflowException, match=f".*{state}.*"): + MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS, **SENSOR_STATE_KWARGS).poke({}) + + +class TestMwaaTaskSuccessSensor: + def test_init_success(self): + sensor = MwaaTaskSensor(**SENSOR_TASK_KWARGS, **SENSOR_STATE_KWARGS) + assert sensor.external_env_name == SENSOR_TASK_KWARGS["external_env_name"] + assert sensor.external_dag_id == SENSOR_TASK_KWARGS["external_dag_id"] + assert sensor.external_dag_run_id == SENSOR_TASK_KWARGS["external_dag_run_id"] + assert sensor.external_task_id == SENSOR_TASK_KWARGS["external_task_id"] + assert set(sensor.success_states) == set(SENSOR_STATE_KWARGS["success_states"]) + assert set(sensor.failure_states) == set(SENSOR_STATE_KWARGS["failure_states"]) + assert sensor.deferrable == SENSOR_TASK_KWARGS["deferrable"] + assert sensor.poke_interval == SENSOR_TASK_KWARGS["poke_interval"] + assert sensor.max_retries == SENSOR_TASK_KWARGS["max_retries"] + + sensor = MwaaTaskSensor(**SENSOR_TASK_KWARGS) + assert sensor.success_states == {TaskInstanceState.SUCCESS.value} + assert sensor.failure_states == {TaskInstanceState.FAILED.value} + + def test_init_failure(self): + with pytest.raises(ValueError, match=r".*success_states.*failure_states.*"): + MwaaTaskSensor( + **SENSOR_TASK_KWARGS, success_states={"state1", "state2"}, failure_states={"state2", "state3"} ) @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["success_states"]) def test_poke_completed(self, mock_invoke_rest_api, state): mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} - assert MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({}) + assert MwaaTaskSensor(**SENSOR_TASK_KWARGS, **SENSOR_STATE_KWARGS).poke({}) @pytest.mark.parametrize("state", ["e", "f"]) def test_poke_not_completed(self, mock_invoke_rest_api, state): mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} - assert not MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({}) + assert not MwaaTaskSensor(**SENSOR_TASK_KWARGS, **SENSOR_STATE_KWARGS).poke({}) @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["failure_states"]) def test_poke_terminated(self, mock_invoke_rest_api, state): mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} with pytest.raises(AirflowException, match=f".*{state}.*"): - MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({}) + MwaaTaskSensor(**SENSOR_TASK_KWARGS, **SENSOR_STATE_KWARGS).poke({}) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py index c58bdce6dd7f7..4293bcfb9d9cc 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py @@ -22,23 +22,30 @@ import pytest from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook -from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger +from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger, MwaaTaskCompletedTrigger from airflow.triggers.base import TriggerEvent from airflow.utils.state import DagRunState from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.mwaa." -TRIGGER_KWARGS = { +TRIGGER_DAG_RUN_KWARGS = { "external_env_name": "test_env", "external_dag_id": "test_dag", "external_dag_run_id": "test_run_id", } +TRIGGER_TASK_KWARGS = { + "external_env_name": "test_env", + "external_dag_id": "test_dag", + "external_dag_run_id": "test_run_id", + "external_task_id": "test_task_id", +} + class TestMwaaDagRunCompletedTrigger: def test_init_states(self): - trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS) + trigger = MwaaDagRunCompletedTrigger(**TRIGGER_DAG_RUN_KWARGS) assert trigger.success_states == {DagRunState.SUCCESS.value} assert trigger.failure_states == {DagRunState.FAILED.value} acceptors = trigger.waiter_config_overrides["acceptors"] @@ -75,19 +82,31 @@ def test_init_states(self): def test_init_fail(self): with pytest.raises(ValueError, match=r".*success_states.*failure_states.*"): - MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS, success_states=("a", "b"), failure_states=("b", "c")) + MwaaDagRunCompletedTrigger( + **TRIGGER_DAG_RUN_KWARGS, success_states=("a", "b"), failure_states=("b", "c") + ) + + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = MwaaDagRunCompletedTrigger(**TRIGGER_DAG_RUN_KWARGS, aws_conn_id=OVERWRITTEN_CONN) + assert op.hook().aws_conn_id == OVERWRITTEN_CONN + + def test_no_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = MwaaDagRunCompletedTrigger(**TRIGGER_DAG_RUN_KWARGS) + assert op.hook().aws_conn_id == DEFAULT_CONN def test_serialization(self): success_states = ["a", "b"] failure_states = ["c", "d"] trigger = MwaaDagRunCompletedTrigger( - **TRIGGER_KWARGS, success_states=success_states, failure_states=failure_states + **TRIGGER_DAG_RUN_KWARGS, success_states=success_states, failure_states=failure_states ) classpath, kwargs = trigger.serialize() assert classpath == BASE_TRIGGER_CLASSPATH + "MwaaDagRunCompletedTrigger" - assert kwargs.get("external_env_name") == TRIGGER_KWARGS["external_env_name"] - assert kwargs.get("external_dag_id") == TRIGGER_KWARGS["external_dag_id"] - assert kwargs.get("external_dag_run_id") == TRIGGER_KWARGS["external_dag_run_id"] + assert kwargs.get("external_env_name") == TRIGGER_DAG_RUN_KWARGS["external_env_name"] + assert kwargs.get("external_dag_id") == TRIGGER_DAG_RUN_KWARGS["external_dag_id"] + assert kwargs.get("external_dag_run_id") == TRIGGER_DAG_RUN_KWARGS["external_dag_run_id"] assert kwargs.get("success_states") == success_states assert kwargs.get("failure_states") == failure_states @@ -97,13 +116,63 @@ def test_serialization(self): async def test_run_success(self, mock_async_conn, mock_get_waiter): mock_async_conn.__aenter__.return_value = mock.MagicMock() mock_get_waiter().wait = AsyncMock() - trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS) + trigger = MwaaDagRunCompletedTrigger(**TRIGGER_DAG_RUN_KWARGS) generator = trigger.run() response = await generator.asend(None) assert response == TriggerEvent( - {"status": "success", "dag_run_id": TRIGGER_KWARGS["external_dag_run_id"]} + {"status": "success", "dag_run_id": TRIGGER_DAG_RUN_KWARGS["external_dag_run_id"]} ) assert_expected_waiter_type(mock_get_waiter, "mwaa_dag_run_complete") mock_get_waiter().wait.assert_called_once() + + +class TestMwaaTaskCompletedTrigger: + def test_overwritten_conn_passed_to_hook(self): + OVERWRITTEN_CONN = "new-conn-id" + op = MwaaTaskCompletedTrigger(**TRIGGER_TASK_KWARGS, aws_conn_id=OVERWRITTEN_CONN) + assert op.hook().aws_conn_id == OVERWRITTEN_CONN + + def test_no_conn_passed_to_hook(self): + DEFAULT_CONN = "aws_default" + op = MwaaTaskCompletedTrigger(**TRIGGER_TASK_KWARGS) + assert op.hook().aws_conn_id == DEFAULT_CONN + + def test_init_fail(self): + with pytest.raises(ValueError, match=r".*success_states.*failure_states.*"): + MwaaTaskCompletedTrigger( + **TRIGGER_TASK_KWARGS, success_states=("a", "b"), failure_states=("b", "c") + ) + + def test_serialization(self): + success_states = ["a", "b"] + failure_states = ["c", "d"] + trigger = MwaaTaskCompletedTrigger( + **TRIGGER_TASK_KWARGS, success_states=success_states, failure_states=failure_states + ) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "MwaaTaskCompletedTrigger" + assert kwargs.get("external_env_name") == TRIGGER_TASK_KWARGS["external_env_name"] + assert kwargs.get("external_dag_id") == TRIGGER_TASK_KWARGS["external_dag_id"] + assert kwargs.get("external_dag_run_id") == TRIGGER_TASK_KWARGS["external_dag_run_id"] + assert kwargs.get("external_task_id") == TRIGGER_TASK_KWARGS["external_task_id"] + assert kwargs.get("success_states") == success_states + assert kwargs.get("failure_states") == failure_states + + @pytest.mark.asyncio + @mock.patch.object(MwaaHook, "get_waiter") + @mock.patch.object(MwaaHook, "get_async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = MwaaTaskCompletedTrigger(**TRIGGER_TASK_KWARGS) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "task_id": TRIGGER_TASK_KWARGS["external_task_id"]} + ) + assert_expected_waiter_type(mock_get_waiter, "mwaa_task_complete") + mock_get_waiter().wait.assert_called_once()