From a9eccd3c5b4db13d58aed6195b9c40047138459a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 19 Jun 2025 12:56:54 +0100 Subject: [PATCH] Improve logging with `run_as_user` to avoid "double" logging/plain-text-over-stdout When we are running normally (without impersonation) the supervisor sets up a new socketpair for logging before forking, and then the task procees configures structlog in the forked process to send logs over that socket. This all works as forking a process gives the new process a copy of all open file descriptors. However sudo by default will close all open file descriptors other than stdin, stdout and stderr, so our logs socket (sockets, and files, are all file descriptors). We could ask people to change their `sudoers` config file to add the [`closefrom_overide`][1] and invoke `sudo -C ` however many people either might not have access to do this, or might not feel comfortable in making this change. There is however another option to us: On both unix and windows there is the ability to pass _open_ file descriptors (which remember, sockets are file descriptors) between two processes! So what this PR does is introduce a new Request and Response pair, and customize the send+receive code to send a new FD (since we've already closed the child end for normal start up before we knew the task was actually going to run as another user, and we can't get it back, so we just open another) that is configured to receive and handle JSON logs. [1]: https://linux.die.net/man/5/sudoers#:~:text=on%20by%20default.-,closefrom_override,is%20off%20by%20default.,-compress_io'%20If%20set --- .../src/airflow/sdk/execution_time/comms.py | 51 +++++++++++++++++-- .../airflow/sdk/execution_time/supervisor.py | 38 ++++++++++++++ .../airflow/sdk/execution_time/task_runner.py | 14 +++++ .../execution_time/test_supervisor.py | 41 +++++++++++++++ 4 files changed, 140 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index be2f4f8eacd46..0589f12f866e3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -54,7 +54,7 @@ from functools import cached_property from pathlib import Path from socket import socket -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union, overload from uuid import UUID import attrs @@ -90,6 +90,12 @@ ) from airflow.sdk.exceptions import ErrorType +try: + from socket import recv_fds +except ImportError: + # Available on Unix and Windows (so "everywhere") but lets be safe + recv_fds = None # type: ignore[assignment] + if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -180,10 +186,29 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None: bytes = frame.as_bytes() self.socket.sendall(bytes) + if isinstance(msg, ResendLoggingFD): + if recv_fds is None: + return None + # We need special handling here! The server can't send us the fd number, as the number on the + # supervisor will be different to in this process, so we have to mutate the message ourselves here. + frame, fds = self._read_frame(maxfds=1) + resp = self._from_frame(frame) + if TYPE_CHECKING: + assert isinstance(resp, SentFDs) + resp.fds = fds + # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not + # always be in the return type union + return resp # type: ignore[return-value] return self._get_response() - def _read_frame(self): + @overload + def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... + + @overload + def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ... + + def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame: """ Get a message from the parent. @@ -191,7 +216,11 @@ def _read_frame(self): """ if self.socket: self.socket.setblocking(True) - len_bytes = self.socket.recv(4) + fds = None + if maxfds: + len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds) + else: + len_bytes = self.socket.recv(4) if len_bytes == b"": raise EOFError("Request socket closed before length") @@ -207,7 +236,10 @@ def _read_frame(self): if nread == 0: raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") - return self.resp_decoder.decode(buffer) + resp = self.resp_decoder.decode(buffer) + if maxfds: + return resp, fds or [] + return resp def _from_frame(self, frame) -> ReceiveMsgType | None: from airflow.sdk.exceptions import AirflowRuntimeError @@ -520,6 +552,11 @@ class OKResponse(BaseModel): type: Literal["OKResponse"] = "OKResponse" +class SentFDs(BaseModel): + type: Literal["SentFDs"] = "SentFDs" + fds: list[int] + + ToTask = Annotated[ Union[ AssetResult, @@ -529,6 +566,7 @@ class OKResponse(BaseModel): DRCount, ErrorResponse, PrevSuccessfulDagRunResult, + SentFDs, StartupDetails, TaskRescheduleStartDate, TICount, @@ -710,6 +748,10 @@ class DeleteVariable(BaseModel): type: Literal["DeleteVariable"] = "DeleteVariable" +class ResendLoggingFD(BaseModel): + type: Literal["ResendLoggingFD"] = "ResendLoggingFD" + + class SetRenderedFields(BaseModel): """Payload for setting RTIF for a task instance.""" @@ -829,6 +871,7 @@ class GetDRCount(BaseModel): TaskState, TriggerDagRun, DeleteVariable, + ResendLoggingFD, ], Field(discriminator="type"), ] diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 2ae554cace1a5..70b384a8c96b2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -94,7 +94,9 @@ PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, + ResendLoggingFD, RetryTask, + SentFDs, SetRenderedFields, SetXCom, SkipDownstreamTasks, @@ -115,6 +117,11 @@ ) from airflow.sdk.execution_time.secrets_masker import mask_secret +try: + from socket import send_fds +except ImportError: + send_fds = None # type: ignore[assignment] + if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -1218,6 +1225,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id) resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp) dump_opts = {"exclude_unset": True} + elif isinstance(msg, ResendLoggingFD): + # We need special handling here! + if send_fds is not None: + self._send_new_log_fd(req_id) + # Since we've sent the message, return. Nothing else in this ifelse/switch should return directly + return else: log.error("Unhandled request", msg=msg) self.send_msg( @@ -1232,6 +1245,31 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self.send_msg(resp, request_id=req_id, error=None, **dump_opts) + def _send_new_log_fd(self, req_id: int) -> None: + if send_fds is None: + raise RuntimeError("send_fds is not available on this platform") + child_logs, read_logs = socketpair() + + target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) + if self.subprocess_logs_to_stdout: + target_loggers += (log,) + + self.selector.register( + read_logs, + selectors.EVENT_READ, + make_buffered_socket_reader( + process_log_messages_from_subprocess(target_loggers), on_close=self._on_socket_closed + ), + ) + # We don't explicitly close the old log socket, that will get handled for us if/when the other end is + # closed (such as `sudo` would do for us automatically.) This also means that this feature _can_ be + # used outside of a exec context if it is useful, as we can then have multiple things sending us logs, + # such as the task process and it's launched subprocess. + + frame = _ResponseFrame(id=req_id, body=SentFDs(fds=[child_logs.fileno()]).model_dump()) + send_fds(self.stdin, [frame.as_bytes()], [child_logs.fileno()]) + child_logs.close() # Close this end now. + def in_process_api_server(): from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index e4292deca2a10..7ee837915a6f1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -70,7 +70,9 @@ GetTICount, InactiveAssetsResult, RescheduleTask, + ResendLoggingFD, RetryTask, + SentFDs, SetRenderedFields, SkipDownstreamTasks, StartupDetails, @@ -659,6 +661,18 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"): # entrypoint of re-exec process msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) + + logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) + if isinstance(logs, SentFDs): + from airflow.sdk.log import configure_logging + + log_io = os.fdopen(logs.fds[0], "wb", buffering=0) + configure_logging(enable_pretty_log=False, output=log_io, sending_to_supervisor=True) + else: + print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) + + # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up + # on stdout log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index d2f8ef40c0881..742dba98c4d38 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -31,6 +31,7 @@ from random import randint from time import sleep from typing import TYPE_CHECKING +from unittest import mock from unittest.mock import MagicMock, patch import httpx @@ -85,7 +86,9 @@ PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, + ResendLoggingFD, RetryTask, + SentFDs, SetRenderedFields, SetXCom, SucceedTask, @@ -236,6 +239,44 @@ def subprocess_main(): ] ) + def test_reopen_log_fd(self, captured_logs, client_with_ti_start): + def subprocess_main(): + # This is run in the subprocess! + + # Ensure we follow the "protocol" and get the startup message before we do anything else + comms = CommsDecoder() + comms._get_response() + + logs = comms.send(ResendLoggingFD()) + assert isinstance(logs, SentFDs) + fd = os.fdopen(logs.fds[0], "w") + logging.root.info("Log on old socket") + json.dump({"level": "info", "event": "Log on new socket"}, fp=fd) + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id="4d828a62-a417-4936-a7a6-2b3fabacecab", + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=client_with_ti_start, + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == 0 + assert captured_logs == unordered( + [ + {"event": "Log on new socket", "level": "info", "logger": "task", "timestamp": mock.ANY}, + {"event": "Log on old socket", "level": "info", "logger": "root", "timestamp": mock.ANY}, + ] + ) + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid()