diff --git a/providers/standard/src/airflow/providers/standard/sensors/date_time.py b/providers/standard/src/airflow/providers/standard/sensors/date_time.py index d04f524cbabc8..b1c5b5da2976d 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/date_time.py +++ b/providers/standard/src/airflow/providers/standard/sensors/date_time.py @@ -25,6 +25,7 @@ from airflow.providers.standard.triggers.temporal import DateTimeTrigger from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator +from airflow.utils import timezone try: from airflow.triggers.base import StartTriggerArgs @@ -41,8 +42,6 @@ class StartTriggerArgs: # type: ignore[no-redef] timeout: datetime.timedelta | None = None -from airflow.utils import timezone - if TYPE_CHECKING: try: from airflow.sdk.definitions.context import Context @@ -99,6 +98,13 @@ def poke(self, context: Context) -> bool: self.log.info("Checking if the time (%s) has come", self.target_time) return timezone.utcnow() > timezone.parse(self.target_time) + @property + def _moment(self) -> datetime.datetime: + if isinstance(self.target_time, datetime.datetime): + return self.target_time + + return timezone.parse(self.target_time) + class DateTimeSensorAsync(DateTimeSensor): """ @@ -145,11 +151,11 @@ def execute(self, context: Context) -> NoReturn: self.defer( method_name="execute_complete", trigger=DateTimeTrigger( - moment=timezone.parse(self.target_time), + moment=self._moment, end_from_trigger=self.end_from_trigger, ) if AIRFLOW_V_3_0_PLUS - else DateTimeTrigger(moment=timezone.parse(self.target_time)), + else DateTimeTrigger(moment=self._moment), ) def execute_complete(self, context: Context, event: Any = None) -> None: diff --git a/providers/standard/tests/unit/standard/sensors/test_date_time.py b/providers/standard/tests/unit/standard/sensors/test_date_time.py index c51b5316206de..188237a957bd4 100644 --- a/providers/standard/tests/unit/standard/sensors/test_date_time.py +++ b/providers/standard/tests/unit/standard/sensors/test_date_time.py @@ -19,8 +19,10 @@ from unittest.mock import patch +import pendulum import pytest +from airflow import macros from airflow.models.dag import DAG from airflow.providers.standard.sensors.date_time import DateTimeSensor from airflow.utils import timezone @@ -90,3 +92,34 @@ def test_invalid_input(self): def test_poke(self, mock_utcnow, task_id, target_time, expected): op = DateTimeSensor(task_id=task_id, target_time=target_time, dag=self.dag) assert op.poke(None) == expected + + @pytest.mark.parametrize( + "native, target_time, expected_type", + [ + (False, "2025-01-01T00:00:00+00:00", pendulum.DateTime), + (True, "{{ data_interval_end }}", pendulum.DateTime), + (False, pendulum.datetime(2025, 1, 1, tz="UTC"), pendulum.DateTime), + ], + ) + def test_moment(self, native, target_time, expected_type): + dag = DAG( + dag_id="moment_dag", + start_date=pendulum.datetime(2025, 1, 1, tz="UTC"), + schedule=None, + render_template_as_native_obj=native, + ) + + sensor = DateTimeSensor( + task_id="moment", + target_time=target_time, + dag=dag, + ) + + ctx = { + "data_interval_end": pendulum.datetime(2025, 1, 1, tz="UTC"), + "macros": macros, + "dag": dag, + } + sensor.render_template_fields(ctx) + + assert isinstance(sensor._moment, expected_type)