Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from pathlib import Path

from airflow import DAG
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import Variable
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.sensors.time_delta import TimeDeltaSensorAsync
import pytest

from system.openlineage.operator import OpenLineageTestOperator

Expand All @@ -45,21 +47,22 @@ def check_start_amount_func():
schedule=None,
catchup=False,
) as dag:
# Timedelta is compared to the DAGRun start timestamp, which can occur long before a worker picks up the
# task. We need to ensure the sensor gets deferred at least once, so setting 180s.
wait = TimeDeltaSensorAsync(task_id="wait", delta=timedelta(seconds=180))
with pytest.warns(AirflowProviderDeprecationWarning):
# Timedelta is compared to the DAGRun start timestamp, which can occur long before a worker picks up the
# task. We need to ensure the sensor gets deferred at least once, so setting 180s.
wait = TimeDeltaSensorAsync(task_id="wait", delta=timedelta(seconds=180))

check_start_events_amount = PythonOperator(
task_id="check_start_events_amount", python_callable=check_start_amount_func
)
check_start_events_amount = PythonOperator(
task_id="check_start_events_amount", python_callable=check_start_amount_func
)

check_events = OpenLineageTestOperator(
task_id="check_events",
file_path=str(Path(__file__).parent / "example_openlineage_defer.json"),
allow_duplicate_events=True,
)
check_events = OpenLineageTestOperator(
task_id="check_events",
file_path=str(Path(__file__).parent / "example_openlineage_defer.json"),
allow_duplicate_events=True,
)

wait >> check_start_events_amount >> check_events
wait >> check_start_events_amount >> check_events


from tests_common.test_utils.system_tests import get_test_run # noqa: E402
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
# under the License.
from __future__ import annotations

import warnings
from datetime import datetime, timedelta
from time import sleep
from typing import TYPE_CHECKING, Any, NoReturn

from deprecated.classic import deprecated
from packaging.version import Version

from airflow.configuration import conf
from airflow.exceptions import AirflowSkipException
from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -52,16 +54,26 @@ class TimeDeltaSensor(BaseSensorOperator):
otherwise run_after will be used.

:param delta: time to wait before succeeding.
:param deferrable: Run sensor in deferrable mode. If set to True, task will defer itself to avoid taking up a worker slot while it is waiting.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensor`

"""

def __init__(self, *, delta, **kwargs):
def __init__(
self,
*,
delta: timedelta,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
end_from_trigger: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.delta = delta
self.deferrable = deferrable
self.end_from_trigger = end_from_trigger

def _derive_base_time(self, context: Context) -> datetime:
"""
Expand Down Expand Up @@ -90,27 +102,21 @@ def poke(self, context: Context) -> bool:
self.log.info("Checking if the delta has elapsed base_time=%s, delta=%s", base_time, self.delta)
return timezone.utcnow() > target_dttm


class TimeDeltaSensorAsync(TimeDeltaSensor):
"""
A deferrable drop-in replacement for TimeDeltaSensor.

Will defers itself to avoid taking up a worker slot while it is waiting.

:param delta: time length to wait after the data interval before succeeding.
:param end_from_trigger: End the task directly from the triggerer without going into the worker.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensorAsync`

