diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index c1a58f8016161..4d4c451318af3 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -50,6 +50,7 @@ AirflowDagCycleException, AirflowDagDuplicatedIdException, AirflowException, + AirflowTaskTimeout, ) from airflow.listeners.listener import get_listener_manager from airflow.models.base import Base, StringID @@ -64,7 +65,6 @@ ) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.timeout import timeout from airflow.utils.types import NOTSET if TYPE_CHECKING: @@ -117,6 +117,30 @@ class FileLoadStat(NamedTuple): warning_num: int +@contextlib.contextmanager +def timeout(seconds=1, error_message="Timeout"): + import logging + + log = logging.getLogger(__name__) + error_message = error_message + ", PID: " + str(os.getpid()) + + def handle_timeout(signum, frame): + """Log information and raises AirflowTaskTimeout.""" + log.error("Process timed out, PID: %s", str(os.getpid())) + raise AirflowTaskTimeout(error_message) + + try: + try: + signal.signal(signal.SIGALRM, handle_timeout) + signal.setitimer(signal.ITIMER_REAL, seconds) + except ValueError: + log.warning("timeout can't be used in the current context", exc_info=True) + yield + finally: + with contextlib.suppress(ValueError): + signal.setitimer(signal.ITIMER_REAL, 0) + + class DagBag(LoggingMixin): """ A dagbag is a collection of dags, parsed out of a folder tree and has high level configuration settings. diff --git a/airflow-core/src/airflow/utils/__init__.py b/airflow-core/src/airflow/utils/__init__.py index 3a72f96fff646..9ed8b48407cee 100644 --- a/airflow-core/src/airflow/utils/__init__.py +++ b/airflow-core/src/airflow/utils/__init__.py @@ -42,6 +42,9 @@ "remove_task_decorator": "airflow.sdk.definitions._internal.decorators.remove_task_decorator", "fixup_decorator_warning_stack": "airflow.sdk.definitions._internal.decorators.fixup_decorator_warning_stack", }, + "timeout": { + "timeout": "airflow.sdk.execution_time.timeout.timeout", + }, } add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow-core/src/airflow/utils/timeout.py b/airflow-core/src/airflow/utils/timeout.py deleted file mode 100644 index 11a5e1bfa1e22..0000000000000 --- a/airflow-core/src/airflow/utils/timeout.py +++ /dev/null @@ -1,88 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os -import signal -from contextlib import AbstractContextManager -from threading import Timer - -from airflow.exceptions import AirflowTaskTimeout -from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.platform import IS_WINDOWS - -_timeout = AbstractContextManager[None] - - -class TimeoutWindows(_timeout, LoggingMixin): - """Windows timeout version: To be used in a ``with`` block and timeout its content.""" - - def __init__(self, seconds=1, error_message="Timeout"): - super().__init__() - self._timer: Timer | None = None - self.seconds = seconds - self.error_message = error_message + ", PID: " + str(os.getpid()) - - def handle_timeout(self, *args): - """Log information and raises AirflowTaskTimeout.""" - self.log.error("Process timed out, PID: %s", str(os.getpid())) - raise AirflowTaskTimeout(self.error_message) - - def __enter__(self): - if self._timer: - self._timer.cancel() - self._timer = Timer(self.seconds, self.handle_timeout) - self._timer.start() - - def __exit__(self, type_, value, traceback): - if self._timer: - self._timer.cancel() - self._timer = None - - -class TimeoutPosix(_timeout, LoggingMixin): - """POSIX Timeout version: To be used in a ``with`` block and timeout its content.""" - - def __init__(self, seconds=1, error_message="Timeout"): - super().__init__() - self.seconds = seconds - self.error_message = error_message + ", PID: " + str(os.getpid()) - - def handle_timeout(self, signum, frame): - """Log information and raises AirflowTaskTimeout.""" - self.log.error("Process timed out, PID: %s", str(os.getpid())) - raise AirflowTaskTimeout(self.error_message) - - def __enter__(self): - try: - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.setitimer(signal.ITIMER_REAL, self.seconds) - except ValueError: - self.log.warning("timeout can't be used in the current context", exc_info=True) - - def __exit__(self, type_, value, traceback): - try: - signal.setitimer(signal.ITIMER_REAL, 0) - except ValueError: - self.log.warning("timeout can't be used in the current context", exc_info=True) - - -if IS_WINDOWS: - timeout: type[TimeoutWindows | TimeoutPosix] = TimeoutWindows -else: - timeout = TimeoutPosix diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index 230ddca9f89e2..e0de8c76751ee 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -77,6 +77,17 @@ def setup_class(self): def teardown_class(self): db_clean_up() + def test_timeout_context_manager_raises_exception(self): + """Test that the timeout context manager raises AirflowTaskTimeout when time limit is exceeded.""" + import time + + from airflow.exceptions import AirflowTaskTimeout + from airflow.models.dagbag import timeout + + with pytest.raises(AirflowTaskTimeout): + with timeout(1, "Test timeout"): + time.sleep(2) + def test_get_existing_dag(self, tmp_path): """ Test that we're able to parse some example DAGs and retrieve them