From 00406fff8ba882b86dcabe729eb8eeeb051e925a Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Tue, 21 May 2024 07:26:04 -0600 Subject: [PATCH] Better typing for BaseOperator `defer` This adds typing for the `defer` method, and covers the core deferrable sensors. It also fixes 1 error in a databricks provider, but leaves the rest of the provider ecosystem alone. --- airflow/models/baseoperator.py | 3 ++- airflow/providers/databricks/operators/databricks.py | 2 +- airflow/sensors/date_time.py | 6 +++--- airflow/sensors/time_delta.py | 6 +++--- airflow/sensors/time_sensor.py | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f08119bd559b9..98532d90b0256 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -41,6 +41,7 @@ Callable, Collection, Iterable, + NoReturn, Sequence, TypeVar, Union, @@ -1706,7 +1707,7 @@ def defer( method_name: str, kwargs: dict[str, Any] | None = None, timeout: timedelta | None = None, - ): + ) -> NoReturn: """ Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 7ae802db104a7..ff8de101326be 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -1046,7 +1046,7 @@ def monitor_databricks_job(self) -> None: run_state = RunState(**run["state"]) self.log.info("Current state of the job: %s", run_state.life_cycle_state) if self.deferrable and not run_state.is_terminal: - return self.defer( + self.defer( trigger=DatabricksExecutionTrigger( run_id=self.databricks_run_id, databricks_conn_id=self.databricks_conn_id, diff --git a/airflow/sensors/date_time.py b/airflow/sensors/date_time.py index 65880ebb9e754..b0763ebd40a87 100644 --- a/airflow/sensors/date_time.py +++ b/airflow/sensors/date_time.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, NoReturn, Sequence from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -90,13 +90,13 @@ class DateTimeSensorAsync(DateTimeSensor): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - def execute(self, context: Context): + def execute(self, context: Context) -> NoReturn: trigger = DateTimeTrigger(moment=timezone.parse(self.target_time)) self.defer( trigger=trigger, method_name="execute_complete", ) - def execute_complete(self, context, event=None): + def execute_complete(self, context, event=None) -> None: """Execute when the trigger fires - returns immediately.""" return None diff --git a/airflow/sensors/time_delta.py b/airflow/sensors/time_delta.py index 3595e551b04b4..82d16bbae6575 100644 --- a/airflow/sensors/time_delta.py +++ b/airflow/sensors/time_delta.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NoReturn from airflow.exceptions import AirflowSkipException from airflow.sensors.base import BaseSensorOperator @@ -66,7 +66,7 @@ class TimeDeltaSensorAsync(TimeDeltaSensor): """ - def execute(self, context: Context): + def execute(self, context: Context) -> NoReturn: target_dttm = context["data_interval_end"] target_dttm += self.delta try: @@ -78,6 +78,6 @@ def execute(self, context: Context): self.defer(trigger=trigger, method_name="execute_complete") - def execute_complete(self, context, event=None): + def execute_complete(self, context, event=None) -> None: """Execute for when the trigger fires - return immediately.""" return None diff --git a/airflow/sensors/time_sensor.py b/airflow/sensors/time_sensor.py index 6df67bc855b24..91c1354782593 100644 --- a/airflow/sensors/time_sensor.py +++ b/airflow/sensors/time_sensor.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NoReturn from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -72,7 +72,7 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None: self.target_datetime = timezone.convert_to_utc(aware_time) - def execute(self, context: Context): + def execute(self, context: Context) -> NoReturn: trigger = DateTimeTrigger(moment=self.target_datetime) self.defer( trigger=trigger,