diff --git a/providers/standard/src/airflow/providers/standard/exceptions.py b/providers/standard/src/airflow/providers/standard/exceptions.py new file mode 100644 index 0000000000000..66acd54aa450f --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/exceptions.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Exceptions used by Standard Provider.""" + +from __future__ import annotations + +from airflow.exceptions import AirflowException + + +class AirflowExternalTaskSensorException(AirflowException): + """Base exception for all ExternalTaskSensor related errors.""" + + +class ExternalDagNotFoundError(AirflowExternalTaskSensorException): + """Raised when the external DAG does not exist.""" + + +class ExternalDagDeletedError(AirflowExternalTaskSensorException): + """Raised when the external DAG was deleted.""" + + +class ExternalTaskNotFoundError(AirflowExternalTaskSensorException): + """Raised when the external task does not exist.""" + + +class ExternalTaskGroupNotFoundError(AirflowExternalTaskSensorException): + """Raised when the external task group does not exist.""" + + +class ExternalTaskFailedError(AirflowExternalTaskSensorException): + """Raised when the external task failed.""" + + +class ExternalTaskGroupFailedError(AirflowExternalTaskSensorException): + """Raised when the external task group failed.""" + + +class ExternalDagFailedError(AirflowExternalTaskSensorException): + """Raised when the external DAG failed.""" + + +class DuplicateStateError(AirflowExternalTaskSensorException): + """Raised when duplicate states are provided across allowed, skipped and failed states.""" diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 3304e0cd2de4e..325cfcd4d3c51 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -24,9 +23,19 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowSkipException from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag +from airflow.providers.standard.exceptions import ( + DuplicateStateError, + ExternalDagDeletedError, + ExternalDagFailedError, + ExternalDagNotFoundError, + ExternalTaskFailedError, + ExternalTaskGroupFailedError, + ExternalTaskGroupNotFoundError, + ExternalTaskNotFoundError, +) from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.external_task import WorkflowTrigger from airflow.providers.standard.utils.sensor_helper import _get_count, _get_external_task_group_task_ids @@ -190,7 +199,7 @@ def __init__( total_states = set(self.allowed_states + self.skipped_states + self.failed_states) if len(total_states) != len(self.allowed_states) + len(self.skipped_states) + len(self.failed_states): - raise AirflowException( + raise DuplicateStateError( "Duplicate values provided across allowed_states, skipped_states and failed_states." ) @@ -356,7 +365,7 @@ def _handle_failed_states(self, count_failed: float | int) -> None: f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." ) - raise AirflowException( + raise ExternalTaskFailedError( f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} failed." ) @@ -366,7 +375,7 @@ def _handle_failed_states(self, count_failed: float | int) -> None: f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." ) - raise AirflowException( + raise ExternalTaskGroupFailedError( f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed." ) @@ -374,7 +383,7 @@ def _handle_failed_states(self, count_failed: float | int) -> None: raise AirflowSkipException( f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." ) - raise AirflowException(f"The external DAG {self.external_dag_id} failed.") + raise ExternalDagFailedError(f"The external DAG {self.external_dag_id} failed.") def _handle_skipped_states(self, count_skipped: float | int) -> None: """Handle skipped states and raise appropriate exceptions.""" @@ -443,10 +452,14 @@ def execute_complete(self, context, event=None): self.log.info("External tasks %s has executed successfully.", self.external_task_ids) elif event["status"] == "skipped": raise AirflowSkipException("External job has skipped skipping.") + elif event["status"] == "failed": + if self.soft_fail: + raise AirflowSkipException("External job has failed skipping.") + raise ExternalDagFailedError("External job has failed.") else: if self.soft_fail: raise AirflowSkipException("External job has failed skipping.") - raise AirflowException( + raise ExternalTaskNotFoundError( "Error occurred while trying to retrieve task status. Please, check the " "name of executed task and Dag." ) @@ -455,23 +468,31 @@ def _check_for_existence(self, session) -> None: dag_to_wait = DagModel.get_current(self.external_dag_id, session) if not dag_to_wait: - raise AirflowException(f"The external DAG {self.external_dag_id} does not exist.") + raise ExternalDagNotFoundError(f"The external DAG {self.external_dag_id} does not exist.") if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)): - raise AirflowException(f"The external DAG {self.external_dag_id} was deleted.") + raise ExternalDagDeletedError(f"The external DAG {self.external_dag_id} was deleted.") if self.external_task_ids: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) + if not refreshed_dag_info: + raise ExternalDagNotFoundError( + f"The external DAG {self.external_dag_id} could not be loaded." + ) for external_task_id in self.external_task_ids: if not refreshed_dag_info.has_task(external_task_id): - raise AirflowException( + raise ExternalTaskNotFoundError( f"The external task {external_task_id} in DAG {self.external_dag_id} does not exist." ) if self.external_task_group_id: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) + if not refreshed_dag_info: + raise ExternalDagNotFoundError( + f"The external DAG {self.external_dag_id} could not be loaded." + ) if not refreshed_dag_info.has_task_group(self.external_task_group_id): - raise AirflowException( + raise ExternalTaskGroupNotFoundError( f"The external task group '{self.external_task_group_id}' in " f"DAG '{self.external_dag_id}' does not exist." ) diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index e267ecfe3ba18..fe90c14fabd8a 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -115,9 +115,7 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: if failed_count > 0: yield TriggerEvent({"status": "failed"}) return - else: - yield TriggerEvent({"status": "success"}) - return + if self.skipped_states: skipped_count = await get_count_func(self.skipped_states) if skipped_count > 0: diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index ab1a3a7a9240d..5c1990a6b7f08 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -26,12 +26,27 @@ import pytest from airflow import settings -from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred +from airflow.exceptions import ( + AirflowException, + AirflowSensorTimeout, + AirflowSkipException, + TaskDeferred, +) from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.models.xcom_arg import XComArg +from airflow.providers.standard.exceptions import ( + DuplicateStateError, + ExternalDagDeletedError, + ExternalDagFailedError, + ExternalDagNotFoundError, + ExternalTaskFailedError, + ExternalTaskGroupFailedError, + ExternalTaskGroupNotFoundError, + ExternalTaskNotFoundError, +) from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator @@ -226,7 +241,7 @@ def test_external_task_group_not_exists_without_check_existence(self): dag=self.dag, poke_interval=0.1, ) - with pytest.raises(AirflowException, match="Sensor has timed out"): + with pytest.raises(AirflowSensorTimeout, match="Sensor has timed out"): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_group_sensor_success(self): @@ -253,13 +268,13 @@ def test_external_task_group_sensor_failed_states(self): dag=self.dag, ) with pytest.raises( - AirflowException, + ExternalTaskGroupFailedError, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_catch_overlap_allowed_failed_state(self): - with pytest.raises(AirflowException): + with pytest.raises(DuplicateStateError): ExternalTaskSensor( task_id="test_external_task_sensor_check", external_dag_id=TEST_DAG_ID, @@ -303,7 +318,7 @@ def test_external_task_sensor_failed_states_as_success(self, caplog): error_message = rf"Some of the external tasks \['{TEST_TASK_ID}'\] in DAG {TEST_DAG_ID} failed\." with caplog.at_level(logging.INFO, logger=op.log.name): caplog.clear() - with pytest.raises(AirflowException, match=error_message): + with pytest.raises(ExternalTaskFailedError, match=error_message): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert ( f"Poking for tasks ['{TEST_TASK_ID}'] in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... " @@ -404,7 +419,7 @@ def test_external_task_sensor_failed_states_as_success_mulitple_task_ids(self, c ) with caplog.at_level(logging.INFO, logger=op.log.name): caplog.clear() - with pytest.raises(AirflowException, match=error_message): + with pytest.raises(ExternalTaskFailedError, match=error_message): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert ( f"Poking for tasks ['{TEST_TASK_ID}', '{TEST_TASK_ID_ALTERNATE}'] " @@ -552,12 +567,12 @@ def test_external_task_sensor_fn_multiple_logical_dates(self): dag=dag, ) - # We need to test for an AirflowException explicitly since + # We need to test for an ExternalTaskFailedError explicitly since # AirflowSensorTimeout is a subclass that will be raised if this does # not execute properly. - with pytest.raises(AirflowException) as ex_ctx: + with pytest.raises(ExternalTaskFailedError) as ex_ctx: task_chain_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - assert type(ex_ctx.value) is AirflowException + assert type(ex_ctx.value) is ExternalTaskFailedError def test_external_task_sensor_delta(self): self.add_time_sensor() @@ -745,15 +760,16 @@ def test_catch_invalid_allowed_states(self): ) def test_external_task_sensor_waits_for_task_check_existence(self): + self.add_time_sensor() op = ExternalTaskSensor( task_id="test_external_task_sensor_check", - external_dag_id="example_bash_operator", + external_dag_id=TEST_DAG_ID, external_task_id="non-existing-task", check_existence=True, dag=self.dag, ) - with pytest.raises(AirflowException): + with pytest.raises(ExternalDagNotFoundError): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_waits_for_dag_check_existence(self): @@ -765,7 +781,7 @@ def test_external_task_sensor_waits_for_dag_check_existence(self): dag=self.dag, ) - with pytest.raises(AirflowException): + with pytest.raises(ExternalDagNotFoundError): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_group_with_mapped_tasks_sensor_success(self): @@ -791,7 +807,7 @@ def test_external_task_group_with_mapped_tasks_failed_states(self): dag=self.dag, ) with pytest.raises( - AirflowException, + ExternalTaskGroupFailedError, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -846,7 +862,7 @@ def test_external_task_group_when_there_is_no_TIs(self): ( ( False, - AirflowException, + ExternalTaskFailedError, ), ( True, @@ -870,8 +886,24 @@ def test_fail_poke( deferrable=False, **kwargs, ) - with pytest.raises(expected_exception, match=expected_message): - op.execute(context={}) + + # We need to handle the specific exception types based on kwargs + if not soft_fail: + expected_exc = expected_exception + if "external_task_ids" in kwargs: + expected_exc = ExternalTaskFailedError + elif "external_task_group_id" in kwargs: + expected_exc = ExternalTaskGroupFailedError + elif "failed_states" in kwargs and not any( + k in kwargs for k in ["external_task_ids", "external_task_group_id"] + ): + expected_exc = ExternalDagFailedError + + with pytest.raises(expected_exc, match=expected_message): + op.execute(context={}) + else: + with pytest.raises(expected_exception, match=expected_message): + op.execute(context={}) @pytest.mark.parametrize( "response_get_current, response_exists, kwargs, expected_message", @@ -903,11 +935,11 @@ def test_fail_poke( ( ( False, - AirflowException, + ExternalDagNotFoundError, ), ( True, - AirflowException, + ExternalDagNotFoundError, ), ), ) @@ -946,7 +978,17 @@ def test_fail__check_for_existence( ) if not hasattr(op, "never_fail"): expected_message = "Skipping due to soft_fail is set to True." if soft_fail else expected_message - with pytest.raises(expected_exception, match=expected_message): + specific_exception = expected_exception + if response_get_current is None: + specific_exception = ExternalDagNotFoundError + elif not response_exists: + specific_exception = ExternalDagDeletedError + elif "external_task_ids" in kwargs: + specific_exception = ExternalTaskNotFoundError + elif "external_task_group_id" in kwargs: + specific_exception = ExternalTaskGroupNotFoundError + + with pytest.raises(specific_exception, match=expected_message): op.execute(context={}) @@ -1002,7 +1044,7 @@ def test_external_task_sensor_failure(self, dag_maker): self.context["ti"].get_ti_count.return_value = 1 - with pytest.raises(AirflowException): + with pytest.raises(ExternalTaskFailedError): op.execute(context=self.context) self.context["ti"].get_ti_count.assert_called_once_with( @@ -1227,7 +1269,7 @@ def test_external_task_sensor_task_group_failed_states(self, dag_maker): self.context["ti"].get_task_states.return_value = {"run_id": {"test_group.task_id": State.FAILED}} - with pytest.raises(AirflowException): + with pytest.raises(ExternalTaskGroupFailedError): op.execute(context=self.context) self.context["ti"].get_task_states.assert_called_once_with( @@ -1261,7 +1303,7 @@ def test_defer_and_fire_task_state_trigger(self): assert isinstance(exc.value.trigger, WorkflowTrigger), "Trigger is not a WorkflowTrigger" def test_defer_and_fire_failed_state_trigger(self): - """Tests that an AirflowException is raised in case of error event""" + """Tests that an ExternalTaskNotFoundError is raised in case of error event""" sensor = ExternalTaskSensor( task_id=TASK_ID, external_task_id=EXTERNAL_TASK_ID, @@ -1269,13 +1311,13 @@ def test_defer_and_fire_failed_state_trigger(self): deferrable=True, ) - with pytest.raises(AirflowException): + with pytest.raises(ExternalTaskNotFoundError): sensor.execute_complete( context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} ) def test_defer_and_fire_timeout_state_trigger(self): - """Tests that an AirflowException is raised in case of timeout event""" + """Tests that an ExternalTaskNotFoundError is raised in case of timeout event""" sensor = ExternalTaskSensor( task_id=TASK_ID, external_task_id=EXTERNAL_TASK_ID, @@ -1283,7 +1325,7 @@ def test_defer_and_fire_timeout_state_trigger(self): deferrable=True, ) - with pytest.raises(AirflowException): + with pytest.raises(ExternalTaskNotFoundError): sensor.execute_complete( context=mock.MagicMock(), event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."}, @@ -1305,6 +1347,55 @@ def test_defer_execute_check_correct_logging(self): ) mock_log_info.assert_called_with("External tasks %s has executed successfully.", [EXTERNAL_TASK_ID]) + def test_defer_execute_check_failed_status(self): + """Tests that the execute_complete method properly handles the 'failed' status from WorkflowTrigger""" + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + ) + + with pytest.raises(ExternalDagFailedError, match="External job has failed."): + sensor.execute_complete( + context=mock.MagicMock(), + event={"status": "failed"}, + ) + + def test_defer_execute_check_failed_status_soft_fail(self): + """Tests that the execute_complete method properly handles the 'failed' status with soft_fail=True""" + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + soft_fail=True, + ) + + with pytest.raises(AirflowSkipException, match="External job has failed skipping."): + sensor.execute_complete( + context=mock.MagicMock(), + event={"status": "failed"}, + ) + + def test_defer_with_failed_states(self): + """Tests that failed_states are properly passed to the WorkflowTrigger when the sensor is deferred""" + failed_states = ["failed", "upstream_failed"] + sensor = ExternalTaskSensor( + task_id=TASK_ID, + external_task_id=EXTERNAL_TASK_ID, + external_dag_id=EXTERNAL_DAG_ID, + deferrable=True, + failed_states=failed_states, + ) + + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context=mock.MagicMock()) + + trigger = exc.value.trigger + assert isinstance(trigger, WorkflowTrigger), "Trigger is not a WorkflowTrigger" + assert trigger.failed_states == failed_states, "failed_states not properly passed to WorkflowTrigger" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Needs Flask app context fixture for AF 2") @pytest.mark.parametrize( diff --git a/providers/standard/tests/unit/standard/test_exceptions.py b/providers/standard/tests/unit/standard/test_exceptions.py new file mode 100644 index 0000000000000..8e52f5e19b2e9 --- /dev/null +++ b/providers/standard/tests/unit/standard/test_exceptions.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.standard.exceptions import ( + AirflowExternalTaskSensorException, + DuplicateStateError, + ExternalDagDeletedError, + ExternalDagFailedError, + ExternalDagNotFoundError, + ExternalTaskFailedError, + ExternalTaskGroupFailedError, + ExternalTaskGroupNotFoundError, + ExternalTaskNotFoundError, +) + + +def test_external_task_sensor_exception(): + """Test if AirflowExternalTaskSensorException can be raised correctly.""" + with pytest.raises(AirflowExternalTaskSensorException, match="Task execution failed"): + raise AirflowExternalTaskSensorException("Task execution failed") + + +def test_external_dag_not_found_error(): + """Test if ExternalDagNotFoundError can be raised correctly.""" + with pytest.raises(ExternalDagNotFoundError, match="External DAG not found"): + raise ExternalDagNotFoundError("External DAG not found") + + # Verify it's a subclass of AirflowExternalTaskSensorException + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalDagNotFoundError("External DAG not found") + + +def test_external_dag_deleted_error(): + """Test if ExternalDagDeletedError can be raised correctly.""" + with pytest.raises(ExternalDagDeletedError, match="External DAG was deleted"): + raise ExternalDagDeletedError("External DAG was deleted") + + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalDagDeletedError("External DAG was deleted") + + +def test_external_task_not_found_error(): + """Test if ExternalTaskNotFoundError can be raised correctly.""" + with pytest.raises(ExternalTaskNotFoundError, match="External task not found"): + raise ExternalTaskNotFoundError("External task not found") + + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalTaskNotFoundError("External task not found") + + +def test_external_task_group_not_found_error(): + """Test if ExternalTaskGroupNotFoundError can be raised correctly.""" + with pytest.raises(ExternalTaskGroupNotFoundError, match="External task group not found"): + raise ExternalTaskGroupNotFoundError("External task group not found") + + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalTaskGroupNotFoundError("External task group not found") + + +def test_external_task_failed_error(): + """Test if ExternalTaskFailedError can be raised correctly.""" + with pytest.raises(ExternalTaskFailedError, match="External task failed"): + raise ExternalTaskFailedError("External task failed") + + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalTaskFailedError("External task failed") + + +def test_external_task_group_failed_error(): + """Test if ExternalTaskGroupFailedError can be raised correctly.""" + with pytest.raises(ExternalTaskGroupFailedError, match="External task group failed"): + raise ExternalTaskGroupFailedError("External task group failed") + + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalTaskGroupFailedError("External task group failed") + + +def test_external_dag_failed_error(): + """Test if ExternalDagFailedError can be raised correctly.""" + with pytest.raises(ExternalDagFailedError, match="External DAG failed"): + raise ExternalDagFailedError("External DAG failed") + + with pytest.raises(AirflowExternalTaskSensorException): + raise ExternalDagFailedError("External DAG failed") + + +def test_duplicate_state_error(): + """Test if DuplicateStateError can be raised correctly.""" + with pytest.raises(DuplicateStateError, match="Duplicate state provided"): + raise DuplicateStateError("Duplicate state provided") + + with pytest.raises(AirflowExternalTaskSensorException): + raise DuplicateStateError("Duplicate state provided") diff --git a/providers/standard/tests/unit/standard/triggers/test_external_task.py b/providers/standard/tests/unit/standard/triggers/test_external_task.py index 5f5b4861b04e9..492087b2c5663 100644 --- a/providers/standard/tests/unit/standard/triggers/test_external_task.py +++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py @@ -122,7 +122,7 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): @pytest.mark.asyncio @mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_ti_count") async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): - mock_get_count.return_value = 0 + mock_get_count.side_effect = [0, 1] # First 0 for failed_states, then 1 for allowed_states trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, @@ -130,6 +130,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): run_ids=[self.RUN_ID], external_task_ids=[self.TASK_ID], failed_states=self.STATES, + allowed_states=self.STATES, poke_interval=0.2, ) @@ -140,13 +141,22 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): result = trigger_task.result() assert isinstance(result, TriggerEvent) assert result.payload == {"status": "success"} - mock_get_count.assert_called_once_with( - dag_id="external_task", - task_ids=["external_task_op"], - logical_dates=[self.LOGICAL_DATE], - run_ids=[self.RUN_ID], - states=["success", "fail"], + + # Verify both calls were made + assert mock_get_count.call_count == 2 + mock_get_count.assert_has_calls( + [ + mock.call( + dag_id="external_task", + task_ids=["external_task_op"], + logical_dates=[self.LOGICAL_DATE], + run_ids=[self.RUN_ID], + states=["success", "fail"], + ), + ] + * 2 ) + # test that it returns after yielding with pytest.raises(StopAsyncIteration): await gen.__anext__() @@ -468,15 +478,20 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): await gen.__anext__() @mock.patch("airflow.providers.standard.triggers.external_task._get_count") + @mock.patch("asyncio.sleep") @pytest.mark.asyncio - async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): - mock_get_count.return_value = 0 + async def test_task_workflow_trigger_fail_count_eq_0(self, mock_sleep, mock_get_count): + mock_get_count.side_effect = [ + 0, + 1, + ] trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, **_DATES, external_task_ids=[self.TASK_ID], failed_states=self.STATES, + allowed_states=self.STATES, poke_interval=0.2, ) @@ -487,14 +502,23 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): result = trigger_task.result() assert isinstance(result, TriggerEvent) assert result.payload == {"status": "success"} - mock_get_count.assert_called_once_with( - dttm_filter=value, - external_task_ids=["external_task_op"], - external_task_group_id=None, - external_dag_id="external_task", - states=["success", "fail"], + + assert mock_get_count.call_count == 2 + mock_get_count.assert_has_calls( + [ + mock.call( + dttm_filter=value, + external_task_ids=["external_task_op"], + external_task_group_id=None, + external_dag_id="external_task", + states=["success", "fail"], + ), + ] + * 2 ) - # test that it returns after yielding + + mock_sleep.assert_not_called() + with pytest.raises(StopAsyncIteration): await gen.__anext__()