Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion providers/standard/tests/unit/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from collections import namedtuple
from collections.abc import Generator
from datetime import date, datetime, timezone as _timezone
from datetime import date, datetime, timedelta, timezone as _timezone
from functools import partial
from importlib.util import find_spec
from pathlib import Path
Expand All @@ -46,6 +46,7 @@
AirflowException,
AirflowProviderDeprecationWarning,
DeserializingResultError,
TaskDeferred,
)
from airflow.models.connection import Connection
from airflow.models.taskinstance import TaskInstance, clear_task_instances
Expand All @@ -62,6 +63,7 @@
_PythonVersionInfo,
get_current_context,
)
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.providers.standard.utils.python_virtualenv import execute_in_subprocess, prepare_virtualenv
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -893,6 +895,53 @@ def poke(self, context):
"Sensor should be skipped by ShortCircuitOperator"
)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 implementation is different")
def test_short_circuit_operator_skips_deferrable_sensors(self):
"""Ensure ShortCircuitOperator skips downstream deferrable sensor"""
from airflow.sdk.bases.sensor import BaseSensorOperator

class CustomDeferrableS3Sensor(BaseSensorOperator):
def __init__(self, deferrable: bool = True, **kwargs):
super().__init__(**kwargs)
self.deferrable = deferrable

def execute(self, context: Context) -> str | None:
if self.deferrable:
raise TaskDeferred(
trigger=TimeDeltaTrigger(timedelta(seconds=1)),
method_name="execute_complete",
)
return "done"

def execute_complete(self, context: Context, event=None) -> str:
return "done"

with self.dag_maker("dag_test_shortcircuit_deferrable_sensor"):
short_circuit = ShortCircuitOperator(
task_id="check_if_should_continue",
python_callable=lambda: False,
)

deferrable_sensor = CustomDeferrableS3Sensor(
deferrable=True,
task_id="wait_for_s3_file_deferrable",
)

short_circuit >> deferrable_sensor

dr = self.dag_maker.create_dagrun()

self.dag_maker.run_ti("check_if_should_continue", dr)

tis = dr.get_task_instances()
xcom_data = tis[0].xcom_pull(task_ids="check_if_should_continue", key="skipmixin_key")

assert xcom_data is not None, "XCom data should exist"
skipped_task_ids = set(xcom_data.get("skipped", []))
assert "wait_for_s3_file_deferrable" in skipped_task_ids, (
"Deferrable sensor should be skipped by ShortCircuitOperator"
)


virtualenv_string_args: list[str] = []

Expand Down
104 changes: 101 additions & 3 deletions providers/standard/tests/unit/standard/utils/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
# under the License.
from __future__ import annotations

import datetime
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, Mock, patch

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance as TI
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
Expand All @@ -44,6 +46,9 @@
from airflow.decorators import task, task_group
from airflow.models.skipmixin import SkipMixin

if TYPE_CHECKING:
from airflow.utils.context import Context

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
DEFAULT_DAG_RUN_ID = "test1"

Expand All @@ -62,7 +67,7 @@ def teardown_method(self):

@patch("airflow.utils.timezone.utcnow")
def test_skip(self, mock_now, dag_maker, session):
now = datetime.datetime.now(tz=datetime.timezone.utc)
now = datetime.now(tz=timezone.utc)
mock_now.return_value = now
with dag_maker("dag"):
tasks = [EmptyOperator(task_id="task")]
Expand Down Expand Up @@ -399,6 +404,54 @@ def poke(self, context):
assert sensor_in_list is not None, "Sensor task should be found in list"
assert isinstance(sensor_in_list, SDKBaseOperator), "Sensor should be instance of SDK BaseOperator"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Issue only exists in Airflow 3.x")
def test_ensure_tasks_includes_deferrable_sensors_airflow_3x(self, dag_maker):
"""Test that sensors (inheriting from airflow.sdk.BaseOperator) are properly handled by _ensure_tasks."""
from airflow.providers.standard.utils.skipmixin import _ensure_tasks
from airflow.sdk import BaseOperator as SDKBaseOperator
from airflow.sdk.bases.sensor import BaseSensorOperator

