Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import selectors
import signal
import sys
import threading
import time
import weakref
from collections import deque
Expand Down Expand Up @@ -1475,6 +1476,44 @@ def request(self, *args, **kwargs):
# Bypass the tenacity retries!
return super().request.__wrapped__(self, *args, **kwargs) # type: ignore[attr-defined]

def _check_subprocess_exit(
self, raise_on_timeout: bool = False, expect_signal: None | int = None
) -> int | None:
# InProcessSupervisor has no subprocess, so we don't need to poll anything. This is called from
# _handle_socket_comms, so we need to override it
return None

def _handle_socket_comms(self):
while self._open_sockets:
self._service_subprocess(1.0)

@contextlib.contextmanager
def _setup_subprocess_socket(self):
thread = threading.Thread(target=self._handle_socket_comms, daemon=True)

requests, child_sock = socketpair()

self._open_sockets[requests] = "requests"
self.stdin = requests

self.selector.register(
requests,
selectors.EVENT_READ,
length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed),
)
os.set_inheritable(child_sock.fileno(), True)
os.environ["__AIRFLOW_SUPERVISOR_FD"] = str(child_sock.fileno())

try:
thread.start()
yield child_sock
finally:
requests.close()
child_sock.close()
self._on_socket_closed(requests)
thread.join(0)
os.environ.pop("__AIRFLOW_SUPERVISOR_FD", None)

@classmethod
def start( # type: ignore[override]
cls,
Expand Down Expand Up @@ -1527,16 +1566,19 @@ def start( # type: ignore[override]
start_date=start_date,
state=TaskInstanceState.RUNNING,
)
context = ti.get_template_context()
log = structlog.get_logger(logger_name="task")

state, msg, error = run(ti, context, log)
finalize(ti, state, context, log, error)
# Create a socketpair preemptively, in case the task process runs VirtualEnv operator or run_as_user
with supervisor._setup_subprocess_socket():
context = ti.get_template_context()
log = structlog.get_logger(logger_name="task")

state, msg, error = run(ti, context, log)
finalize(ti, state, context, log, error)

# In the normal subprocess model, the task runner calls this before exiting.
# Since we're running in-process, we manually notify the API server that
# the task has finished—unless the terminal state was already sent explicitly.
supervisor.update_task_state_if_needed()
# In the normal subprocess model, the task runner calls this before exiting.
# Since we're running in-process, we manually notify the API server that
# the task has finished—unless the terminal state was already sent explicitly.
supervisor.update_task_state_if_needed()

return TaskRunResult(ti=ti, state=state, msg=msg, error=error)

Expand Down
6 changes: 5 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,11 +1475,15 @@ def reinit_supervisor_comms() -> None:
run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
can continue to access variables etc.
"""
import socket

if "SUPERVISOR_COMMS" not in globals():
global SUPERVISOR_COMMS
log = structlog.get_logger(logger_name="task")

SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))

SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))

logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
if isinstance(logs, SentFDs):
Expand Down