Asynchronous execution
"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
super().__init__(delta=delta, **kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> bool | NoReturn:
"""
Depending on the deferrable flag, either execute the sensor in a blocking way or defer it.

- Sync path → use BaseSensorOperator.execute() which loops over ``poke``.
- Async path → defer to DateTimeTrigger and free the worker slot.
"""
if not self.deferrable:
return super().execute(context=context)

# Deferrable path
base_time = self._derive_base_time(context=context)
target_dttm: datetime = base_time + self.delta

Expand Down Expand Up @@ -146,6 +152,26 @@ def execute_complete(self, context: Context, event: Any = None) -> None:
return None


# TODO: Remove in the next major release
@deprecated(
"Use `TimeDeltaSensor` with `deferrable=True` instead", category=AirflowProviderDeprecationWarning
)
class TimeDeltaSensorAsync(TimeDeltaSensor):
"""
Deprecated. Use TimeDeltaSensor with deferrable=True instead.

:sphinx-autoapi-skip:
"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
warnings.warn(
"TimeDeltaSensorAsync is deprecated and will be removed in a future version. Use `TimeDeltaSensor` with `deferrable=True` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
super().__init__(delta=delta, deferrable=True, end_from_trigger=end_from_trigger, **kwargs)


class WaitSensor(BaseSensorOperator):
"""
A sensor that waits a specified period of time before completing.
Expand Down
6 changes: 4 additions & 2 deletions providers/standard/tests/system/standard/example_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from airflow.providers.standard.sensors.filesystem import FileSensor
from airflow.providers.standard.sensors.python import PythonSensor
from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.sensors.time_delta import TimeDeltaSensor, TimeDeltaSensorAsync
from airflow.providers.standard.sensors.time_delta import TimeDeltaSensor
from airflow.providers.standard.sensors.weekday import DayOfWeekSensor
from airflow.providers.standard.utils.weekday import WeekDay
from airflow.sdk import DAG
Expand Down Expand Up @@ -57,7 +57,9 @@ def failure_callable():
# [END example_time_delta_sensor]

# [START example_time_delta_sensor_async]
t0a = TimeDeltaSensorAsync(task_id="wait_some_seconds_async", delta=datetime.timedelta(seconds=2))
t0a = TimeDeltaSensor(
task_id="wait_some_seconds_async", delta=datetime.timedelta(seconds=2), deferrable=True
)
# [END example_time_delta_sensor_async]

# [START example_time_sensors]
Expand Down
137 changes: 97 additions & 40 deletions providers/standard/tests/unit/standard/sensors/test_time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@
from __future__ import annotations

from datetime import timedelta
from typing import Any
from unittest import mock

import pendulum
import pytest
import time_machine

from airflow.exceptions import TaskDeferred
from airflow.exceptions import AirflowProviderDeprecationWarning, TaskDeferred
from airflow.models import DagBag
from airflow.models.dag import DAG
from airflow.providers.standard.sensors.time_delta import (
TimeDeltaSensor,
TimeDeltaSensorAsync,
WaitSensor,
)
from airflow.providers.standard.triggers.temporal import DateTimeTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -105,6 +107,57 @@ def test_timedelta_sensor_run_after_vs_interval(run_after, interval_end, dag_mak
assert actual == expected


@pytest.mark.parametrize(
"run_after, interval_end",
[
(timezone.utcnow() + timedelta(days=1), timezone.utcnow() + timedelta(days=2)),
(timezone.utcnow() + timedelta(days=1), None),
],
)
def test_timedelta_sensor_deferrable_run_after_vs_interval(run_after, interval_end, dag_maker):
"""Test that TimeDeltaSensor defers correctly when flag is enabled."""
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context: dict[str, Any] = {}
if interval_end:
context["data_interval_end"] = interval_end

with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

delta = timedelta(minutes=5)
sensor = TimeDeltaSensor(
task_id="timedelta_sensor_deferrable",
delta=delta,
dag=dag,
deferrable=True, # <-- the feature under test
)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)

expected_base = interval_end or run_after
expected_fire_time = expected_base + delta

with pytest.raises(TaskDeferred) as td:
sensor.execute(context)

# The sensor should defer once with a DateTimeTrigger
trigger = td.value.trigger
assert isinstance(trigger, DateTimeTrigger)
assert trigger.moment == expected_fire_time


class TestTimeDeltaSensorAsync:
def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
Expand All @@ -117,17 +170,20 @@ def setup_method(self):
)
@mock.patch(DEFER_PATH)
def test_timedelta_sensor(self, defer_mock, should_defer):
delta = timedelta(hours=1)
op = TimeDeltaSensorAsync(task_id="timedelta_sensor_check", delta=delta, dag=self.dag)
if should_defer:
data_interval_end = pendulum.now("UTC").add(hours=1)
else:
data_interval_end = pendulum.now("UTC").replace(microsecond=0, second=0, minute=0).add(hours=-1)
op.execute({"data_interval_end": data_interval_end})
if should_defer:
defer_mock.assert_called_once()
else:
defer_mock.assert_not_called()
with pytest.warns(AirflowProviderDeprecationWarning):
delta = timedelta(hours=1)
op = TimeDeltaSensorAsync(task_id="timedelta_sensor_check", delta=delta, dag=self.dag)
if should_defer:
data_interval_end = pendulum.now("UTC").add(hours=1)
else:
data_interval_end = (
pendulum.now("UTC").replace(microsecond=0, second=0, minute=0).add(hours=-1)
)
op.execute({"data_interval_end": data_interval_end})
if should_defer:
defer_mock.assert_called_once()
else:
defer_mock.assert_not_called()

@pytest.mark.parametrize(
"should_defer",
Expand Down Expand Up @@ -157,31 +213,32 @@ def test_wait_sensor(self, sleep_mock, defer_mock, should_defer):
)
def test_timedelta_sensor_async_run_after_vs_interval(self, run_after, interval_end, dag_maker):
"""Interval end should be used as base time when present else run_after"""
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context = {}
if interval_end:
context["data_interval_end"] = interval_end
with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)
delta = timedelta(seconds=1)
op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag)
base_time = interval_end or run_after
expected_time = base_time + delta
with pytest.raises(TaskDeferred) as caught:
op.execute(context)

assert caught.value.trigger.moment == expected_time
with pytest.warns(AirflowProviderDeprecationWarning):
if not AIRFLOW_V_3_0_PLUS and not interval_end:
pytest.skip("not applicable")

context = {}
if interval_end:
context["data_interval_end"] = interval_end
with dag_maker() as dag:
kwargs = {}
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

kwargs.update(triggered_by=DagRunTriggeredByType.TEST, run_after=run_after)

dr = dag.create_dagrun(
run_id="abcrhroceuh",
run_type=DagRunType.MANUAL,
state=None,
**kwargs,
)
context.update(dag_run=dr)
delta = timedelta(seconds=1)
op = TimeDeltaSensorAsync(task_id="wait_sensor_check", delta=delta, dag=dag)
base_time = interval_end or run_after
expected_time = base_time + delta
with pytest.raises(TaskDeferred) as caught:
op.execute(context)

assert caught.value.trigger.moment == expected_time
Loading