diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 7a166559a72a5..3f8d8be6b3ce6 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -28,7 +28,7 @@ import warnings from collections import namedtuple from collections.abc import Generator -from datetime import date, datetime, timezone as _timezone +from datetime import date, datetime, timedelta, timezone as _timezone from functools import partial from importlib.util import find_spec from pathlib import Path @@ -46,6 +46,7 @@ AirflowException, AirflowProviderDeprecationWarning, DeserializingResultError, + TaskDeferred, ) from airflow.models.connection import Connection from airflow.models.taskinstance import TaskInstance, clear_task_instances @@ -62,6 +63,7 @@ _PythonVersionInfo, get_current_context, ) +from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger from airflow.providers.standard.utils.python_virtualenv import execute_in_subprocess, prepare_virtualenv from airflow.utils import timezone from airflow.utils.session import create_session @@ -893,6 +895,53 @@ def poke(self, context): "Sensor should be skipped by ShortCircuitOperator" ) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 implementation is different") + def test_short_circuit_operator_skips_deferrable_sensors(self): + """Ensure ShortCircuitOperator skips downstream deferrable sensor""" + from airflow.sdk.bases.sensor import BaseSensorOperator + + class CustomDeferrableS3Sensor(BaseSensorOperator): + def __init__(self, deferrable: bool = True, **kwargs): + super().__init__(**kwargs) + self.deferrable = deferrable + + def execute(self, context: Context) -> str | None: + if self.deferrable: + raise TaskDeferred( + trigger=TimeDeltaTrigger(timedelta(seconds=1)), + method_name="execute_complete", + ) + return "done" + + def execute_complete(self, context: Context, event=None) -> str: + return "done" + + with self.dag_maker("dag_test_shortcircuit_deferrable_sensor"): + short_circuit = ShortCircuitOperator( + task_id="check_if_should_continue", + python_callable=lambda: False, + ) + + deferrable_sensor = CustomDeferrableS3Sensor( + deferrable=True, + task_id="wait_for_s3_file_deferrable", + ) + + short_circuit >> deferrable_sensor + + dr = self.dag_maker.create_dagrun() + + self.dag_maker.run_ti("check_if_should_continue", dr) + + tis = dr.get_task_instances() + xcom_data = tis[0].xcom_pull(task_ids="check_if_should_continue", key="skipmixin_key") + + assert xcom_data is not None, "XCom data should exist" + skipped_task_ids = set(xcom_data.get("skipped", [])) + assert "wait_for_s3_file_deferrable" in skipped_task_ids, ( + "Deferrable sensor should be skipped by ShortCircuitOperator" + ) + virtualenv_string_args: list[str] = [] diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py b/providers/standard/tests/unit/standard/utils/test_skipmixin.py index 7d4c9875efbf7..488bd0f8b6313 100644 --- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py +++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py @@ -17,15 +17,17 @@ # under the License. from __future__ import annotations -import datetime +from datetime import datetime, timedelta +from typing import TYPE_CHECKING from unittest.mock import MagicMock, Mock, patch import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance as TI from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -44,6 +46,9 @@ from airflow.decorators import task, task_group from airflow.models.skipmixin import SkipMixin +if TYPE_CHECKING: + from airflow.utils.context import Context + DEFAULT_DATE = timezone.datetime(2016, 1, 1) DEFAULT_DAG_RUN_ID = "test1" @@ -62,7 +67,7 @@ def teardown_method(self): @patch("airflow.utils.timezone.utcnow") def test_skip(self, mock_now, dag_maker, session): - now = datetime.datetime.now(tz=datetime.timezone.utc) + now = datetime.now(tz=timezone.utc) mock_now.return_value = now with dag_maker("dag"): tasks = [EmptyOperator(task_id="task")] @@ -399,6 +404,54 @@ def poke(self, context): assert sensor_in_list is not None, "Sensor task should be found in list" assert isinstance(sensor_in_list, SDKBaseOperator), "Sensor should be instance of SDK BaseOperator" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Issue only exists in Airflow 3.x") + def test_ensure_tasks_includes_deferrable_sensors_airflow_3x(self, dag_maker): + """Test that sensors (inheriting from airflow.sdk.BaseOperator) are properly handled by _ensure_tasks.""" + from airflow.providers.standard.utils.skipmixin import _ensure_tasks + from airflow.sdk import BaseOperator as SDKBaseOperator + from airflow.sdk.bases.sensor import BaseSensorOperator + + class DummyDeferableSensor(BaseSensorOperator): + def __init__(self, deferrable: bool = True, **kwargs): + super().__init__(**kwargs) + self.deferrable = deferrable + + def execute(self, context) -> str | None: + if self.deferrable: + raise TaskDeferred( + trigger=TimeDeltaTrigger(timedelta(seconds=1)), + method_name="execute_complete", + ) + return "done" + + def execute_complete(self, context, event=None) -> str: + return "done" + + with dag_maker("dag_test_sensor_skipping") as dag: + regular_task = EmptyOperator(task_id="regular_task") + deferrable_sensor_task = DummyDeferableSensor(task_id="deferrable_sensor_task") + downstream_task = EmptyOperator(task_id="downstream_task") + + regular_task >> [deferrable_sensor_task, downstream_task] + + dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) + + downstream_nodes = dag.get_task("regular_task").downstream_list + task_list = _ensure_tasks(downstream_nodes) + + # Verify both the regular operator and sensor are included + task_ids = [t.task_id for t in task_list] + assert "deferrable_sensor_task" in task_ids, "Sensor should be included in task list" + assert "downstream_task" in task_ids, "Regular task should be included in task list" + assert len(task_list) == 2, "Both tasks should be included" + + # Also verify that the sensor is actually an instance of the correct BaseOperator + sensor_in_list = next((t for t in task_list if t.task_id == "deferrable_sensor_task"), None) + assert sensor_in_list is not None, "Deferrable Sensor task should be found in list" + assert isinstance(sensor_in_list, SDKBaseOperator), ( + "Deferrable Sensor should be instance of SDK BaseOperator" + ) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Integration test for Airflow 3.x sensor skipping") def test_skip_sensor_in_branching_scenario(self, dag_maker): """Integration test: verify sensors are properly skipped by branching operators in Airflow 3.x.""" @@ -440,3 +493,48 @@ def poke(self, context): # Verify that the regular task is properly marked for skipping skipped_tasks = set(exc_info.value.tasks) assert ("regular_task", -1) in skipped_tasks, "Regular task should be marked for skipping" + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3.x only") + def test_skip_deferrable_sensor_in_branching_scenario(self, dag_maker): + from airflow.sdk.bases.sensor import BaseSensorOperator + + class DummyDeferrableSensor(BaseSensorOperator): + def __init__(self, deferrable: bool = True, **kwargs): + super().__init__(**kwargs) + self.deferrable = deferrable + + def execute(self, context: Context) -> str | None: + if self.deferrable: + raise TaskDeferred( + trigger=TimeDeltaTrigger(timedelta(seconds=1)), + method_name="execute_complete", + ) + return "done" + + def execute_complete(self, context: Context, event=None) -> str: + return "done" + + with dag_maker("dag_test_branch_deferrable_sensor_skipping"): + branch_task = EmptyOperator(task_id="branch_task") + regular_task = EmptyOperator(task_id="regular_task") + deferrable_sensor = DummyDeferrableSensor(task_id="deferrable_sensor_task") + branch_task >> [regular_task, deferrable_sensor] + + dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) + + dag_version = DagVersion.get_latest_version(branch_task.dag_id) + ti_branch = TI(branch_task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + + # Sensor should be skipped if regular_task is selected + with pytest.raises(DownstreamTasksSkipped) as exc_info: + SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="regular_task") + + skipped = set(exc_info.value.tasks) + assert ("deferrable_sensor_task", -1) in skipped + + # Regular should be skipped if sensor_task is selected + with pytest.raises(DownstreamTasksSkipped) as exc_info: + SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="deferrable_sensor_task") + + skipped = set(exc_info.value.tasks) + assert ("regular_task", -1) in skipped