diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 0a508923a6733..94ad23eab46f0 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -702,6 +702,9 @@ triggers: - integration-name: AWS Lambda python-modules: - airflow.providers.amazon.aws.triggers.lambda_function + - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) + python-modules: + - airflow.providers.amazon.aws.triggers.mwaa - integration-name: Amazon Managed Service for Apache Flink python-modules: - airflow.providers.amazon.aws.triggers.kinesis_analytics diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py index 862a210cd23cd..0230dea91f14e 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py @@ -943,6 +943,7 @@ def get_waiter( self, waiter_name: str, parameters: dict[str, str] | None = None, + config_overrides: dict[str, Any] | None = None, deferrable: bool = False, client=None, ) -> Waiter: @@ -962,6 +963,9 @@ def get_waiter( :param parameters: will scan the waiter config for the keys of that dict, and replace them with the corresponding value. If a custom waiter has such keys to be expanded, they need to be provided here. + Note: cannot be used if parameters are included in config_overrides + :param config_overrides: will update values of provided keys in the waiter's + config. Only specified keys will be updated. :param deferrable: If True, the waiter is going to be an async custom waiter. An async client must be provided in that case. :param client: The client to use for the waiter's operations @@ -970,14 +974,18 @@ def get_waiter( if deferrable and not client: raise ValueError("client must be provided for a deferrable waiter.") + if parameters is not None and config_overrides is not None and "acceptors" in config_overrides: + raise ValueError('parameters must be None when "acceptors" is included in config_overrides') # Currently, the custom waiter doesn't work with resource_type, only client_type is supported. client = client or self._client if self.waiter_path and (waiter_name in self._list_custom_waiters()): # Technically if waiter_name is in custom_waiters then self.waiter_path must # exist but MyPy doesn't like the fact that self.waiter_path could be None. with open(self.waiter_path) as config_file: - config = json.loads(config_file.read()) + config: dict = json.loads(config_file.read()) + if config_overrides is not None: + config["waiters"][waiter_name].update(config_overrides) config = self._apply_parameters_value(config, waiter_name, parameters) return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter( waiter_name diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py index 9007379e22c5e..bbd23c60cb38d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py @@ -18,13 +18,16 @@ from __future__ import annotations from collections.abc import Collection, Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.utils.state import State +from airflow.utils.state import DagRunState if TYPE_CHECKING: from airflow.utils.context import Context @@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]): (templated) :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated) :param success_states: Collection of DAG Run states that would make this task marked as successful, default is - ``airflow.utils.state.State.success_states`` (templated) + ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated) :param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an - AirflowException, default is ``airflow.utils.state.State.failed_states`` (templated) + AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated) + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60) + :param max_retries: Number of times before returning the current state. (default: 720) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ aws_hook_class = MwaaHook @@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]): "external_dag_run_id", "success_states", "failure_states", + "deferrable", + "max_retries", + "poke_interval", ) def __init__( @@ -68,19 +89,25 @@ def __init__( external_dag_run_id: str, success_states: Collection[str] | None = None, failure_states: Collection[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poke_interval: int = 60, + max_retries: int = 720, **kwargs, ): super().__init__(**kwargs) - self.success_states = set(success_states if success_states else State.success_states) - self.failure_states = set(failure_states if failure_states else State.failed_states) + self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value} + self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value} if len(self.success_states & self.failure_states): - raise AirflowException("allowed_states and failed_states must not have any values in common") + raise ValueError("success_states and failure_states must not have any values in common") self.external_env_name = external_env_name self.external_dag_id = external_dag_id self.external_dag_run_id = external_dag_run_id + self.deferrable = deferrable + self.poke_interval = poke_interval + self.max_retries = max_retries def poke(self, context: Context) -> bool: self.log.info( @@ -102,12 +129,32 @@ def poke(self, context: Context) -> bool: # The scope of this sensor is going to only be raising AirflowException due to failure of the DAGRun state = response["RestApiResponse"]["state"] - if state in self.success_states: - return True if state in self.failure_states: raise AirflowException( f"The DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} " - f"failed with state {state}." + f"failed with state: {state}" ) - return False + + return state in self.success_states + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + validate_execute_complete_event(event) + + def execute(self, context: Context): + if self.deferrable: + self.defer( + trigger=MwaaDagRunCompletedTrigger( + external_env_name=self.external_env_name, + external_dag_id=self.external_dag_id, + external_dag_run_id=self.external_dag_run_id, + success_states=self.success_states, + failure_states=self.failure_states, + waiter_delay=self.poke_interval, + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + else: + super().execute(context=context) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py index f2c71a99adce9..6183020a8f925 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py @@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger): :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. + :param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will + be updated. :param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook. :param region_name: The AWS region where the resources to watch are. To be used to build the hook. :param verify: Whether or not to verify SSL certificates. To be used to build the hook. @@ -77,6 +79,7 @@ def __init__( return_value: Any, waiter_delay: int, waiter_max_attempts: int, + waiter_config_overrides: dict[str, Any] | None = None, aws_conn_id: str | None, region_name: str | None = None, verify: bool | str | None = None, @@ -91,6 +94,7 @@ def __init__( self.failure_message = failure_message self.status_message = status_message self.status_queries = status_queries + self.waiter_config_overrides = waiter_config_overrides self.return_key = return_key self.return_value = return_value @@ -140,7 +144,12 @@ def hook(self) -> AwsGenericHook: async def run(self) -> AsyncIterator[TriggerEvent]: hook = self.hook() async with await hook.get_async_conn() as client: - waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client) + waiter = hook.get_waiter( + self.waiter_name, + deferrable=True, + client=client, + config_overrides=self.waiter_config_overrides, + ) await async_wait( waiter, self.waiter_delay, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py new file mode 100644 index 0000000000000..bb6306d288ee0 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from collections.abc import Collection +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.utils.state import DagRunState + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger): + """ + Trigger when an MWAA Dag Run is complete. + + :param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for + (templated) + :param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for + (templated) + :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated) + :param success_states: Collection of DAG Run states that would make this task marked as successful, default is + ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated) + :param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an + AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated) + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 720) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + external_env_name: str, + external_dag_id: str, + external_dag_run_id: str, + success_states: Collection[str] | None = None, + failure_states: Collection[str] | None = None, + waiter_delay: int = 60, + waiter_max_attempts: int = 720, + aws_conn_id: str | None = None, + ) -> None: + self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value} + self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value} + + if len(self.success_states & self.failure_states): + raise ValueError("success_states and failure_states must not have any values in common") + + in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states + + super().__init__( + serialized_fields={ + "external_env_name": external_env_name, + "external_dag_id": external_dag_id, + "external_dag_run_id": external_dag_run_id, + "success_states": success_states, + "failure_states": failure_states, + }, + waiter_name="mwaa_dag_run_complete", + waiter_args={ + "Name": external_env_name, + "Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}", + "Method": "GET", + }, + failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state", + status_message="State of DAG run", + status_queries=["RestApiResponse.state"], + return_key="dag_run_id", + return_value=external_dag_run_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + waiter_config_overrides={ + "acceptors": _build_waiter_acceptors( + success_states=self.success_states, + failure_states=self.failure_states, + in_progress_states=in_progress_states, + ) + }, + ) + + def hook(self) -> AwsGenericHook: + return MwaaHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + +def _build_waiter_acceptors( + success_states: set[str], failure_states: set[str], in_progress_states: set[str] +) -> list: + def build_acceptor(dag_run_state: str, state_waiter_category: str): + return { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": dag_run_state, + "state": state_waiter_category, + } + + acceptors = [] + for state_set, state_waiter_category in ( + (success_states, "success"), + (failure_states, "failure"), + (in_progress_states, "retry"), + ): + for dag_run_state in state_set: + acceptors.append(build_acceptor(dag_run_state, state_waiter_category)) + + return acceptors diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py index 43d8bdf26d3fc..575a089382ad0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py @@ -136,15 +136,16 @@ async def async_wait( last_response = error.last_response if "terminal failure" in error_reason: - log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, last_response)) - raise AirflowException(f"{failure_message}: {error}") + raise AirflowException( + f"{failure_message}: {_LazyStatusFormatter(status_args, last_response)}\n{error}" + ) if ( "An error occurred" in error_reason and isinstance(last_response.get("Error"), dict) and "Code" in last_response.get("Error") ): - raise AirflowException(f"{failure_message}: {error}") + raise AirflowException(f"{failure_message}\n{last_response}\n{error}") log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response)) else: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json new file mode 100644 index 0000000000000..c1de661aa7b3d --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json @@ -0,0 +1,36 @@ +{ + "version": 2, + "waiters": { + "mwaa_dag_run_complete": { + "delay": 60, + "maxAttempts": 720, + "operation": "InvokeRestApi", + "acceptors": [ + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": "queued", + "state": "retry" + }, + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": "running", + "state": "retry" + }, + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": "success", + "state": "success" + }, + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": "failed", + "state": "failure" + } + ] + } + } +} diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 19bfdb0ae35e1..b634eb2f8cef7 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -881,6 +881,10 @@ def get_provider_info(): "integration-name": "AWS Lambda", "python-modules": ["airflow.providers.amazon.aws.triggers.lambda_function"], }, + { + "integration-name": "Amazon Managed Workflows for Apache Airflow (MWAA)", + "python-modules": ["airflow.providers.amazon.aws.triggers.mwaa"], + }, { "integration-name": "Amazon Managed Service for Apache Flink", "python-modules": ["airflow.providers.amazon.aws.triggers.kinesis_analytics"], diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py index 8ab39ecf1ad2c..345d4838412d1 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py @@ -23,13 +23,21 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor -from airflow.utils.state import State +from airflow.utils.state import DagRunState SENSOR_KWARGS = { "task_id": "test_mwaa_sensor", "external_env_name": "test_env", "external_dag_id": "test_dag", "external_dag_run_id": "test_run_id", + "deferrable": False, + "poke_interval": 5, + "max_retries": 100, +} + +SENSOR_STATE_KWARGS = { + "success_states": ["a", "b"], + "failure_states": ["c", "d"], } @@ -41,35 +49,38 @@ def mock_invoke_rest_api(): class TestMwaaDagRunSuccessSensor: def test_init_success(self): - success_states = {"state1", "state2"} - failure_states = {"state3", "state4"} - sensor = MwaaDagRunSensor( - **SENSOR_KWARGS, success_states=success_states, failure_states=failure_states - ) + sensor = MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS) assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"] assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"] assert sensor.external_dag_run_id == SENSOR_KWARGS["external_dag_run_id"] - assert set(sensor.success_states) == success_states - assert set(sensor.failure_states) == failure_states + assert set(sensor.success_states) == set(SENSOR_STATE_KWARGS["success_states"]) + assert set(sensor.failure_states) == set(SENSOR_STATE_KWARGS["failure_states"]) + assert sensor.deferrable == SENSOR_KWARGS["deferrable"] + assert sensor.poke_interval == SENSOR_KWARGS["poke_interval"] + assert sensor.max_retries == SENSOR_KWARGS["max_retries"] + + sensor = MwaaDagRunSensor(**SENSOR_KWARGS) + assert sensor.success_states == {DagRunState.SUCCESS.value} + assert sensor.failure_states == {DagRunState.FAILED.value} def test_init_failure(self): - with pytest.raises(AirflowException): + with pytest.raises(ValueError, match=r".*success_states.*failure_states.*"): MwaaDagRunSensor( **SENSOR_KWARGS, success_states={"state1", "state2"}, failure_states={"state2", "state3"} ) - @pytest.mark.parametrize("status", sorted(State.success_states)) - def test_poke_completed(self, mock_invoke_rest_api, status): - mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": status}} - assert MwaaDagRunSensor(**SENSOR_KWARGS).poke({}) + @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["success_states"]) + def test_poke_completed(self, mock_invoke_rest_api, state): + mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} + assert MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({}) - @pytest.mark.parametrize("status", ["running", "queued"]) - def test_poke_not_completed(self, mock_invoke_rest_api, status): - mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": status}} - assert not MwaaDagRunSensor(**SENSOR_KWARGS).poke({}) + @pytest.mark.parametrize("state", ["e", "f"]) + def test_poke_not_completed(self, mock_invoke_rest_api, state): + mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} + assert not MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({}) - @pytest.mark.parametrize("status", sorted(State.failed_states)) - def test_poke_terminated(self, mock_invoke_rest_api, status): - mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": status}} - with pytest.raises(AirflowException): - MwaaDagRunSensor(**SENSOR_KWARGS).poke({}) + @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["failure_states"]) + def test_poke_terminated(self, mock_invoke_rest_api, state): + mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": state}} + with pytest.raises(AirflowException, match=f".*{state}.*"): + MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({}) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py new file mode 100644 index 0000000000000..18c53e11f1802 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook +from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger +from airflow.triggers.base import TriggerEvent +from airflow.utils.state import DagRunState +from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type + +BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.mwaa." +TRIGGER_KWARGS = { + "external_env_name": "test_env", + "external_dag_id": "test_dag", + "external_dag_run_id": "test_run_id", +} + + +class TestMwaaDagRunCompletedTrigger: + def test_init_states(self): + trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS) + assert trigger.success_states == {DagRunState.SUCCESS.value} + assert trigger.failure_states == {DagRunState.FAILED.value} + acceptors = trigger.waiter_config_overrides["acceptors"] + expected_acceptors = [ + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": DagRunState.SUCCESS.value, + "state": "success", + }, + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": DagRunState.FAILED.value, + "state": "failure", + }, + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": DagRunState.RUNNING.value, + "state": "retry", + }, + { + "matcher": "path", + "argument": "RestApiResponse.state", + "expected": DagRunState.QUEUED.value, + "state": "retry", + }, + ] + assert len(acceptors) == len(DagRunState) + assert {tuple(sorted(a.items())) for a in acceptors} == { + tuple(sorted(a.items())) for a in expected_acceptors + } + + def test_init_fail(self): + with pytest.raises(ValueError, match=r".*success_states.*failure_states.*"): + MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS, success_states=("a", "b"), failure_states=("b", "c")) + + def test_serialization(self): + success_states = ["a", "b"] + failure_states = ["c", "d"] + trigger = MwaaDagRunCompletedTrigger( + **TRIGGER_KWARGS, success_states=success_states, failure_states=failure_states + ) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "MwaaDagRunCompletedTrigger" + assert kwargs.get("external_env_name") == TRIGGER_KWARGS["external_env_name"] + assert kwargs.get("external_dag_id") == TRIGGER_KWARGS["external_dag_id"] + assert kwargs.get("external_dag_run_id") == TRIGGER_KWARGS["external_dag_run_id"] + assert kwargs.get("success_states") == success_states + assert kwargs.get("failure_states") == failure_states + + @pytest.mark.asyncio + @mock.patch.object(MwaaHook, "get_waiter") + @mock.patch.object(MwaaHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "dag_run_id": TRIGGER_KWARGS["external_dag_run_id"]} + ) + assert_expected_waiter_type(mock_get_waiter, "mwaa_dag_run_complete") + mock_get_waiter().wait.assert_called_once()