class DummyDeferableSensor(BaseSensorOperator):
def __init__(self, deferrable: bool = True, **kwargs):
super().__init__(**kwargs)
self.deferrable = deferrable

def execute(self, context) -> str | None:
if self.deferrable:
raise TaskDeferred(
trigger=TimeDeltaTrigger(timedelta(seconds=1)),
method_name="execute_complete",
)
return "done"

def execute_complete(self, context, event=None) -> str:
return "done"

with dag_maker("dag_test_sensor_skipping") as dag:
regular_task = EmptyOperator(task_id="regular_task")
deferrable_sensor_task = DummyDeferableSensor(task_id="deferrable_sensor_task")
downstream_task = EmptyOperator(task_id="downstream_task")

regular_task >> [deferrable_sensor_task, downstream_task]

dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)

downstream_nodes = dag.get_task("regular_task").downstream_list
task_list = _ensure_tasks(downstream_nodes)

# Verify both the regular operator and sensor are included
task_ids = [t.task_id for t in task_list]
assert "deferrable_sensor_task" in task_ids, "Sensor should be included in task list"
assert "downstream_task" in task_ids, "Regular task should be included in task list"
assert len(task_list) == 2, "Both tasks should be included"

# Also verify that the sensor is actually an instance of the correct BaseOperator
sensor_in_list = next((t for t in task_list if t.task_id == "deferrable_sensor_task"), None)
assert sensor_in_list is not None, "Deferrable Sensor task should be found in list"
assert isinstance(sensor_in_list, SDKBaseOperator), (
"Deferrable Sensor should be instance of SDK BaseOperator"
)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Integration test for Airflow 3.x sensor skipping")
def test_skip_sensor_in_branching_scenario(self, dag_maker):
"""Integration test: verify sensors are properly skipped by branching operators in Airflow 3.x."""
Expand Down Expand Up @@ -440,3 +493,48 @@ def poke(self, context):
# Verify that the regular task is properly marked for skipping
skipped_tasks = set(exc_info.value.tasks)
assert ("regular_task", -1) in skipped_tasks, "Regular task should be marked for skipping"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3.x only")
def test_skip_deferrable_sensor_in_branching_scenario(self, dag_maker):
from airflow.sdk.bases.sensor import BaseSensorOperator

class DummyDeferrableSensor(BaseSensorOperator):
def __init__(self, deferrable: bool = True, **kwargs):
super().__init__(**kwargs)
self.deferrable = deferrable

def execute(self, context: Context) -> str | None:
if self.deferrable:
raise TaskDeferred(
trigger=TimeDeltaTrigger(timedelta(seconds=1)),
method_name="execute_complete",
)
return "done"

def execute_complete(self, context: Context, event=None) -> str:
return "done"

with dag_maker("dag_test_branch_deferrable_sensor_skipping"):
branch_task = EmptyOperator(task_id="branch_task")
regular_task = EmptyOperator(task_id="regular_task")
deferrable_sensor = DummyDeferrableSensor(task_id="deferrable_sensor_task")
branch_task >> [regular_task, deferrable_sensor]

dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)

dag_version = DagVersion.get_latest_version(branch_task.dag_id)
ti_branch = TI(branch_task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id)

# Sensor should be skipped if regular_task is selected
with pytest.raises(DownstreamTasksSkipped) as exc_info:
SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="regular_task")

skipped = set(exc_info.value.tasks)
assert ("deferrable_sensor_task", -1) in skipped

# Regular should be skipped if sensor_task is selected
with pytest.raises(DownstreamTasksSkipped) as exc_info:
SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="deferrable_sensor_task")

skipped = set(exc_info.value.tasks)
assert ("regular_task", -1) in skipped