diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 70e293cbbe467..b3c93ba965435 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,6 +27,7 @@ import selectors import signal import sys +import threading import time import weakref from collections import deque @@ -1475,6 +1476,44 @@ def request(self, *args, **kwargs): # Bypass the tenacity retries! return super().request.__wrapped__(self, *args, **kwargs) # type: ignore[attr-defined] + def _check_subprocess_exit( + self, raise_on_timeout: bool = False, expect_signal: None | int = None + ) -> int | None: + # InProcessSupervisor has no subprocess, so we don't need to poll anything. This is called from + # _handle_socket_comms, so we need to override it + return None + + def _handle_socket_comms(self): + while self._open_sockets: + self._service_subprocess(1.0) + + @contextlib.contextmanager + def _setup_subprocess_socket(self): + thread = threading.Thread(target=self._handle_socket_comms, daemon=True) + + requests, child_sock = socketpair() + + self._open_sockets[requests] = "requests" + self.stdin = requests + + self.selector.register( + requests, + selectors.EVENT_READ, + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), + ) + os.set_inheritable(child_sock.fileno(), True) + os.environ["__AIRFLOW_SUPERVISOR_FD"] = str(child_sock.fileno()) + + try: + thread.start() + yield child_sock + finally: + requests.close() + child_sock.close() + self._on_socket_closed(requests) + thread.join(0) + os.environ.pop("__AIRFLOW_SUPERVISOR_FD", None) + @classmethod def start( # type: ignore[override] cls, @@ -1527,16 +1566,19 @@ def start( # type: ignore[override] start_date=start_date, state=TaskInstanceState.RUNNING, ) - context = ti.get_template_context() - log = structlog.get_logger(logger_name="task") - state, msg, error = run(ti, context, log) - finalize(ti, state, context, log, error) + # Create a socketpair preemptively, in case the task process runs VirtualEnv operator or run_as_user + with supervisor._setup_subprocess_socket(): + context = ti.get_template_context() + log = structlog.get_logger(logger_name="task") + + state, msg, error = run(ti, context, log) + finalize(ti, state, context, log, error) - # In the normal subprocess model, the task runner calls this before exiting. - # Since we're running in-process, we manually notify the API server that - # the task has finished—unless the terminal state was already sent explicitly. - supervisor.update_task_state_if_needed() + # In the normal subprocess model, the task runner calls this before exiting. + # Since we're running in-process, we manually notify the API server that + # the task has finished—unless the terminal state was already sent explicitly. + supervisor.update_task_state_if_needed() return TaskRunResult(ti=ti, state=state, msg=msg, error=error) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 41b11084d7fdf..e6115ea915ba4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1475,11 +1475,15 @@ def reinit_supervisor_comms() -> None: run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks can continue to access variables etc. """ + import socket + if "SUPERVISOR_COMMS" not in globals(): global SUPERVISOR_COMMS log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) + fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) + + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) if isinstance(logs, SentFDs):