diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 9b4933de8080e..cb5554681b303 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,7 +27,6 @@ import signal import sys import time -import weakref from collections.abc import Generator from contextlib import suppress from datetime import datetime, timezone @@ -64,6 +63,8 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger, WrappedLogger + from airflow.typing_compat import Self + __all__ = ["WatchedSubprocess", "supervise"] @@ -263,7 +264,7 @@ def exit(n: int) -> NoReturn: @attrs.define() class WatchedSubprocess: - ti_id: UUID + id: UUID pid: int stdin: BinaryIO @@ -292,20 +293,16 @@ class WatchedSubprocess: selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) - procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary() - - def __attrs_post_init__(self): - self.procs[self.pid] = self - @classmethod def start( cls, path: str | os.PathLike[str], - ti: TaskInstance, + what: TaskInstance, client: Client, target: Callable[[], None] = _subprocess_main, logger: FilteringBoundLogger | None = None, - ) -> WatchedSubprocess: + **constructor_kwargs, + ) -> Self: """Fork and start a new subprocess to execute the given task.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess child_stdin, feed_stdin = mkpipe(remote_read=True) @@ -324,31 +321,27 @@ def start( # around in the forked processes, especially things that might involve open files or sockets! del path del client - del ti + del what del logger # Run the child entrypoint _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + requests_fd = child_comms.fileno() + + # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the + # other end of the pair open + cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) + proc = cls( - ti_id=ti.id, + id=constructor_kwargs.get("id") or getattr(what, "id"), pid=pid, stdin=feed_stdin, process=psutil.Process(pid), client=client, + **constructor_kwargs, ) - # We've forked, but the task won't start until we send it the StartupDetails message. But before we do - # that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any - # reason) - try: - client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc)) - proc._last_successful_heartbeat = time.monotonic() - except Exception: - # On any error kill that subprocess! - proc.kill(signal.SIGKILL) - raise - logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc._register_pipe_readers( logger=logger, @@ -359,11 +352,8 @@ def start( ) # Tell the task process what it needs to do! - proc._send_startup_message(ti, path, child_comms) + proc._on_child_started(what, path, requests_fd) - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the - # other end of the pair open - proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) return proc def _register_pipe_readers( @@ -401,12 +391,23 @@ def _close_unused_sockets(*sockets): for sock in sockets: sock.close() - def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str], child_comms: socket): + def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requests_fd: int): """Send startup message to the subprocess.""" + try: + # We've forked, but the task won't start doing anything until we send it the StartupDetails + # message. But before we do that, we need to tell the server it's started (so it has the chance to + # tell us "no, stop!" for any reason) + self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc)) + self._last_successful_heartbeat = time.monotonic() + except Exception: + # On any error kill that subprocess! + self.kill(signal.SIGKILL) + raise + msg = StartupDetails.model_construct( ti=ti, - file=str(path), - requests_fd=child_comms.fileno(), + file=os.fspath(path), + requests_fd=requests_fd, ) # Send the message to tell the process what it needs to execute @@ -490,7 +491,7 @@ def wait(self) -> int: # by the subprocess in the `handle_requests` method. if self.final_state in TerminalTIState: self.client.task_instances.finish( - id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) + id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) ) return self._exit_code @@ -525,9 +526,9 @@ def _monitor_subprocess(self): # logs self._send_heartbeat_if_needed() - self._handle_task_overtime_if_needed() + self._handle_process_overtime_if_needed() - def _handle_task_overtime_if_needed(self): + def _handle_process_overtime_if_needed(self): """Handle termination of auxiliary processes if the task exceeds the configured overtime.""" # If the task has reached a terminal state, we can start monitoring the overtime if not self._terminal_state: @@ -537,7 +538,7 @@ def _handle_task_overtime_if_needed(self): self._task_end_time_monotonic and (time.monotonic() - self._task_end_time_monotonic) > self.TASK_OVERTIME_THRESHOLD ): - log.warning("Task success overtime reached; terminating process", ti_id=self.ti_id) + log.warning("Workload success overtime reached; terminating process", ti_id=self.id) self.kill(signal.SIGTERM, force=True) def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False): @@ -579,7 +580,7 @@ def _check_subprocess_exit(self, raise_on_timeout: bool = False) -> int | None: if self._exit_code is None: try: self._exit_code = self._process.wait(timeout=0) - log.debug("Task process exited", exit_code=self._exit_code) + log.debug("Workload process exited", exit_code=self._exit_code) except psutil.TimeoutExpired: if raise_on_timeout: raise @@ -593,7 +594,7 @@ def _send_heartbeat_if_needed(self): self._last_heartbeat_attempt = time.monotonic() try: - self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid) + self.client.task_instances.heartbeat(self.id, pid=self._process.pid) # Update the last heartbeat time on success self._last_successful_heartbeat = time.monotonic() @@ -619,7 +620,7 @@ def _handle_heartbeat_failures(self): log.warning( "Failed to send heartbeat. Will be retried", failed_heartbeats=self.failed_heartbeats, - ti_id=self.ti_id, + ti_id=self.id, max_retries=MAX_FAILED_HEARTBEATS, exc_info=True, ) @@ -646,7 +647,7 @@ def final_state(self): return TerminalTIState.FAILED def __rich_repr__(self): - yield "ti_id", self.ti_id + yield "id", self.id yield "pid", self.pid # only include this if it's not the default (third argument) yield "exit_code", self._exit_code, None @@ -654,7 +655,7 @@ def __rich_repr__(self): __rich_repr__.angular = True # type: ignore[attr-defined] def __repr__(self) -> str: - rep = f"" @@ -672,35 +673,38 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N log.exception("Unable to decode message", line=line) continue + self._handle_request(msg, log) + + def _handle_request(self, msg, log): + resp = None + if isinstance(msg, TaskState): + self._terminal_state = msg.state + self._task_end_time_monotonic = time.monotonic() + elif isinstance(msg, GetConnection): + conn = self.client.connections.get(msg.conn_id) + resp = conn.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, GetVariable): + var = self.client.variables.get(msg.key) + resp = var.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, GetXCom): + xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) + resp = xcom.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, DeferTask): + self._terminal_state = IntermediateTIState.DEFERRED + self.client.task_instances.defer(self.id, msg) resp = None - if isinstance(msg, TaskState): - self._terminal_state = msg.state - self._task_end_time_monotonic = time.monotonic() - elif isinstance(msg, GetConnection): - conn = self.client.connections.get(msg.conn_id) - resp = conn.model_dump_json(exclude_unset=True).encode() - elif isinstance(msg, GetVariable): - var = self.client.variables.get(msg.key) - resp = var.model_dump_json(exclude_unset=True).encode() - elif isinstance(msg, GetXCom): - xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) - resp = xcom.model_dump_json(exclude_unset=True).encode() - elif isinstance(msg, DeferTask): - self._terminal_state = IntermediateTIState.DEFERRED - self.client.task_instances.defer(self.ti_id, msg) - resp = None - elif isinstance(msg, SetXCom): - self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index) - resp = None - elif isinstance(msg, PutVariable): - self.client.variables.set(msg.key, msg.value, msg.description) - resp = None - else: - log.error("Unhandled request", msg=msg) - continue + elif isinstance(msg, SetXCom): + self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index) + resp = None + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) + resp = None + else: + log.error("Unhandled request", msg=msg) + return - if resp: - self.stdin.write(resp + b"\n") + if resp: + self.stdin.write(resp + b"\n") # Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index db32398c7c5a9..d210e0011fe0e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -23,11 +23,11 @@ import sys from datetime import datetime, timezone from io import FileIO -from typing import TYPE_CHECKING, TextIO +from typing import TYPE_CHECKING, Generic, TextIO, TypeVar import attrs import structlog -from pydantic import ConfigDict, TypeAdapter +from pydantic import BaseModel, ConfigDict, TypeAdapter from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator @@ -77,17 +77,24 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: return RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=task) +SendMsgType = TypeVar("SendMsgType", bound=BaseModel) +ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) + + @attrs.define() -class CommsDecoder: +class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): """Handle communication between the task in this process and the supervisor parent process.""" input: TextIO request_socket: FileIO = attrs.field(init=False, default=None) - decoder: TypeAdapter[ToTask] = attrs.field(init=False, factory=lambda: TypeAdapter(ToTask)) + # We could be "clever" here and set the default to this based type parameters and a custom + # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a + # "sort of wrong default" + decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) - def get_message(self) -> ToTask: + def get_message(self) -> ReceiveMsgType: """ Get a message from the parent. @@ -106,7 +113,7 @@ def get_message(self) -> ToTask: self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) return msg - def send_request(self, log: Logger, msg: ToSupervisor): + def send_request(self, log: Logger, msg: SendMsgType): encoded_msg = msg.model_dump_json().encode() + b"\n" log.debug("Sending request", json=encoded_msg) @@ -123,7 +130,7 @@ def send_request(self, log: Logger, msg: ToSupervisor): # deeply nested execution stack. # - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. -SUPERVISOR_COMMS: CommsDecoder +SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] # State machine! # 1. Start up (receive details from supervisor) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index e44b4942e13ff..406b2ee26996f 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -98,7 +98,7 @@ def subprocess_main(): proc = WatchedSubprocess.start( path=os.devnull, - ti=TaskInstance( + what=TaskInstance( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", @@ -165,7 +165,7 @@ def subprocess_main(): proc = WatchedSubprocess.start( path=os.devnull, - ti=TaskInstance( + what=TaskInstance( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", @@ -188,7 +188,7 @@ def subprocess_main(): proc = WatchedSubprocess.start( path=os.devnull, - ti=TaskInstance( + what=TaskInstance( id=uuid7(), task_id="b", dag_id="c", @@ -224,7 +224,7 @@ def subprocess_main(): spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) proc = WatchedSubprocess.start( path=os.devnull, - ti=TaskInstance( + what=TaskInstance( id=ti_id, task_id="b", dag_id="c", @@ -335,7 +335,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) with pytest.raises(ServerResponseError, match="Server returned error") as err: - WatchedSubprocess.start(path=os.devnull, ti=ti, client=client) + WatchedSubprocess.start(path=os.devnull, what=ti, client=client) assert err.value.response.status_code == 409 assert err.value.detail == { @@ -388,7 +388,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: proc = WatchedSubprocess.start( path=os.devnull, - ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, ) @@ -440,7 +440,7 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t mock_kill = mocker.patch("airflow.sdk.execution_time.supervisor.WatchedSubprocess.kill") proc = WatchedSubprocess( - ti_id=TI_ID, + id=TI_ID, pid=mock_process.pid, stdin=mocker.MagicMock(), client=client, @@ -528,7 +528,7 @@ def test_overtime_handling( monkeypatch.setattr(WatchedSubprocess, "TASK_OVERTIME_THRESHOLD", overtime_threshold) mock_watched_subprocess = WatchedSubprocess( - ti_id=TI_ID, + id=TI_ID, pid=12345, stdin=mocker.Mock(), process=mocker.Mock(), @@ -541,13 +541,13 @@ def test_overtime_handling( # Call `wait` to trigger the overtime handling # This will call the `kill` method if the task has been running for too long - mock_watched_subprocess._handle_task_overtime_if_needed() + mock_watched_subprocess._handle_process_overtime_if_needed() # Validate process kill behavior and log messages if expected_kill: mock_kill.assert_called_once_with(signal.SIGTERM, force=True) mock_logger.warning.assert_called_once_with( - "Task success overtime reached; terminating process", + "Workload success overtime reached; terminating process", ti_id=TI_ID, ) else: @@ -565,7 +565,7 @@ def mock_process(self, mocker): @pytest.fixture def watched_subprocess(self, mocker, mock_process): proc = WatchedSubprocess( - ti_id=TI_ID, + id=TI_ID, pid=12345, stdin=mocker.Mock(), client=mocker.Mock(), @@ -656,7 +656,7 @@ def _handler(sig, frame): proc = WatchedSubprocess.start( path=os.devnull, - ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), + what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, ) @@ -746,7 +746,7 @@ class TestHandleRequest: def watched_subprocess(self, mocker): """Fixture to provide a WatchedSubprocess instance.""" return WatchedSubprocess( - ti_id=TI_ID, + id=TI_ID, pid=12345, stdin=BytesIO(), client=mocker.Mock(),