diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9e7bb782518fc..3a9f1a705af04 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -40,7 +40,7 @@ from airflow.configuration import conf from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException +from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, AirflowTaskTimeout from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( @@ -869,7 +869,6 @@ def run( AirflowSensorTimeout, AirflowSkipException, AirflowTaskTerminated, - AirflowTaskTimeout, DagRunTriggerException, DownstreamTasksSkipped, TaskDeferred, @@ -1157,8 +1156,6 @@ def _send_task_error_email(to: Iterable[str], ti: RuntimeTaskInstance, exception def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): """Execute Task (optionally with a Timeout) and push Xcom results.""" - from airflow.exceptions import AirflowTaskTimeout - task = ti.task execute = task.execute @@ -1187,9 +1184,9 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): _run_task_state_change_callbacks(task, "on_execute_callback", context, log) if task.execution_timeout: - # TODO: handle timeout in case of deferral - from airflow.utils.timeout import timeout + from airflow.sdk.execution_time.timeout import timeout + # TODO: handle timeout in case of deferral timeout_seconds = task.execution_timeout.total_seconds() try: # It's possible we're already timed out, so fast-fail if true diff --git a/task-sdk/src/airflow/sdk/execution_time/timeout.py b/task-sdk/src/airflow/sdk/execution_time/timeout.py new file mode 100644 index 0000000000000..fe4a0e8bd8c52 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/timeout.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +import structlog + +from airflow.exceptions import AirflowTaskTimeout + + +class TimeoutPosix: + """POSIX Timeout version: To be used in a ``with`` block and timeout its content.""" + + def __init__(self, seconds=1, error_message="Timeout"): + super().__init__() + self.seconds = seconds + self.error_message = error_message + ", PID: " + str(os.getpid()) + self.log = structlog.get_logger(logger_name="task") + + def handle_timeout(self, signum, frame): + """Log information and raises AirflowTaskTimeout.""" + self.log.error("Process timed out", pid=os.getpid()) + raise AirflowTaskTimeout(self.error_message) + + def __enter__(self): + import signal + + try: + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.setitimer(signal.ITIMER_REAL, self.seconds) + except ValueError: + self.log.warning("timeout can't be used in the current context", exc_info=True) + return self + + def __exit__(self, type_, value, traceback): + import signal + + try: + signal.setitimer(signal.ITIMER_REAL, 0) + except ValueError: + self.log.warning("timeout can't be used in the current context", exc_info=True) + + +timeout = TimeoutPosix