diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index be2f4f8eacd46..0589f12f866e3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -54,7 +54,7 @@ from functools import cached_property from pathlib import Path from socket import socket -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union, overload from uuid import UUID import attrs @@ -90,6 +90,12 @@ ) from airflow.sdk.exceptions import ErrorType +try: + from socket import recv_fds +except ImportError: + # Available on Unix and Windows (so "everywhere") but lets be safe + recv_fds = None # type: ignore[assignment] + if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -180,10 +186,29 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None: bytes = frame.as_bytes() self.socket.sendall(bytes) + if isinstance(msg, ResendLoggingFD): + if recv_fds is None: + return None + # We need special handling here! The server can't send us the fd number, as the number on the + # supervisor will be different to in this process, so we have to mutate the message ourselves here. + frame, fds = self._read_frame(maxfds=1) + resp = self._from_frame(frame) + if TYPE_CHECKING: + assert isinstance(resp, SentFDs) + resp.fds = fds + # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not + # always be in the return type union + return resp # type: ignore[return-value] return self._get_response() - def _read_frame(self): + @overload + def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... + + @overload + def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ... + + def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame: """ Get a message from the parent. @@ -191,7 +216,11 @@ def _read_frame(self): """ if self.socket: self.socket.setblocking(True) - len_bytes = self.socket.recv(4) + fds = None + if maxfds: + len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds) + else: + len_bytes = self.socket.recv(4) if len_bytes == b"": raise EOFError("Request socket closed before length") @@ -207,7 +236,10 @@ def _read_frame(self): if nread == 0: raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") - return self.resp_decoder.decode(buffer) + resp = self.resp_decoder.decode(buffer) + if maxfds: + return resp, fds or [] + return resp def _from_frame(self, frame) -> ReceiveMsgType | None: from airflow.sdk.exceptions import AirflowRuntimeError @@ -520,6 +552,11 @@ class OKResponse(BaseModel): type: Literal["OKResponse"] = "OKResponse" +class SentFDs(BaseModel): + type: Literal["SentFDs"] = "SentFDs" + fds: list[int] + + ToTask = Annotated[ Union[ AssetResult, @@ -529,6 +566,7 @@ class OKResponse(BaseModel): DRCount, ErrorResponse, PrevSuccessfulDagRunResult, + SentFDs, StartupDetails, TaskRescheduleStartDate, TICount, @@ -710,6 +748,10 @@ class DeleteVariable(BaseModel): type: Literal["DeleteVariable"] = "DeleteVariable" +class ResendLoggingFD(BaseModel): + type: Literal["ResendLoggingFD"] = "ResendLoggingFD" + + class SetRenderedFields(BaseModel): """Payload for setting RTIF for a task instance.""" @@ -829,6 +871,7 @@ class GetDRCount(BaseModel): TaskState, TriggerDagRun, DeleteVariable, + ResendLoggingFD, ], Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 2ae554cace1a5..70b384a8c96b2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -94,7 +94,9 @@ PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, + ResendLoggingFD, RetryTask, + SentFDs, SetRenderedFields, SetXCom, SkipDownstreamTasks, @@ -115,6 +117,11 @@ ) from airflow.sdk.execution_time.secrets_masker import mask_secret +try: + from socket import send_fds +except ImportError: + send_fds = None # type: ignore[assignment] + if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -1218,6 +1225,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id) resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp) dump_opts = {"exclude_unset": True} + elif isinstance(msg, ResendLoggingFD): + # We need special handling here! + if send_fds is not None: + self._send_new_log_fd(req_id) + # Since we've sent the message, return. Nothing else in this ifelse/switch should return directly + return else: log.error("Unhandled request", msg=msg) self.send_msg( @@ -1232,6 +1245,31 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self.send_msg(resp, request_id=req_id, error=None, **dump_opts) + def _send_new_log_fd(self, req_id: int) -> None: + if send_fds is None: + raise RuntimeError("send_fds is not available on this platform") + child_logs, read_logs = socketpair() + + target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) + if self.subprocess_logs_to_stdout: + target_loggers += (log,) + + self.selector.register( + read_logs, + selectors.EVENT_READ, + make_buffered_socket_reader( + process_log_messages_from_subprocess(target_loggers), on_close=self._on_socket_closed + ), + ) + # We don't explicitly close the old log socket, that will get handled for us if/when the other end is + # closed (such as `sudo` would do for us automatically.) This also means that this feature _can_ be + # used outside of a exec context if it is useful, as we can then have multiple things sending us logs, + # such as the task process and it's launched subprocess. + + frame = _ResponseFrame(id=req_id, body=SentFDs(fds=[child_logs.fileno()]).model_dump()) + send_fds(self.stdin, [frame.as_bytes()], [child_logs.fileno()]) + child_logs.close() # Close this end now. + def in_process_api_server(): from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index e4292deca2a10..7ee837915a6f1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -70,7 +70,9 @@ GetTICount, InactiveAssetsResult, RescheduleTask, + ResendLoggingFD, RetryTask, + SentFDs, SetRenderedFields, SkipDownstreamTasks, StartupDetails, @@ -659,6 +661,18 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"): # entrypoint of re-exec process msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) + + logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) + if isinstance(logs, SentFDs): + from airflow.sdk.log import configure_logging + + log_io = os.fdopen(logs.fds[0], "wb", buffering=0) + configure_logging(enable_pretty_log=False, output=log_io, sending_to_supervisor=True) + else: + print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) + + # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up + # on stdout log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index d2f8ef40c0881..742dba98c4d38 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -31,6 +31,7 @@ from random import randint from time import sleep from typing import TYPE_CHECKING +from unittest import mock from unittest.mock import MagicMock, patch import httpx @@ -85,7 +86,9 @@ PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, + ResendLoggingFD, RetryTask, + SentFDs, SetRenderedFields, SetXCom, SucceedTask, @@ -236,6 +239,44 @@ def subprocess_main(): ] ) + def test_reopen_log_fd(self, captured_logs, client_with_ti_start): + def subprocess_main(): + # This is run in the subprocess! + + # Ensure we follow the "protocol" and get the startup message before we do anything else + comms = CommsDecoder() + comms._get_response() + + logs = comms.send(ResendLoggingFD()) + assert isinstance(logs, SentFDs) + fd = os.fdopen(logs.fds[0], "w") + logging.root.info("Log on old socket") + json.dump({"level": "info", "event": "Log on new socket"}, fp=fd) + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id="4d828a62-a417-4936-a7a6-2b3fabacecab", + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=client_with_ti_start, + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == 0 + assert captured_logs == unordered( + [ + {"event": "Log on new socket", "level": "info", "logger": "task", "timestamp": mock.ANY}, + {"event": "Log on old socket", "level": "info", "logger": "root", "timestamp": mock.ANY}, + ] + ) + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid()