From 48f33a05f43f09ebf7d3deba7e6697d940ab5745 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 10 Jun 2025 11:55:46 +0100 Subject: [PATCH 01/21] Switch the Supervisor/task process from line-based to length-prefixed The existing JSON Lines based approach had two major drawbacks 1. In the case of really large lines (in the region of 10 or 20MB) the python line buffering could _sometimes_ result in a partial read 2. The JSON based approach didn't have the ability to add any metadata (such as errors). 3. Not every message type/call-site waited for a response, which meant those client functions could never get told about an error One of the ways this line-based approach fell down was if you suddenly tried to run 100s of triggers at the same time you would get an error like this: ``` Traceback (most recent call last): File "/Users/ash/.local/share/uv/python/cpython-3.12.7-macos-aarch64-none/lib/python3.12/asyncio/streams.py", line 568, in readline line = await self.readuntil(sep) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/ash/.local/share/uv/python/cpython-3.12.7-macos-aarch64-none/lib/python3.12/asyncio/streams.py", line 663, in readuntil raise exceptions.LimitOverrunError( asyncio.exceptions.LimitOverrunError: Separator is found, but chunk is longer than limit ``` The other way this caused problems was if you parse a large dag (as in one with 20k tasks or more) the DagFileProcessor could end up getting a partial read which would be invalid JSON. This changes the communications protocol in in a couple of ways. First off at the python level the separate send and receive methods in the client/task side have been removed and replaced with a single `send()` that sends the request, reads the response and raises an error if one is returned. (But note, right now almost nothing in the supervisor side sets the error, that will be a future PR.) Secondly the JSON Lines approach has been changed from a line-based protocol to a binary "frame" one. The protocol (which is the same for whichever side is sending) is length-prefixed, i.e. we first send the length of the data as a 4byte big-endian integer, followed by the data itself. This should remove the possibility of JSON parse errors due to reading incomplete lines Finally the last change made in this PR is to remove the "extra" requests socket/channel. Upon closer examination with this comms path I realised that this socket is unnecessary: Since we are in 100% control of the client side we can make use of the bi-directional nature of `socketpair` and save file handles. This also happens to help the `run_as_user` feature which is currently broken, as without extra config to `sudoers` file, `sudo` will close all filehandles other than stdin, stdout, and stderr -- so by introducing this change we make it easier to re-add run_as_user support. In order to support this in the DagFileProcessor (as the fact that the proc manager uses a single selector for multiple processes) means I have moved the `on_close` callback to be part of the object we store in the `selector` object in the supervisors, previoulsy it was the "on_read" callback, now we store a tuple of `(on_read, on_close)` and on_close is called once universally. This also changes the way comms are handled from the (async) TriggerRunner process. Previously we had a sync+async lock, but that made it possible to end up deadlocking things. The change now is to have `send` on `TriggerCommsDecoder` "go back" to the async even loop via `async_to_sync`, so that only async code deals with the socket, and we can use an async lock (rather than the hybrid sync and async lock we tried before). This seems to help the deadlock issue, but I'm not 100% sure it will remove it entirely, but it makes it much much harder to hit - I've not been able to reprouce it with this change --- .../src/airflow/dag_processing/manager.py | 9 +- .../src/airflow/dag_processing/processor.py | 33 +- .../src/airflow/jobs/triggerer_job_runner.py | 124 ++++--- .../tests/unit/dag_processing/test_manager.py | 82 +++-- .../unit/dag_processing/test_processor.py | 51 +-- .../tests/unit/jobs/test_triggerer_job.py | 34 +- .../src/tests_common/pytest_plugin.py | 9 +- task-sdk/pyproject.toml | 1 - task-sdk/src/airflow/sdk/bases/xcom.py | 98 +++--- .../sdk/definitions/asset/decorators.py | 12 +- .../src/airflow/sdk/execution_time/comms.py | 164 ++++++++- .../src/airflow/sdk/execution_time/context.py | 74 ++--- .../sdk/execution_time/lazy_sequence.py | 60 ++-- .../airflow/sdk/execution_time/supervisor.py | 255 ++++++++------ .../airflow/sdk/execution_time/task_runner.py | 215 ++++-------- .../task_sdk/execution_time/test_comms.py | 83 +++++ .../execution_time/test_supervisor.py | 312 ++++++++++-------- .../execution_time/test_task_runner.py | 51 --- 18 files changed, 909 insertions(+), 758 deletions(-) create mode 100644 task-sdk/tests/task_sdk/execution_time/test_comms.py diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 272e086d90903..8af35af62b7ef 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -77,6 +77,8 @@ from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: + from socket import socket + from sqlalchemy.orm import Session from airflow.callbacks.callback_requests import CallbackRequest @@ -388,7 +390,7 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): """ events = self.selector.select(timeout=timeout) for key, _ in events: - socket_handler = key.data + socket_handler, on_close = key.data # BrokenPipeError should be caught and treated as if the handler returned false, similar # to EOF case @@ -397,8 +399,9 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): except BrokenPipeError: need_more = False if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() def _queue_requested_files_for_parsing(self) -> None: """Queue any files requested for parsing as requested by users via UI/API.""" diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 5f69082c758aa..4341d6d91390e 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel): bundle_path: Path """Passing bundle path around lets us figure out relative file path.""" - requests_fd: int callback_requests: list[CallbackRequest] = Field(default_factory=list) type: Literal["DagFileParseRequest"] = "DagFileParseRequest" @@ -102,18 +101,16 @@ class DagFileParsingResult(BaseModel): def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time import comms, task_runner # Parse DAG file, send JSON back up! - comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager]( - input=sys.stdin, - decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), + comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) - msg = comms_decoder.get_message() + msg = comms_decoder._get_response() if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") @@ -125,7 +122,7 @@ def _parse_file_entrypoint(): result = _parse_file(msg, log) if result is not None: - comms_decoder.send_request(log, result) + comms_decoder.send(result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: @@ -266,20 +263,18 @@ def _on_child_started( msg = DagFileParseRequest( file=os.fspath(path), bundle_path=bundle_path, - requests_fd=self._requests_fd, callback_requests=callbacks, ) - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) - def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse resp: BaseModel | None = None dump_opts = {} if isinstance(msg, DagFileParsingResult): self.parsing_result = msg - return - if isinstance(msg, GetConnection): + elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) @@ -301,10 +296,16 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # resp = self.client.variables.delete(msg.key) else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) @property def is_ready(self) -> bool: @@ -312,7 +313,7 @@ def is_ready(self) -> bool: # Process still alive, def can't be finished yet return False - return self._num_open_sockets == 0 + return not self._open_sockets def wait(self) -> int: raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 5df6a03261523..e829a4fa6eae0 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -28,6 +28,7 @@ from collections.abc import Generator, Iterable from contextlib import suppress from datetime import datetime +from socket import socket from traceback import format_exception from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypedDict, Union @@ -43,6 +44,7 @@ from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger from airflow.sdk.execution_time.comms import ( + CommsDecoder, ConnectionResult, DagRunStateResult, DRCount, @@ -58,6 +60,7 @@ TICount, VariableResult, XComResult, + _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.stats import Stats @@ -70,8 +73,6 @@ from airflow.utils.session import provide_session if TYPE_CHECKING: - from socket import socket - from sqlalchemy.orm import Session from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -181,7 +182,6 @@ class messages: class StartTriggerer(BaseModel): """Tell the async trigger runner process to start, and where to send status update messages.""" - requests_fd: int type: Literal["StartTriggerer"] = "StartTriggerer" class TriggerStateChanges(BaseModel): @@ -295,7 +295,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): """ TriggerRunnerSupervisor is responsible for monitoring the subprocess and marshalling DB access. - This class (which runs in the main process) is responsible for querying the DB, sending RunTrigger + This class (which runs in the main/sync process) is responsible for querying the DB, sending RunTrigger workload messages to the subprocess, and collecting results and updating them in the DB. """ @@ -342,8 +342,8 @@ def start( # type: ignore[override] ): proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs) - msg = messages.StartTriggerer(requests_fd=proc._requests_fd) - proc.send_msg(msg) + msg = messages.StartTriggerer() + proc.send_msg(msg, in_response_to=0) return proc @functools.cached_property @@ -355,7 +355,7 @@ def client(self) -> Client: client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, TaskStatesResponse, @@ -454,8 +454,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - else: raise ValueError(f"Unknown message type {type(msg)}") - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def run(self) -> None: """Run synchronously and handle all database reads/writes.""" @@ -628,7 +627,7 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke ), ) - def _process_log_messages_from_subprocess(self) -> Generator[None, bytes, None]: + def _process_log_messages_from_subprocess(self) -> Generator[None, bytes | bytearray, None]: import msgspec from structlog.stdlib import NAME_TO_LEVEL @@ -691,14 +690,58 @@ class TriggerDetails(TypedDict): events: int +@attrs.define(kw_only=True) +class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): + _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer") + _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader") + + body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field( + factory=lambda: TypeAdapter(ToTriggerRunner), repr=False + ) + + _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + + def _read_frame(self): + from asgiref.sync import async_to_sync + + return async_to_sync(self._aread_frame)() + + def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + from asgiref.sync import async_to_sync + + return async_to_sync(self.asend)(msg) + + async def _aread_frame(self): + len_bytes = await self._async_reader.readexactly(4) + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = await self._async_reader.readexactly(len) + return self.resp_decoder.decode(buffer) + + async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None: + frame = await self._aread_frame() + if frame.id != expect_id: + # Given the lock we take out in `asend`, this _shouldn't_ be possible, but I'd rather fail with + # this explicit error return the wrong type of message back to a Trigger + raise RuntimeError(f"Response read out of order! Got {frame.id=}, {expect_id=}") + return self._from_frame(frame) + + async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + async with self._lock: + self._async_writer.write(bytes) + + return await self._aget_response(frame.id) + + class TriggerRunner: """ Runtime environment for all triggers. - Mainly runs inside its own thread, where it hands control off to an asyncio - event loop, but is also sometimes interacted with from the main thread - (where all the DB queries are done). All communication between threads is - done via Deques. + Mainly runs inside its own process, where it hands control off to an asyncio + event loop. All communication between this and it's (sync) supervisor is done via sockets """ # Maps trigger IDs to their running tasks and other info @@ -726,10 +769,7 @@ class TriggerRunner: # TODO: connect this to the parent process log: FilteringBoundLogger = structlog.get_logger() - requests_sock: asyncio.StreamWriter - response_sock: asyncio.StreamReader - - decoder: TypeAdapter[ToTriggerRunner] + comms_decoder: TriggerCommsDecoder def __init__(self): super().__init__() @@ -740,7 +780,6 @@ def __init__(self): self.events = deque() self.failed_triggers = deque() self.job_id = None - self.decoder = TypeAdapter(ToTriggerRunner) def run(self): """Sync entrypoint - just run a run in an async loop.""" @@ -796,36 +835,21 @@ async def init_comms(self): """ from airflow.sdk.execution_time import task_runner - loop = asyncio.get_event_loop() + # Yes, we read and write to stdin! It's a socket, not a normal stdin. + reader, writer = await asyncio.open_connection(sock=socket(fileno=0)) - comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]( - input=sys.stdin, - decoder=self.decoder, + self.comms_decoder = TriggerCommsDecoder( + async_writer=writer, + async_reader=reader, ) - task_runner.SUPERVISOR_COMMS = comms_decoder - - async def connect_stdin() -> asyncio.StreamReader: - reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, sys.stdin) - return reader - - self.response_sock = await connect_stdin() + task_runner.SUPERVISOR_COMMS = self.comms_decoder - line = await self.response_sock.readline() + msg = await self.comms_decoder._aget_response(expect_id=0) - msg = self.decoder.validate_json(line) if not isinstance(msg, messages.StartTriggerer): raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - writer_transport, writer_protocol = await loop.connect_write_pipe( - lambda: asyncio.streams.FlowControlMixin(loop=loop), - comms_decoder.request_socket, - ) - self.requests_sock = asyncio.streams.StreamWriter(writer_transport, writer_protocol, None, loop) - async def create_triggers(self): """Drain the to_create queue and create all new triggers that have been requested in the DB.""" while self.to_create: @@ -934,8 +958,6 @@ async def cleanup_finished_triggers(self) -> list[int]: return finished_ids async def sync_state_to_supervisor(self, finished_ids: list[int]): - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Copy out of our deques in threadsafe manner to sync state with parent events_to_send = [] while self.events: @@ -961,19 +983,17 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]): if not finished_ids: msg.finished = None - # Block triggers from making any requests for the duration of this - async with SUPERVISOR_COMMS.lock: - # Tell the monitor that we've finished triggers so it can update things - self.requests_sock.write(msg.model_dump_json(exclude_none=True).encode() + b"\n") - line = await self.response_sock.readline() - - if line == b"": # EoF received! + # Tell the monitor that we've finished triggers so it can update things + try: + resp = await self.comms_decoder.asend(msg) + except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") + return + raise - resp = self.decoder.validate_json(line) if not isinstance(resp, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got f{type(msg)}") + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") self.to_create.extend(resp.to_create) self.to_cancel.extend(resp.to_cancel) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index c9e974b8cdb9a..98204033ced9e 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -30,10 +30,11 @@ from datetime import datetime, timedelta from logging.config import dictConfig from pathlib import Path -from socket import socket +from socket import socket, socketpair from unittest import mock from unittest.mock import MagicMock +import msgspec import pytest import time_machine from sqlalchemy import func, select @@ -54,7 +55,6 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.execution_time.supervisor import mkpipe from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import create_session @@ -138,20 +138,19 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces logger_filehandle = MagicMock() proc.create_time.return_value = time.time() proc.wait.return_value = 0 - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socketpair() ret = DagFileProcessorProcess( process_log=MagicMock(), id=uuid7(), pid=1234, process=proc, stdin=write_end, - requests_fd=123, logger_filehandle=logger_filehandle, client=MagicMock(), ) if start_time: ret.start_time = start_time - ret._num_open_sockets = 0 + ret._open_sockets.clear() return ret, read_end @pytest.fixture @@ -552,18 +551,17 @@ def test_kill_timed_out_processors_no_kill(self): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.parametrize( - ["callbacks", "path", "expected_buffer"], + ["callbacks", "path", "expected_body"], [ pytest.param( [], "/opt/airflow/dags/test_dag.py", - b"{" - b'"file":"/opt/airflow/dags/test_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,' - b'"callback_requests":[],' - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/test_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [], + "type": "DagFileParseRequest", + }, ), pytest.param( [ @@ -577,44 +575,40 @@ def test_kill_timed_out_processors_no_kill(self): ) ], "/opt/airflow/dags/dag_callback_dag.py", - b"{" - b'"file":"/opt/airflow/dags/dag_callback_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,"callback_requests":' - b"[" - b"{" - b'"filepath":"dag_callback_dag.py",' - b'"bundle_name":"testing",' - b'"bundle_version":null,' - b'"msg":null,' - b'"dag_id":"dag_id",' - b'"run_id":"run_id",' - b'"is_failure_callback":false,' - b'"type":"DagCallbackRequest"' - b"}" - b"]," - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/dag_callback_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [ + { + "filepath": "dag_callback_dag.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": None, + "dag_id": "dag_id", + "run_id": "run_id", + "is_failure_callback": False, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ), ], ) - def test_serialize_callback_requests(self, callbacks, path, expected_buffer): + def test_serialize_callback_requests(self, callbacks, path, expected_body): + from airflow.sdk.execution_time.comms import _ResponseFrame + processor, read_socket = self.mock_processor() processor._on_child_started(callbacks, path, bundle_path=Path("/opt/airflow/dags")) read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(4096) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError): - # no response written, valid for some message types. - pass - - assert val == expected_buffer + # Read response from the read end of the socket + read_socket.settimeout(0.1) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.body == expected_body @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 28ce7a8c23fe6..eb6a58dfcfbfb 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -42,7 +42,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.api.client import Client -from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -87,7 +87,6 @@ def _process_file( DagFileParseRequest( file=file_path, bundle_path=TEST_DAG_FOLDER, - requests_fd=1, callback_requests=callback_requests or [], ), log=structlog.get_logger(), @@ -393,26 +392,38 @@ def disable_capturing(): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.usefixtures("disable_capturing") -def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency): +def test_parse_file_entrypoint_parses_dag_callbacks(mocker): r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags","requests_fd":' - + str(w2.fileno()).encode("ascii") - + b',"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' - b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": ' - b'"manual__2024-12-30T21:02:55.203691+00:00", ' - b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": "DagFileParseRequest"}\n' + + frame = comms._ResponseFrame( + id=1, + body={ + "file": "/files/dags/wait.py", + "bundle_path": "/files/dags", + "callback_requests": [ + { + "filepath": "wait.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": "task_failure", + "dag_id": "wait_to_fail", + "run_id": "manual__2024-12-30T21:02:55.203691+00:00", + "is_failure_callback": True, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ) + bytes = frame.as_bytes() + w.sendall(bytes) - decoder = CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - input=r.makefile("r"), - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + request_socket=r, + body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), ) - msg = decoder.get_message() + msg = decoder._get_response() assert isinstance(msg, DagFileParseRequest) assert msg.file == "/files/dags/wait.py" assert msg.callback_requests == [ @@ -455,7 +466,7 @@ def fake_collect_dags(self, *args, **kwargs): ) ] _parse_file( - DagFileParseRequest(file="A", bundle_path="no matter", requests_fd=1, callback_requests=requests), + DagFileParseRequest(file="A", bundle_path="no matter", callback_requests=requests), log=structlog.get_logger(), ) @@ -489,8 +500,6 @@ def fake_collect_dags(self, *args, **kwargs): bundle_version=None, ) ] - _parse_file( - DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() - ) + _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) assert called is True diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d3c9ae6f27dc9..d3cecb5fc94e6 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -34,6 +34,7 @@ from airflow.hooks.base import BaseHook from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( + TriggerCommsDecoder, TriggererJobRunner, TriggerRunner, TriggerRunnerSupervisor, @@ -172,7 +173,6 @@ def builder(job=None): pid=process.pid, stdin=mocker.Mock(), process=process, - requests_fd=-1, capacity=10, ) # Mock the selector @@ -302,10 +302,9 @@ async def test_invalid_trigger(self, supervisor_builder): id=1, ti=None, classpath="fake.classpath", encrypted_kwargs={} ) trigger_runner = TriggerRunner() - trigger_runner.requests_sock = MagicMock() - trigger_runner.response_sock = AsyncMock() - trigger_runner.response_sock.readline.return_value = ( - b'{"type": "TriggerStateSync", "to_create": [], "to_cancel": []}\n' + trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) + trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( + to_create=[], to_cancel=[] ) trigger_runner.to_create.append(workload) @@ -316,9 +315,9 @@ async def test_invalid_trigger(self, supervisor_builder): await trigger_runner.sync_state_to_supervisor(ids) # Check that we sent the right info in the failure message - assert trigger_runner.requests_sock.write.call_count == 1 - blob = trigger_runner.requests_sock.write.mock_calls[0].args[0] - msg = messages.TriggerStateChanges.model_validate_json(blob) + assert trigger_runner.comms_decoder.asend.call_count == 1 + msg = trigger_runner.comms_decoder.asend.mock_calls[0].args[0] + assert isinstance(msg, messages.TriggerStateChanges) assert msg.events is None assert msg.failures is not None @@ -552,6 +551,7 @@ def test_failed_trigger(session, dag_maker, supervisor_builder): ) ], ), + req_id=1, log=MagicMock(), ) @@ -622,10 +622,6 @@ def handle_events(self): super().handle_events() -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio @pytest.mark.execution_timeout(20) async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): @@ -726,13 +722,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"count": dag_run_states_count, "dag_run_state": dag_run_state}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" # Create the test DAG and task @@ -822,13 +813,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count, "task_states": task_states}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states.""" # Create the test DAG and task diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index b5ae53625e256..73d6c82bcafa2 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1963,8 +1963,13 @@ def mock_supervisor_comms(): if not AIRFLOW_V_3_0_PLUS: yield None return + + from airflow.sdk.execution_time.comms import CommsDecoder + with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", + create=True, + spec=CommsDecoder, ) as supervisor_comms: yield supervisor_comms @@ -1991,7 +1996,6 @@ def mocked_parse(spy_agency): id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1 ), file="", - requests_fd=0, ), "example_dag_id", CustomOperator(task_id="hello"), @@ -2198,7 +2202,6 @@ def _create_task_instance( ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, ti_context=ti_context, start_date=start_date, # type: ignore ) diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 65d8166e14a77..39322def2d81d 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -47,7 +47,6 @@ classifiers = [ dependencies = [ "apache-airflow-core<3.2.0,>=3.1.0", - "aiologic>=0.14.0", "attrs>=24.2.0, !=25.2.0", "fsspec>=2023.10.0", "httpx>=0.27.0", diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index c9b777daca32e..770dbf53df769 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -77,9 +77,8 @@ def set( map_index=map_index, ) - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -114,9 +113,8 @@ def _set_xcom_in_db( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -188,23 +186,16 @@ def _get_xcom_db_ref( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -248,23 +239,16 @@ def get_one( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - include_prior_dates=include_prior_dates, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + include_prior_dates=include_prior_dates, + ), + ) if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -307,24 +291,17 @@ def get_all( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - start=None, - stop=None, - step=None, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + msg=GetXComSequenceSlice( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + start=None, + stop=None, + step=None, + ), + ) if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") @@ -379,9 +356,8 @@ def delete( map_index=map_index, ) cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( + SUPERVISOR_COMMS.send( + DeleteXCom( key=key, dag_id=dag_id, task_id=task_id, diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index d899dd3718336..bfe5ac2fc960f 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -69,18 +69,14 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> ) def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: - import structlog - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name=self.__class__.__qualname__) - def _fetch_asset(name: str) -> Asset: - SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name)) - if isinstance(msg := SUPERVISOR_COMMS.get_message(), ErrorResponse): - raise AirflowRuntimeError(msg) - return Asset(**msg.model_dump(exclude={"type"})) + resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) + return Asset(**resp.model_dump(exclude={"type"})) value: Any for key, param in inspect.signature(self.python_callable).parameters.items(): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d0622cf6ffc86..67a298a6c14ab 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -19,12 +19,17 @@ Communication protocol between the Supervisor and the task process ================================================================== -* All communication is done over stdout/stdin in the form of "JSON lines" (each - message is a single JSON document terminated by `\n` character) -* Messages from the subprocess are all log messages and are sent directly to the log +* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame + (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same + encoding +* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) * No messages are sent to task process except in response to a request. (This is because the task process will be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom value etc.) +* Every request returns a response, even if the frame is otherwise empty. +* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a + bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending + requests. The reason this communication protocol exists, rather than the task process speaking directly to the Task Execution API server is because: @@ -43,15 +48,20 @@ from __future__ import annotations +import itertools from collections.abc import Iterator from datetime import datetime from functools import cached_property -from typing import Annotated, Any, Literal, Union +from pathlib import Path +from socket import socket +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from uuid import UUID import attrs +import msgspec +import structlog from fastapi import Body -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetEventDagRunReference, @@ -80,6 +90,144 @@ ) from airflow.sdk.exceptions import ErrorType +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + +SendMsgType = TypeVar("SendMsgType", bound=BaseModel) +ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) + + +def _msgpack_enc_hook(obj: Any) -> Any: + import pendulum + + if isinstance(obj, pendulum.DateTime): + # convert the complex to a tuple of real, imag + return datetime( + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo + ) + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, BaseModel): + return obj.model_dump(exclude_unset=True) + + # Raise a NotImplementedError for other types + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def _new_encoder() -> msgspec.msgpack.Encoder: + return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) + + +class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The request id, set by the sender. + + This is used to allow "pipeling" of requests and to be able to tie response to requests, which is + particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. + """ + body: dict[str, Any] | None + + req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() + + def as_bytes(self) -> bytearray: + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + buffer = bytearray(256) + + self.req_encoder.encode_into(self, buffer, 4) + + n = len(buffer) - 4 + if n > 2**32: + raise OverflowError("Cannot send messages larger than 4GiB") + buffer[:4] = n.to_bytes(4, byteorder="big") + + return buffer + + +class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The id of the request this is a response to + """ + body: dict[str, Any] | None = None + error: dict[str, Any] | None = None + + +@attrs.define() +class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): + """Handle communication between the task in this process and the supervisor parent process.""" + + log: Logger = attrs.field(repr=False, factory=structlog.get_logger) + request_socket: socket = attrs.field(factory=lambda: socket(fileno=0)) + + resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( + factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False + ) + + id_counter: Iterator[int] = attrs.field(factory=itertools.count) + + # We could be "clever" here and set the default to this based type parameters and a custom + # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a + # "sort of wrong default" + body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + def send(self, msg: SendMsgType) -> ReceiveMsgType | None: + """Send a request to the parent and block until the response is received.""" + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + self.request_socket.sendall(bytes) + + return self._get_response() + + def _read_frame(self): + """ + Get a message from the parent. + + This will block until the message has been received. + """ + if self.request_socket: + self.request_socket.setblocking(True) + len_bytes = self.request_socket.recv(4) + + if len_bytes == b"": + raise EOFError("Request socket closed before length") + + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = bytearray(len) + nread = self.request_socket.recv_into(buffer) + if nread != len: + raise RuntimeError( + f"unable to read full response in child. (We read {nread}, but expected {len})" + ) + if nread == 0: + raise EOFError("Request socket closed before response was complete") + + return self.resp_decoder.decode(buffer) + + def _from_frame(self, frame) -> ReceiveMsgType | None: + from airflow.sdk.exceptions import AirflowRuntimeError + + if frame.error is not None: + err = self.err_decoder.validate_python(frame.error) + raise AirflowRuntimeError(error=err) + + if frame.body is None: + return None + + try: + return self.body_decoder.validate_python(frame.body) + except Exception: + self.log.exception("Unable to decode message") + raise + + def _get_response(self) -> ReceiveMsgType | None: + frame = self._read_frame() + return self._from_frame(frame) + class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -87,13 +235,7 @@ class StartupDetails(BaseModel): ti: TaskInstance dag_rel_path: str bundle_info: BundleInfo - requests_fd: int start_date: datetime - """ - The channel for the task to send requests over. - - Responses will come back on stdin - """ ti_context: TIRunContext type: Literal["StartupDetails"] = "StartupDetails" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 438687bdeb870..0216ff0ce050b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -150,13 +150,7 @@ def _get_connection(conn_id: str) -> Connection: from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -203,13 +197,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -259,11 +247,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) def _delete_variable(key: str) -> None: @@ -275,12 +259,7 @@ def _delete_variable(key: str) -> None: from airflow.sdk.execution_time.comms import DeleteVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=DeleteVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -383,22 +362,28 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: @staticmethod def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: from airflow.sdk.definitions.asset import Asset - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetAssetByName, + GetAssetByUri, + ToSupervisor, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if name: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + msg = GetAssetByName(name=name) elif uri: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + msg = GetAssetByUri(uri=uri) else: raise ValueError("Either name or uri must be provided") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetResult) + assert isinstance(resp, AssetResult) return Asset(**msg.model_dump(exclude={"type"})) @@ -533,9 +518,11 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve ErrorResponse, GetAssetEventByAsset, GetAssetEventByAssetAlias, + ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): @@ -545,31 +532,33 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve if isinstance(obj, Asset): asset = self._assets[AssetUniqueKey.from_asset(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri)) + msg = GetAssetEventByAsset(name=asset.name, uri=asset.uri) elif isinstance(obj, AssetNameRef): try: asset = next(a for k, a in self._assets.items() if k.name == obj.name) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=None)) + msg = GetAssetEventByAsset(name=asset.name, uri=None) elif isinstance(obj, AssetUriRef): try: asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=None, uri=asset.uri)) + msg = GetAssetEventByAsset(name=None, uri=asset.uri) elif isinstance(obj, AssetAlias): asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name)) + msg = GetAssetEventByAssetAlias(alias_name=asset_alias.name) + else: + raise TypeError(f"`key` is of unknown type ({type(key).__name__})") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetEventsResult) + assert isinstance(resp, AssetEventsResult) - return list(msg.iter_asset_event_results()) + return list(resp.iter_asset_event_results()) @attrs.define @@ -626,8 +615,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 9cf9acfac81bb..d8b356870c7aa 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -91,16 +91,14 @@ def __len__(self) -> int: task = self._xcom_arg.operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComCount( + msg = SUPERVISOR_COMMS.send( + GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, run_id=self._ti.run_id, task_id=task.task_id, ), ) - msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise RuntimeError(msg) if not isinstance(msg, XComCountResponse): @@ -127,43 +125,37 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if isinstance(key, slice): start, stop, step = _coerce_slice(key) - with SUPERVISOR_COMMS.lock: - source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=xcom_arg.key, - dag_id=source.dag_id, - task_id=source.task_id, - run_id=self._ti.run_id, - start=start, - stop=stop, - step=step, - ), - ) - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComSequenceSliceResult): - raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") - return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] - - if not isinstance(key, int): - if (index := getattr(key, "__index__", None)) is not None: - key = index() - raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") - - with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceItem( + msg = SUPERVISOR_COMMS.send( + GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, task_id=source.task_id, run_id=self._ti.run_id, - offset=key, + start=start, + stop=stop, + step=step, ), ) - msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") + return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] + + if not isinstance(key, int): + if (index := getattr(key, "__index__", None)) is not None: + key = index() + raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") + + source = (xcom_arg := self._xcom_arg).operator + msg = SUPERVISOR_COMMS.send( + GetXComSequenceItem( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + offset=key, + ), + ) if isinstance(msg, ErrorResponse): raise IndexError(key) if not isinstance(msg, XComSequenceIndexResult): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index dc63daeef7bb7..cf37be96a318d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,12 +27,13 @@ import signal import sys import time +import weakref from collections import deque from collections.abc import Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus -from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair +from socket import SO_SNDBUF, SOL_SOCKET, socket, socketpair from typing import ( TYPE_CHECKING, BinaryIO, @@ -64,6 +65,7 @@ XComSequenceIndexResponse, ) from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -109,6 +111,8 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -193,6 +197,8 @@ def mkpipe( local.setsockopt(SO_SNDBUF, SOL_SOCKET, BUFFER_SIZE) # set nonblocking to True so that send or sendall waits till all data is sent local.setblocking(True) + else: + local.setblocking(False) return remote, local @@ -224,14 +230,13 @@ def _configure_logs_over_json_channel(log_fd: int): def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, fd, sock, mode in ( - ("stdin", 0, child_stdin, "r"), + # Yes, we want to re-open stdin in write mode! This is cause it is a bi-directional socket, so we can + # read and write to it. + ("stdin", 0, child_stdin, "w"), ("stdout", 1, child_stdout, "w"), ("stderr", 2, child_stderr, "w"), ): - handle = getattr(sys, handle_name) - handle.close() os.dup2(sock.fileno(), fd) del sock @@ -319,7 +324,7 @@ def __getattr__(name: str): def _fork_main( - child_stdin: socket, + requests: socket, child_stdout: socket, child_stderr: socket, log_fd: int, @@ -346,10 +351,12 @@ def _fork_main( # Store original stderr for last-chance exception handling last_chance_stderr = _get_last_chance_stderr() + # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno()) + _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) - _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + _reopen_std_io_handles(requests, child_stdout, child_stderr) def exit(n: int) -> NoReturn: with suppress(ValueError, OSError): @@ -360,11 +367,11 @@ def exit(n: int) -> NoReturn: last_chance_stderr.flush() # Explicitly close the child-end of our supervisor sockets so - # the parent sees EOF on both "requests" and "logs" channels. + # the parent sees EOF on "logs" channel. with suppress(OSError): os.close(log_fd) with suppress(OSError): - os.close(child_stdin.fileno()) + os.close(requests.fileno()) os._exit(n) if hasattr(atexit, "_clear"): @@ -432,16 +439,18 @@ class WatchedSubprocess: """The decoder to use for incoming messages from the child process.""" _process: psutil.Process = attrs.field(repr=False) - _requests_fd: int """File descriptor for request handling.""" - _num_open_sockets: int = 4 _exit_code: int | None = attrs.field(default=None, init=False) _process_exit_monotonic: float | None = attrs.field(default=None, init=False) - _fd_to_socket_type: dict[int, str] = attrs.field(factory=dict, init=False) + _open_sockets: weakref.WeakKeyDictionary[socket, str] = attrs.field( + factory=weakref.WeakKeyDictionary, init=False + ) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector, repr=False) + _frame_encoder: msgspec.msgpack.Encoder = attrs.field(factory=comms._new_encoder, repr=False) + process_log: FilteringBoundLogger = attrs.field(repr=False) subprocess_logs_to_stdout: bool = False @@ -460,18 +469,19 @@ def start( ) -> Self: """Fork and start a new subprocess with the specified target function.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess - child_stdin, feed_stdin = mkpipe(remote_read=True) - child_stdout, read_stdout = mkpipe() - child_stderr, read_stderr = mkpipe() + child_stdout, read_stdout = socketpair() + child_stderr, read_stderr = socketpair() + + # Place for child to send requests/read responses, and the server side to read/respond + child_requests, read_requests = socketpair() - # Open these socketpair before forking off the child, so that it is open when we fork. - child_comms, read_msgs = mkpipe() - child_logs, read_logs = mkpipe() + # Open the socketpair before forking off the child, so that it is open when we fork. + child_logs, read_logs = socketpair() pid = os.fork() if pid == 0: # Close and delete of the parent end of the sockets. - cls._close_unused_sockets(feed_stdin, read_stdout, read_stderr, read_msgs, read_logs) + cls._close_unused_sockets(read_requests, read_stdout, read_stderr, read_logs) # Python GC should delete these for us, but lets make double sure that we don't keep anything # around in the forked processes, especially things that might involve open files or sockets! @@ -480,28 +490,28 @@ def start( try: # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + _fork_main(child_requests, child_stdout, child_stderr, child_logs.fileno(), target) except BaseException as e: + import traceback + with suppress(BaseException): # We can't use log here, as if we except out of _fork_main something _weird_ went on. - print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + print("Exception in _fork_main, exiting with code 124", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr) # It's really super super important we never exit this block. We are in the forked child, and if we # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) os._exit(124) - requests_fd = child_comms.fileno() - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) + cls._close_unused_sockets(child_stdout, child_stderr, child_logs) logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc = cls( pid=pid, - stdin=feed_stdin, + stdin=read_requests, process=psutil.Process(pid), - requests_fd=requests_fd, process_log=logger, start_time=time.monotonic(), **constructor_kwargs, @@ -510,7 +520,7 @@ def start( proc._register_pipe_readers( stdout=read_stdout, stderr=read_stderr, - requests=read_msgs, + requests=read_requests, logs=read_logs, ) @@ -523,24 +533,26 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - # Track socket types for debugging - self._fd_to_socket_type = { - stdout.fileno(): "stdout", - stderr.fileno(): "stderr", - requests.fileno(): "requests", - logs.fileno(): "logs", - } + # Track the open sockets, and for debugging what type each one is + self._open_sockets.update( + ( + (stdout, "stdout"), + (stderr, "stderr"), + (logs, "logs"), + (requests, "requests"), + ) + ) target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) if self.subprocess_logs_to_stdout: target_loggers += (log,) self.selector.register( - stdout, selectors.EVENT_READ, self._create_socket_handler(target_loggers, channel="stdout") + stdout, selectors.EVENT_READ, self._create_log_forwarder(target_loggers, channel="stdout") ) self.selector.register( stderr, selectors.EVENT_READ, - self._create_socket_handler(target_loggers, channel="stderr", log_level=logging.ERROR), + self._create_log_forwarder(target_loggers, channel="stderr", log_level=logging.ERROR), ) self.selector.register( logs, @@ -552,37 +564,47 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke self.selector.register( requests, selectors.EVENT_READ, - make_buffered_socket_reader(self.handle_requests(log), on_close=self._on_socket_closed), + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), ) - def _create_socket_handler(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: + def _create_log_forwarder(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: """Create a socket handler that forwards logs to a logger.""" return make_buffered_socket_reader( forward_to_log(loggers, chan=channel, level=log_level), on_close=self._on_socket_closed ) - def _on_socket_closed(self): + def _on_socket_closed(self, sock: socket): # We want to keep servicing this process until we've read up to EOF from all the sockets. - self._num_open_sockets -= 1 - def send_msg(self, msg: BaseModel, **dump_opts): - """Send the given pydantic message to the subprocess at once by encoding it and adding a line break.""" - b = msg.model_dump_json(**dump_opts).encode() + b"\n" - self.stdin.sendall(b) + with suppress(KeyError): + self.selector.unregister(sock) + del self._open_sockets[sock] - def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): + """Send the msg as a length-prefixed response frame.""" + if msg: + frame = _ResponseFrame(id=in_response_to, body=msg.model_dump(**dump_opts)) + else: + err_resp = error.model_dump() if error else None + frame = _ResponseFrame(id=in_response_to, error=err_resp) + + self.stdin.sendall(frame.as_bytes()) + + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _RequestFrame, None]: """Handle incoming requests from the task process, respond with the appropriate data.""" while True: - line = yield + request = yield try: - msg = self.decoder.validate_json(line) + msg = self.decoder.validate_python(request.body) except Exception: - log.exception("Unable to decode message", line=line) + log.exception("Unable to decode message", body=request.body) continue try: - self._handle_request(msg, log) + self._handle_request(msg, log, request.id) except ServerResponseError as e: error_details = e.response.json() if e.response else None log.error( @@ -594,27 +616,26 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N # Send error response back to task so that the error appears in the task logs self.send_msg( - ErrorResponse( + msg=None, + error=ErrorResponse( error=ErrorType.API_SERVER_ERROR, detail={ "status_code": e.response.status_code, "message": str(e), "detail": error_details, }, - ) + ), + in_response_to=request.id, ) + return - def _handle_request(self, msg, log: FilteringBoundLogger) -> None: + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: raise NotImplementedError() @staticmethod def _close_unused_sockets(*sockets): """Close unused ends of sockets after fork.""" for sock in sockets: - if isinstance(sock, SocketIO): - # If we have the socket IO object, we need to close the underlying socket foricebly here too, - # else we get unclosed socket warnings, and likely leaking FDs too - sock._sock.close() sock.close() def _cleanup_open_sockets(self): @@ -624,20 +645,18 @@ def _cleanup_open_sockets(self): # sockets the supervisor would wait forever thinking they are still # active. This cleanup ensures we always release resources and exit. stuck_sockets = [] - for key in list(self.selector.get_map().values()): - socket_type = self._fd_to_socket_type.get(key.fd, f"unknown-{key.fd}") - stuck_sockets.append(f"{socket_type}({key.fd})") - with suppress(Exception): - self.selector.unregister(key.fileobj) + for sock, socket_type in self._open_sockets.items(): + fileno = "unknown" with suppress(Exception): - key.fileobj.close() # type: ignore[union-attr] + fileno = sock.fileno() + sock.close() + stuck_sockets.append(f"{socket_type}(fd={fileno})") if stuck_sockets: log.warning("Force-closed stuck sockets", pid=self.pid, sockets=stuck_sockets) self.selector.close() - self._close_unused_sockets(self.stdin) - self._num_open_sockets = 0 + self.stdin.close() def kill( self, @@ -736,7 +755,7 @@ def _service_subprocess( events = self.selector.select(timeout=timeout) for key, _ in events: # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) - socket_handler = key.data + socket_handler, on_close = key.data # Example of handler behavior: # If the subprocess writes "Hello, World!" to stdout: @@ -753,8 +772,9 @@ def _service_subprocess( # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors # are removed. if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -773,16 +793,16 @@ def _check_subprocess_exit( raise else: self._process_exit_monotonic = time.monotonic() - self._close_unused_sockets(self.stdin) - # Put a message in the viewable task logs if expect_signal is not None and self._exit_code == -expect_signal: # Bypass logging, the caller expected us to exit with this return self._exit_code - # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): + # Put a message in the viewable task logs + if self._exit_code == -signal.SIGSEGV: self.process_log.critical(SIGSEGV_MESSAGE) + # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): elif name := getattr(self._exit_code, "name", None): message = "Process terminated by signal" level = logging.ERROR @@ -861,7 +881,6 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st ti=ti, dag_rel_path=os.fspath(dag_rel_path), bundle_info=bundle_info, - requests_fd=self._requests_fd, ti_context=ti_context, start_date=start_date, ) @@ -870,7 +889,7 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st log.debug("Sending", msg=msg) try: - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) except BrokenPipeError: # Debug is fine, the process will have shown _something_ in it's last_chance exception handler log.debug("Couldn't send startup message to Subprocess - it died very early", pid=self.pid) @@ -930,7 +949,7 @@ def _monitor_subprocess(self): - Processes events triggered on the monitored file objects, such as data availability or EOF. - Sends heartbeats to ensure the process is alive and checks if the subprocess has exited. """ - while self._exit_code is None or self._num_open_sockets > 0: + while self._exit_code is None or self._open_sockets: last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. @@ -946,16 +965,11 @@ def _monitor_subprocess(self): # This listens for activity (e.g., subprocess output) on registered file objects alive = self._service_subprocess(max_wait_time=max_wait_time) is None - if self._exit_code is not None and self._num_open_sockets > 0: + if self._exit_code is not None and self._open_sockets: if ( self._process_exit_monotonic and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT ): - log.debug( - "Forcefully closing remaining sockets", - open_sockets=self._num_open_sockets, - pid=self.pid, - ) self._cleanup_open_sockets() if alive: @@ -1051,7 +1065,7 @@ def final_state(self): return SERVER_TERMINATED return TaskInstanceState.FAILED - def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} @@ -1224,10 +1238,17 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): dump_opts = {"exclude_unset": True} else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def in_process_api_server(): @@ -1254,7 +1275,7 @@ def send_request(self, log, msg: BaseModel): log.debug("Sending request", msg=msg) with set_supervisor_comms(None): - self.supervisor._handle_request(msg, log) # type: ignore[arg-type] + self.supervisor._handle_request(msg, log, 0) # type: ignore[arg-type] @attrs.define @@ -1298,7 +1319,6 @@ def start( # type: ignore[override] id=what.id, pid=os.getpid(), # Use current process process=psutil.Process(), # Current process - requests_fd=-1, # Not used in in-process mode process_log=logger or structlog.get_logger(logger_name="task").bind(), client=cls._api_client(task.dag), **kwargs, @@ -1363,9 +1383,12 @@ def _api_client(dag=None): client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def send_msg(self, msg: BaseModel, **dump_opts): + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): """Override to use in-process comms.""" - self.comms.messages.append(msg) + if msg is not None: + self.comms.messages.append(msg) @property def final_state(self): @@ -1421,9 +1444,9 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: # to a (sync) generator def make_buffered_socket_reader( gen: Generator[None, bytes | bytearray, None], - on_close: Callable, + on_close: Callable[[socket], None], buffer_size: int = 4096, -) -> Callable[[socket], bool]: +): buffer = bytearray() # This will hold our accumulated binary data read_buffer = bytearray(buffer_size) # Temporary buffer for each read @@ -1441,7 +1464,6 @@ def cb(sock: socket): with suppress(StopIteration): gen.send(buffer) # Tell loop to close this selector - on_close() return False buffer.extend(read_buffer[:n_received]) @@ -1452,18 +1474,62 @@ def cb(sock: socket): try: gen.send(line) except StopIteration: - on_close() return False buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data return True - return cb + return cb, on_close + + +def length_prefixed_frame_reader( + gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None] +): + length_needed: int | None = None + # This will hold our accumulated/partial binary frame if it doesn't come in a single read + buffer: memoryview | None = None + # position in the buffer to store next read + pos = 0 + decoder = msgspec.msgpack.Decoder[_RequestFrame](_RequestFrame) + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, length_needed, pos + + if length_needed is None: + # Read the 32bit length of the frame + bytes = sock.recv(4) + if bytes == b"": + return False + + length_needed = int.from_bytes(bytes, byteorder="big") + buffer = memoryview(bytearray(length_needed)) + if length_needed and buffer: + n = sock.recv_into(buffer[pos:]) + if n == 0: + # EOF + return False + pos += n + + if pos >= length_needed: + request = decoder.decode(buffer) + buffer = None + pos = 0 + length_needed = None + try: + gen.send(request) + except StopIteration: + return False + return True + + return cb, on_close def process_log_messages_from_subprocess( loggers: tuple[FilteringBoundLogger, ...], -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: from structlog.stdlib import NAME_TO_LEVEL while True: @@ -1499,10 +1565,9 @@ def process_log_messages_from_subprocess( def forward_to_log( target_loggers: tuple[FilteringBoundLogger, ...], chan: str, level: int -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: while True: - buf = yield - line = bytes(buf) + line = yield # Strip off new line line = line.rstrip() try: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 2374520e63f21..0268e24f2177b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -28,16 +28,14 @@ from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from datetime import datetime, timezone -from io import FileIO from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TextIO, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Literal -import aiologic import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -59,6 +57,7 @@ from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, + CommsDecoder, DagRunStateResult, DeferTask, DRCount, @@ -424,10 +423,9 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - SUPERVISOR_COMMS.send_request( - log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) + response = SUPERVISOR_COMMS.send( + msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -445,22 +443,17 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTICount( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, TICount) @@ -477,21 +470,16 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTaskStates( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTaskStates( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + ), + ) if TYPE_CHECKING: assert isinstance(response, TaskStatesResult) @@ -506,19 +494,14 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of DAG runs matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, DRCount) @@ -528,10 +511,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the DAG run with the given Run ID.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -650,62 +630,6 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -SendMsgType = TypeVar("SendMsgType", bound=BaseModel) -ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) - - -@attrs.define() -class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): - """Handle communication between the task in this process and the supervisor parent process.""" - - input: TextIO - - request_socket: FileIO = attrs.field(init=False, default=None) - - # We could be "clever" here and set the default to this based type parameters and a custom - # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a - # "sort of wrong default" - decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) - - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False) - - def get_message(self) -> ReceiveMsgType: - """ - Get a message from the parent. - - This will block until the message has been received. - """ - line = None - - # TODO: Investigate why some empty lines are sent to the processes stdin. - # That was highlighted when working on https://github.com/apache/airflow/issues/48183 - # and is maybe related to deferred/triggerer only context. - while not line: - line = self.input.readline() - - try: - msg = self.decoder.validate_json(line) - except Exception: - structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) - raise - - if isinstance(msg, StartupDetails): - # If we read a startup message, pull out the FDs we care about! - if msg.requests_fd > 0: - self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - elif isinstance(msg, ErrorResponse) and msg.error == ErrorType.API_SERVER_ERROR: - structlog.get_logger(logger_name="task").error("Error response from the API Server") - raise AirflowRuntimeError(error=msg) - - return msg - - def send_request(self, log: Logger, msg: SendMsgType): - encoded_msg = msg.model_dump_json().encode() + b"\n" - - log.debug("Sending request", json=encoded_msg) - self.request_socket.write(encoded_msg) - - # This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # @@ -725,31 +649,33 @@ def send_request(self, log: Logger, msg: SendMsgType): def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - msg = SUPERVISOR_COMMS.get_message() + # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent + # in response to us sending a request. + msg = SUPERVISOR_COMMS._get_response() + + if not isinstance(msg, StartupDetails): + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") log = structlog.get_logger(logger_name="task") + # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 + os_type = sys.platform + if os_type == "darwin": + log.debug("Mac OS detected, skipping setproctitle") + else: + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + try: get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) except Exception: log.exception("error calling listener") - if isinstance(msg, StartupDetails): - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 - os_type = sys.platform - if os_type == "darwin": - log.debug("Mac OS detected, skipping setproctitle") - else: - from setproctitle import setproctitle - - setproctitle(f"airflow worker -- {msg.ti.id}") - - with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): - ti = parse(msg, log) - ti.log_url = get_log_url_from_ti(ti) - log.debug("DAG file parsed", file=msg.dag_rel_path) - else: - raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): + ti = parse(msg, log) + ti.log_url = get_log_url_from_ti(ti) + log.debug("DAG file parsed", file=msg.dag_rel_path) return ti, ti.get_template_context(), log @@ -797,7 +723,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) _validate_task_inlets_and_outlets(ti=ti, log=log) @@ -817,8 +743,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log) - inactive_assets_resp = SUPERVISOR_COMMS.get_message() + inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -914,7 +839,7 @@ def run( except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send_request(log=log, msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -976,7 +901,7 @@ def run( error = e finally: if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + SUPERVISOR_COMMS.send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1014,9 +939,8 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - SUPERVISOR_COMMS.send_request( - log=log, - msg=TriggerDagRun( + comms_msg = SUPERVISOR_COMMS.send( + TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, logical_date=drte.logical_date, @@ -1025,7 +949,6 @@ def _handle_trigger_dag_run( ), ) - comms_msg = SUPERVISOR_COMMS.get_message() if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: if drte.skip_when_already_exists: log.info( @@ -1080,10 +1003,9 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - SUPERVISOR_COMMS.send_request( - log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) + comms_msg = SUPERVISOR_COMMS.send( + GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) ) - comms_msg = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1272,10 +1194,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)), - ) + SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1316,9 +1235,11 @@ def finalize( def main(): - # TODO: add an exception here, it causes an oof of a stack trace! + # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! + log = structlog.get_logger(logger_name="task") + global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) try: ti, context, log = startup() @@ -1329,11 +1250,9 @@ def main(): state, msg, error = run(ti, context, log) finalize(ti, state, context, log, error) except KeyboardInterrupt: - log = structlog.get_logger(logger_name="task") log.exception("Ctrl-c hit") exit(2) except Exception: - log = structlog.get_logger(logger_name="task") log.exception("Top level error") exit(1) finally: diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py new file mode 100644 index 0000000000000..ee2f956b063c9 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +import msgspec +import pytest + +from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails, _ResponseFrame +from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.utils import timezone + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + @pytest.mark.usefixtures("disable_capturing") + def test_recv_StartupDetails(self): + r, w = socketpair() + + msg = { + "type": "StartupDetails", + "ti": { + "id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + "task_id": "a", + "try_number": 1, + "run_id": "b", + "dag_id": "c", + }, + "ti_context": { + "dag_run": { + "dag_id": "c", + "run_id": "b", + "logical_date": "2024-12-01T01:00:00Z", + "data_interval_start": "2024-12-01T00:00:00Z", + "data_interval_end": "2024-12-01T01:00:00Z", + "start_date": "2024-12-01T01:00:00Z", + "run_after": "2024-12-01T01:00:00Z", + "end_date": None, + "run_type": "manual", + "conf": None, + "consumed_asset_events": [], + }, + "max_tries": 0, + "should_retry": False, + "variables": None, + "connections": None, + }, + "file": "/dev/null", + "start_date": "2024-12-01T01:00:00Z", + "dag_rel_path": "/dev/null", + "bundle_info": {"name": "any-name", "version": "any-version"}, + } + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes) + + decoder = CommsDecoder(request_socket=r, log=None) + + msg = decoder._get_response() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.dag_rel_path == "/dev/null" + assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") + assert msg.start_date == timezone.datetime(2024, 12, 1, 1) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4f5e4cfc7ac9d..3fa05e7d7573f 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,13 +27,14 @@ import socket import sys import time -from io import BytesIO from operator import attrgetter +from random import randint from time import sleep from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import httpx +import msgspec import psutil import pytest from pytest_unordered import unordered @@ -56,6 +57,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, DagRunStateResult, DeferTask, @@ -97,17 +99,16 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.supervisor import ( - BUFFER_SIZE, ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, - mkpipe, set_supervisor_comms, supervise, ) -from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz if TYPE_CHECKING: @@ -136,18 +137,25 @@ def local_dag_bundle_cfg(path, name="my-bundle"): } +@pytest.fixture +def client_with_ti_start(make_ti_context): + client = MagicMock(spec=sdk_client.Client) + client.task_instances.start.return_value = make_ti_context() + return client + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @pytest.fixture(autouse=True) def disable_log_upload(self, spy_agency): spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) - def test_reading_from_pipes(self, captured_logs, time_machine): + def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start): def subprocess_main(): # This is run in the subprocess! - # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + # Ensure we follow the "protocol" and get the startup message before we do anything else + CommsDecoder()._get_response() import logging import warnings @@ -180,7 +188,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -228,12 +236,12 @@ def subprocess_main(): ] ) - def test_subprocess_sigkilled(self): + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid() def subprocess_main(): # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() assert os.getpid() != main_pid os.kill(os.getpid(), signal.SIGKILL) @@ -248,7 +256,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -285,7 +293,7 @@ def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -314,7 +322,7 @@ def test_no_heartbeat_in_overtime(self, spy_agency: kgb.SpyAgency, monkeypatch, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -340,7 +348,7 @@ def _on_child_started(self, *args, **kwargs): assert proc.wait() == 0 spy_agency.assert_spy_not_called(heartbeat_spy) - def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context): + def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, client_with_ti_start): """Test running a simple DAG in a subprocess and capturing the output.""" instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) @@ -355,11 +363,6 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker try_number=1, ) - # Create a mock client to assert calls to the client - # We assume the implementation of the client is correct and only need to check the calls - mock_client = mocker.Mock(spec=sdk_client.Client) - mock_client.task_instances.start.return_value = make_ti_context() - bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): exit_code = supervise( @@ -368,7 +371,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker token="", server="", dry_run=True, - client=mock_client, + client=client_with_ti_start, bundle_info=bundle_info, ) assert exit_code == 0, captured_logs @@ -498,7 +501,7 @@ def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, m monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.0) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() sleep(5) # Shouldn't get here exit(5) @@ -611,7 +614,6 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t stdin=mocker.MagicMock(), client=client, process=mock_process, - requests_fd=-1, ) time_now = tz.datetime(2024, 11, 28, 12, 0, 0) @@ -701,7 +703,6 @@ def test_overtime_handling( stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock(), - requests_fd=-1, ) # Set the terminal state and task end datetime @@ -738,7 +739,7 @@ def test_overtime_handling( ), ), ) - def test_exit_by_signal(self, monkeypatch, signal_to_raise, log_pattern, cap_structlog): + def test_exit_by_signal(self, signal_to_raise, log_pattern, cap_structlog, client_with_ti_start): def subprocess_main(): import faulthandler import os @@ -748,7 +749,7 @@ def subprocess_main(): faulthandler.disable() # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() os.kill(os.getpid(), signal_to_raise) @@ -762,7 +763,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -791,26 +792,26 @@ def test_cleanup_sockets_after_delay(self, monkeypatch, mocker, time_machine): stdin=mocker.MagicMock(), client=mocker.MagicMock(), process=mock_process, - requests_fd=-1, ) proc.selector = mocker.MagicMock() proc.selector.select.return_value = [] proc._exit_code = 0 - proc._num_open_sockets = 1 + # Create a dummy placeholder in the open socket weekref + proc._open_sockets[mocker.MagicMock()] = "test placeholder" proc._process_exit_monotonic = time.monotonic() mocker.patch.object( ActivitySubprocess, "_cleanup_open_sockets", - side_effect=lambda: setattr(proc, "_num_open_sockets", 0), + side_effect=lambda: setattr(proc, "_open_sockets", {}), ) time_machine.shift(2) proc._monitor_subprocess() - assert proc._num_open_sockets == 0 + assert len(proc._open_sockets) == 0 class TestWatchedSubprocessKill: @@ -829,7 +830,6 @@ def watched_subprocess(self, mocker, mock_process): stdin=mocker.Mock(), client=mocker.Mock(), process=mock_process, - requests_fd=-1, ) # Mock the selector mock_selector = mocker.Mock(spec=selectors.DefaultSelector) @@ -888,7 +888,7 @@ def test_kill_process_custom_signal(self, watched_subprocess, mock_process): ), ], ) - def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, captured_logs, monkeypatch): + def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, client_with_ti_start): def subprocess_main(): import signal @@ -905,7 +905,7 @@ def _handler(sig, frame): signal.signal(signal.SIGINT, _handler) signal.signal(signal.SIGTERM, _handler) try: - sys.stdin.readline() + CommsDecoder()._get_response() print("Ready") sleep(10) except Exception as e: @@ -919,7 +919,7 @@ def _handler(sig, frame): dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -1073,16 +1073,15 @@ def test_max_wait_time_calculation_edge_cases( class TestHandleRequest: @pytest.fixture def watched_subprocess(self, mocker): - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socket.socketpair() subprocess = ActivitySubprocess( process_log=mocker.MagicMock(), id=TI_ID, pid=12345, - stdin=write_end, # this is the writer side + stdin=write_end, client=mocker.Mock(), process=mocker.Mock(), - requests_fd=-1, ) return subprocess, read_end @@ -1091,7 +1090,7 @@ def watched_subprocess(self, mocker): @pytest.mark.parametrize( [ "message", - "expected_buffer", + "expected_body", "client_attr_path", "method_arg", "method_kwarg", @@ -1101,7 +1100,7 @@ def watched_subprocess(self, mocker): [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1111,7 +1110,12 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n', + { + "conn_id": "test_conn", + "conn_type": "mysql", + "password": "password", + "type": "ConnectionResult", + }, "connections.get", ("test_conn",), {}, @@ -1121,7 +1125,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "schema": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1131,7 +1135,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', + {"key": "test_key", "value": "test_value", "type": "VariableResult"}, "variables.get", ("test_key",), {}, @@ -1141,7 +1145,7 @@ def watched_subprocess(self, mocker): ), pytest.param( PutVariable(key="test_key", value="test_value", description="test_description"), - b"", + None, "variables.set", ("test_key", "test_value", "test_description"), {}, @@ -1151,7 +1155,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeleteVariable(key="test_key"), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "variables.delete", ("test_key",), {}, @@ -1161,7 +1165,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeferTask(next_method="execute_callback", classpath="my-classpath"), - b"", + None, "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), {}, @@ -1174,7 +1178,7 @@ def watched_subprocess(self, mocker): reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), end_date=timezone.parse("2024-10-31T12:00:00Z"), ), - b"", + None, "task_instances.reschedule", ( TI_ID, @@ -1190,7 +1194,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1202,7 +1206,7 @@ def watched_subprocess(self, mocker): GetXCom( dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 ), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2, False), {}, @@ -1212,7 +1216,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1228,7 +1232,7 @@ def watched_subprocess(self, mocker): key="test_key", include_prior_dates=True, ), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, True), {}, @@ -1244,7 +1248,7 @@ def watched_subprocess(self, mocker): key="test_key", value='{"key": "test_key", "value": {"key2": "value2"}}', ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1269,7 +1273,7 @@ def watched_subprocess(self, mocker): value='{"key": "test_key", "value": {"key2": "value2"}}', map_index=2, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1295,7 +1299,7 @@ def watched_subprocess(self, mocker): map_index=2, mapped_length=3, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1319,15 +1323,9 @@ def watched_subprocess(self, mocker): key="test_key", map_index=2, ), - b"", + None, "xcoms.delete", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - 2, - ), + ("test_dag", "test_run", "test_task", "test_key", 2), {}, OKResponse(ok=True), None, @@ -1337,7 +1335,7 @@ def watched_subprocess(self, mocker): # if it can handle TaskState message pytest.param( TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), - b"", + None, "", (), {}, @@ -1349,7 +1347,7 @@ def watched_subprocess(self, mocker): RetryTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" ), - b"", + None, "task_instances.retry", (), { @@ -1363,7 +1361,7 @@ def watched_subprocess(self, mocker): ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), - b"", + None, "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), {}, @@ -1373,7 +1371,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByName(name="asset"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"name": "asset"}, @@ -1383,7 +1381,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByUri(uri="s3://bucket/obj"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"uri": "s3://bucket/obj"}, @@ -1393,11 +1391,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": "test"}, @@ -1416,11 +1420,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name=None), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": None}, @@ -1439,11 +1449,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri=None, name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": None, "name": "test"}, @@ -1462,11 +1478,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAssetAlias(alias_name="test_alias"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"alias_name": "test_alias"}, @@ -1485,7 +1507,10 @@ def watched_subprocess(self, mocker): ), pytest.param( ValidateInletsAndOutlets(ti_id=TI_ID), - b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n', + { + "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], + "type": "InactiveAssetsResult", + }, "task_instances.validate_inlets_and_outlets", (TI_ID,), {}, @@ -1499,7 +1524,7 @@ def watched_subprocess(self, mocker): SucceedTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" ), - b"", + None, "task_instances.succeed", (), { @@ -1515,11 +1540,13 @@ def watched_subprocess(self, mocker): ), pytest.param( GetPrevSuccessfulDagRun(ti_id=TI_ID), - ( - b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",' - b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",' - b'"type":"PrevSuccessfulDagRunResult"}\n' - ), + { + "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), + "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), + "start_date": timezone.parse("2025-01-10T12:00:00Z"), + "end_date": timezone.parse("2025-01-10T14:00:00Z"), + "type": "PrevSuccessfulDagRunResult", + }, "task_instances.get_previous_successful_dagrun", (TI_ID,), {}, @@ -1540,7 +1567,7 @@ def watched_subprocess(self, mocker): logical_date=timezone.datetime(2025, 1, 1), reset_dag_run=True, ), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "dag_runs.trigger", ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), {}, @@ -1549,8 +1576,9 @@ def watched_subprocess(self, mocker): id="dag_run_trigger", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR TriggerDagRun(dag_id="test_dag", run_id="test_run"), - b'{"error":"DAGRUN_ALREADY_EXISTS","detail":null,"type":"ErrorResponse"}\n', + {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, "dag_runs.trigger", ("test_dag", "test_run", None, None, False), {}, @@ -1560,7 +1588,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDagRunState(dag_id="test_dag", run_id="test_run"), - b'{"state":"running","type":"DagRunStateResult"}\n', + {"state": "running", "type": "DagRunStateResult"}, "dag_runs.get_state", ("test_dag", "test_run"), {}, @@ -1570,7 +1598,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskRescheduleStartDate(ti_id=TI_ID), - b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n', + {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": "TaskRescheduleStartDate"}, "task_instances.get_reschedule_start_date", (TI_ID, 1), {}, @@ -1580,7 +1608,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), - b'{"count":2,"type":"TICount"}\n', + {"count": 2, "type": "TICount"}, "task_instances.get_count", (), { @@ -1598,7 +1626,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDRCount(dag_id="test_dag", states=["success", "failed"]), - b'{"count":2,"type":"DRCount"}\n', + {"count": 2, "type": "DRCount"}, "dag_runs.get_count", (), { @@ -1613,7 +1641,10 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), - b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n', + { + "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, + "type": "TaskStatesResult", + }, "task_instances.get_task_states", (), { @@ -1636,7 +1667,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=0, ), - b'{"root":"test_value","type":"XComSequenceIndexResult"}\n', + {"root": "test_value", "type": "XComSequenceIndexResult"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 0), {}, @@ -1645,6 +1676,7 @@ def watched_subprocess(self, mocker): id="get_xcom_seq_item", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR GetXComSequenceItem( key="test_key", dag_id="test_dag", @@ -1652,7 +1684,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=2, ), - b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + {"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 2), {}, @@ -1670,7 +1702,7 @@ def watched_subprocess(self, mocker): stop=None, step=None, ), - b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n', + {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", ("test_dag", "test_run", "test_task", "test_key", None, None, None), {}, @@ -1687,7 +1719,7 @@ def test_handle_requests( mocker, time_machine, message, - expected_buffer, + expected_body, client_attr_path, method_arg, method_kwarg, @@ -1715,8 +1747,9 @@ def test_handle_requests( generator = watched_subprocess.handle_requests(log=mocker.Mock()) # Initialize the generator next(generator) - msg = message.model_dump_json().encode() + b"\n" - generator.send(msg) + + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) + generator.send(req_frame) if mask_secret_args: mock_mask_secret.assert_called_with(*mask_secret_args) @@ -1729,33 +1762,23 @@ def test_handle_requests( # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(BUFFER_SIZE) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError, socket.timeout): - # no response written, valid for some message types like setters and TI operations. - pass + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id # Verify the response was added to the buffer - assert val == expected_buffer + assert frame.body == expected_body # Verify the response is correctly decoded # This is important because the subprocess/task runner will read the response # and deserialize it to the correct message type - # Only decode the buffer if it contains data. An empty buffer implies no response was written. - if not val and (mock_response and not isinstance(mock_response, OKResponse)): - pytest.fail("Expected a response, but got an empty buffer.") - - if val: + if frame.body is not None: # Using BytesIO to simulate a readable stream for CommsDecoder. - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) - assert decoder.get_message() == mock_response + decoder = CommsDecoder(request_socket=None).body_decoder + assert decoder.validate_python(frame.body) == mock_response def test_handle_requests_api_server_error(self, watched_subprocess, mocker): """Test that API server errors are properly handled and sent back to the task.""" @@ -1777,28 +1800,32 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): next(generator) - msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")).model_dump_json().encode() + b"\n" - generator.send(msg) + msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")) + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg) + generator.send(req_frame) - # Read response directly from the reader socket + # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - val += read_socket.recv(4096) - except (BlockingIOError, TimeoutError): - pass - - assert val == ( - b'{"error":"API_SERVER_ERROR","detail":{"status_code":500,"message":"API Server Error",' - b'"detail":{"detail":"Internal Server Error"}},"type":"ErrorResponse"}\n' - ) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id + + assert frame.error == { + "error": "API_SERVER_ERROR", + "detail": { + "status_code": 500, + "message": "API Server Error", + "detail": {"detail": "Internal Server Error"}, + }, + "type": "ErrorResponse", + } # Verify the error can be decoded correctly - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) + comms = CommsDecoder(request_socket=None) with pytest.raises(AirflowRuntimeError) as exc_info: - decoder.get_message() + comms._from_frame(frame) assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR assert exc_info.value.error.detail == { @@ -1871,7 +1898,6 @@ def _handle_request(self, msg, log): supervisor = MinimalSupervisor( id="test", pid=123, - requests_fd=-1, process=MagicMock(), process_log=MagicMock(), client=MagicMock(), diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index d786fdaa5b2bc..caecc1038818a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -22,11 +22,9 @@ import json import os import textwrap -import uuid from collections.abc import Iterable from datetime import datetime, timedelta from pathlib import Path -from socket import socketpair from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -103,7 +101,6 @@ VariableAccessor, ) from airflow.sdk.execution_time.task_runner import ( - CommsDecoder, RuntimeTaskInstance, TaskRunnerMarker, _push_xcom_if_needed, @@ -139,47 +136,6 @@ def execute(self, context): print(f"Hello World {task_id}!") -class TestCommsDecoder: - """Test the communication between the subprocess and the "supervisor".""" - - @pytest.mark.usefixtures("disable_capturing") - def test_recv_StartupDetails(self): - r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"type":"StartupDetails", "ti": {' - b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", ' - b'"dag_id": "c"}, "ti_context":{"dag_run":{"dag_id":"c","run_id":"b",' - b'"logical_date":"2024-12-01T01:00:00Z",' - b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' - b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,' - b'"run_type":"manual","conf":null,"consumed_asset_events":[]},' - b'"max_tries":0,"should_retry":false,"variables":null,"connections":null},"file": "/dev/null",' - b'"start_date":"2024-12-01T01:00:00Z", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' - b'"any-name", "version": "any-version"}, "requests_fd": ' - + str(w2.fileno()).encode("ascii") - + b"}\n" - ) - - decoder = CommsDecoder(input=r.makefile("r")) - - msg = decoder.get_message() - assert isinstance(msg, StartupDetails) - assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") - assert msg.ti.task_id == "a" - assert msg.ti.dag_id == "c" - assert msg.dag_rel_path == "/dev/null" - assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") - assert msg.start_date == timezone.datetime(2024, 12, 1, 1) - - # Since this was a StartupDetails message, the decoder should open the other socket - assert decoder.request_socket is not None - assert decoder.request_socket.writable() - assert decoder.request_socket.fileno() == w2.fileno() - - def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( @@ -192,7 +148,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -248,7 +203,6 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -302,7 +256,6 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -573,7 +526,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ), bundle_info=FAKE_BUNDLE, dag_rel_path="", - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -689,7 +641,6 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -834,7 +785,6 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(dag_id=dag_id, run_id="c"), start_date=timezone.utcnow(), ) @@ -2234,7 +2184,6 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) From 22a44cbce14b0ce5fdaab0466549c8c336f5dc9a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 13 Jun 2025 17:43:47 +0100 Subject: [PATCH 02/21] fixup! Switch the Supervisor/task process from line-based to length-prefixed --- task-sdk/src/airflow/sdk/definitions/asset/decorators.py | 2 ++ task-sdk/src/airflow/sdk/execution_time/supervisor.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index bfe5ac2fc960f..44479fbb9cd42 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -74,6 +74,8 @@ def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: def _fetch_asset(name: str) -> Asset: resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + if resp is None: + raise RuntimeError("Empty non-error response received") if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) return Asset(**resp.model_dump(exclude={"type"})) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index cf37be96a318d..3c0103741d527 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -45,7 +45,6 @@ ) from uuid import UUID -import aiologic import attrs import httpx import msgspec @@ -1264,7 +1263,6 @@ class InProcessSupervisorComms: supervisor: InProcessTestSupervisor messages: deque[BaseModel] = attrs.field(factory=deque) - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock) def get_message(self) -> BaseModel: """Get a message from the supervisor. Blocks until a message is available.""" @@ -1293,7 +1291,6 @@ class InProcessTestSupervisor(ActivitySubprocess): """A supervisor that runs tasks in-process for easier testing.""" comms: InProcessSupervisorComms = attrs.field(init=False) - stdin = attrs.field(init=False) @classmethod def start( # type: ignore[override] From 1372214abfd52a59e8a468adf06160876495c424 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 13 Jun 2025 19:35:13 +0530 Subject: [PATCH 03/21] User impersonation wip --- airflow-core/src/airflow/settings.py | 4 +- .../airflow/sdk/execution_time/task_runner.py | 40 ++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index c137c80b56fba..08d7ac7af5ba7 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -616,7 +616,9 @@ def initialize(): # The webservers import this file from models.py with the default settings. if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None): - configure_orm() + is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" + if not is_worker: + configure_orm() configure_action_logging() # mask the sensitive_config_values diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 0268e24f2177b..a6fe61b0cf781 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -35,11 +35,12 @@ import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, ConfigDict, Field, JsonValue +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter +from airflow.configuration import conf from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException +from airflow.exceptions import AirflowConfigException, AirflowInactiveAssetInInletOrOutletException from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, @@ -640,7 +641,8 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # deeply nested execution stack. # - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. -SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] +log = structlog.get_logger(logger_name="task") +SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) # State machine! # 1. Start up (receive details from supervisor) @@ -651,13 +653,21 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent # in response to us sending a request. - msg = SUPERVISOR_COMMS._get_response() + log = structlog.get_logger(logger_name="task") + + if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"): + # re exec process + log.info("Using serialized startup message from environment") + msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) + log.info("Trying to open in rexec", fd=msg) + SUPERVISOR_COMMS.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + else: + msg = SUPERVISOR_COMMS._get_response() + log.info("Received startup message", msg_type=type(msg).__name__) if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") - log = structlog.get_logger(logger_name="task") - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 os_type = sys.platform if os_type == "darwin": @@ -677,6 +687,21 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: ti.log_url = get_log_url_from_ti(ti) log.debug("DAG file parsed", file=msg.dag_rel_path) + try: + run_as_user = getattr(ti.task, "run_as_user", None) or conf.get("core", "default_impersonation") + except AirflowConfigException: + run_as_user = None + + if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user: + # re-exec process + os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" + # store statrup messgae + os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() + os.set_inheritable(SUPERVISOR_COMMS.request_socket.fileno(), True) + log.info("Running command", command=["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) + os.execvp("sudo", ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) + return None, None, None + return ti, ti.get_template_context(), log @@ -1238,9 +1263,6 @@ def main(): # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! log = structlog.get_logger(logger_name="task") - global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) - try: ti, context, log = startup() with BundleVersionLock( From dcc491c235525e9dd05cb4f6dc3c4dabc12467e3 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 13 Jun 2025 19:49:46 +0530 Subject: [PATCH 04/21] xcom push doesnt work -- rest is ok --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index a6fe61b0cf781..effb4db8abcd1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -642,7 +642,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. log = structlog.get_logger(logger_name="task") -SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) +SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] # Will be initialized when needed # State machine! # 1. Start up (receive details from supervisor) @@ -660,7 +660,6 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log.info("Using serialized startup message from environment") msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) log.info("Trying to open in rexec", fd=msg) - SUPERVISOR_COMMS.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) else: msg = SUPERVISOR_COMMS._get_response() log.info("Received startup message", msg_type=type(msg).__name__) @@ -1263,6 +1262,9 @@ def main(): # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! log = structlog.get_logger(logger_name="task") + global SUPERVISOR_COMMS + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) + try: ti, context, log = startup() with BundleVersionLock( From d48608201037ea180404777e64307617b5825307 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 16 Jun 2025 11:02:46 +0530 Subject: [PATCH 05/21] cleaning up the code --- .../airflow/sdk/execution_time/task_runner.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index effb4db8abcd1..018b29e3e9c5c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -641,8 +641,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # deeply nested execution stack. # - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. -log = structlog.get_logger(logger_name="task") -SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] # Will be initialized when needed +SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] # State machine! # 1. Start up (receive details from supervisor) @@ -656,13 +655,13 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log = structlog.get_logger(logger_name="task") if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"): - # re exec process - log.info("Using serialized startup message from environment") + # entrypoint of re-exec process msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) - log.info("Trying to open in rexec", fd=msg) + log.info("Using serialized startup message from environment", msg=msg) else: + # normal entry point msg = SUPERVISOR_COMMS._get_response() - log.info("Received startup message", msg_type=type(msg).__name__) + log.info("Received startup message from supervisor", msg=msg) if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") @@ -692,13 +691,16 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: run_as_user = None if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user: - # re-exec process + # enters here for re-exec process os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" - # store statrup messgae + # store startup message in environment for re-exec process os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() os.set_inheritable(SUPERVISOR_COMMS.request_socket.fileno(), True) + log.info("Running command", command=["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) os.execvp("sudo", ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) + + # ideally, we should never reach here, but if we do, we should return None, None, None return None, None, None return ti, ti.get_template_context(), log From 595b5500e2ad9e198e6a965b3632d982658adc22 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 16 Jun 2025 17:55:53 +0530 Subject: [PATCH 06/21] using getattr instead --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 018b29e3e9c5c..430e53778b8a3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -643,6 +643,14 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # accessible wherever needed during task execution without modifying every layer of the call stack. SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] +def __getattr__(name: str) -> CommsDecoder[ToTask, ToSupervisor]: + """Lazy initialization of supervisor comms.""" + global SUPERVISOR_COMMS + if name == "SUPERVISOR_COMMS": + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=structlog.get_logger(logger_name="task")) + return SUPERVISOR_COMMS + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + # State machine! # 1. Start up (receive details from supervisor) # 2. Execution (run task code, possibly send requests) From 4975725d5d6e34ccfbeec614c5cff5785137379b Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 10 Jun 2025 11:55:46 +0100 Subject: [PATCH 07/21] Switch the Supervisor/task process from line-based to length-prefixed The existing JSON Lines based approach had two major drawbacks 1. In the case of really large lines (in the region of 10 or 20MB) the python line buffering could _sometimes_ result in a partial read 2. The JSON based approach didn't have the ability to add any metadata (such as errors). 3. Not every message type/call-site waited for a response, which meant those client functions could never get told about an error One of the ways this line-based approach fell down was if you suddenly tried to run 100s of triggers at the same time you would get an error like this: ``` Traceback (most recent call last): File "/Users/ash/.local/share/uv/python/cpython-3.12.7-macos-aarch64-none/lib/python3.12/asyncio/streams.py", line 568, in readline line = await self.readuntil(sep) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/ash/.local/share/uv/python/cpython-3.12.7-macos-aarch64-none/lib/python3.12/asyncio/streams.py", line 663, in readuntil raise exceptions.LimitOverrunError( asyncio.exceptions.LimitOverrunError: Separator is found, but chunk is longer than limit ``` The other way this caused problems was if you parse a large dag (as in one with 20k tasks or more) the DagFileProcessor could end up getting a partial read which would be invalid JSON. This changes the communications protocol in in a couple of ways. First off at the python level the separate send and receive methods in the client/task side have been removed and replaced with a single `send()` that sends the request, reads the response and raises an error if one is returned. (But note, right now almost nothing in the supervisor side sets the error, that will be a future PR.) Secondly the JSON Lines approach has been changed from a line-based protocol to a binary "frame" one. The protocol (which is the same for whichever side is sending) is length-prefixed, i.e. we first send the length of the data as a 4byte big-endian integer, followed by the data itself. This should remove the possibility of JSON parse errors due to reading incomplete lines Finally the last change made in this PR is to remove the "extra" requests socket/channel. Upon closer examination with this comms path I realised that this socket is unnecessary: Since we are in 100% control of the client side we can make use of the bi-directional nature of `socketpair` and save file handles. This also happens to help the `run_as_user` feature which is currently broken, as without extra config to `sudoers` file, `sudo` will close all filehandles other than stdin, stdout, and stderr -- so by introducing this change we make it easier to re-add run_as_user support. In order to support this in the DagFileProcessor (as the fact that the proc manager uses a single selector for multiple processes) means I have moved the `on_close` callback to be part of the object we store in the `selector` object in the supervisors, previoulsy it was the "on_read" callback, now we store a tuple of `(on_read, on_close)` and on_close is called once universally. This also changes the way comms are handled from the (async) TriggerRunner process. Previously we had a sync+async lock, but that made it possible to end up deadlocking things. The change now is to have `send` on `TriggerCommsDecoder` "go back" to the async even loop via `async_to_sync`, so that only async code deals with the socket, and we can use an async lock (rather than the hybrid sync and async lock we tried before). This seems to help the deadlock issue, but I'm not 100% sure it will remove it entirely, but it makes it much much harder to hit - I've not been able to reprouce it with this change --- .../src/airflow/dag_processing/manager.py | 9 +- .../src/airflow/dag_processing/processor.py | 33 +- .../src/airflow/jobs/triggerer_job_runner.py | 124 ++++--- .../tests/unit/dag_processing/test_manager.py | 82 ++--- .../unit/dag_processing/test_processor.py | 51 +-- .../tests/unit/jobs/test_triggerer_job.py | 34 +- .../src/tests_common/pytest_plugin.py | 15 +- task-sdk/pyproject.toml | 1 - task-sdk/src/airflow/sdk/bases/xcom.py | 98 ++---- .../sdk/definitions/asset/decorators.py | 14 +- .../src/airflow/sdk/execution_time/comms.py | 164 ++++++++- .../src/airflow/sdk/execution_time/context.py | 76 ++-- .../sdk/execution_time/lazy_sequence.py | 60 ++-- .../airflow/sdk/execution_time/supervisor.py | 286 ++++++++------- .../airflow/sdk/execution_time/task_runner.py | 215 ++++-------- .../task_sdk/execution_time/test_comms.py | 83 +++++ .../task_sdk/execution_time/test_context.py | 85 ++--- .../execution_time/test_lazy_sequence.py | 140 ++++---- .../execution_time/test_supervisor.py | 328 ++++++++++-------- .../execution_time/test_task_runner.py | 293 +++++----------- 20 files changed, 1113 insertions(+), 1078 deletions(-) create mode 100644 task-sdk/tests/task_sdk/execution_time/test_comms.py diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 272e086d90903..8af35af62b7ef 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -77,6 +77,8 @@ from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: + from socket import socket + from sqlalchemy.orm import Session from airflow.callbacks.callback_requests import CallbackRequest @@ -388,7 +390,7 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): """ events = self.selector.select(timeout=timeout) for key, _ in events: - socket_handler = key.data + socket_handler, on_close = key.data # BrokenPipeError should be caught and treated as if the handler returned false, similar # to EOF case @@ -397,8 +399,9 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): except BrokenPipeError: need_more = False if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() def _queue_requested_files_for_parsing(self) -> None: """Queue any files requested for parsing as requested by users via UI/API.""" diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 5f69082c758aa..4341d6d91390e 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel): bundle_path: Path """Passing bundle path around lets us figure out relative file path.""" - requests_fd: int callback_requests: list[CallbackRequest] = Field(default_factory=list) type: Literal["DagFileParseRequest"] = "DagFileParseRequest" @@ -102,18 +101,16 @@ class DagFileParsingResult(BaseModel): def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time import comms, task_runner # Parse DAG file, send JSON back up! - comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager]( - input=sys.stdin, - decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), + comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) - msg = comms_decoder.get_message() + msg = comms_decoder._get_response() if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") @@ -125,7 +122,7 @@ def _parse_file_entrypoint(): result = _parse_file(msg, log) if result is not None: - comms_decoder.send_request(log, result) + comms_decoder.send(result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: @@ -266,20 +263,18 @@ def _on_child_started( msg = DagFileParseRequest( file=os.fspath(path), bundle_path=bundle_path, - requests_fd=self._requests_fd, callback_requests=callbacks, ) - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) - def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse resp: BaseModel | None = None dump_opts = {} if isinstance(msg, DagFileParsingResult): self.parsing_result = msg - return - if isinstance(msg, GetConnection): + elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) @@ -301,10 +296,16 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # resp = self.client.variables.delete(msg.key) else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) @property def is_ready(self) -> bool: @@ -312,7 +313,7 @@ def is_ready(self) -> bool: # Process still alive, def can't be finished yet return False - return self._num_open_sockets == 0 + return not self._open_sockets def wait(self) -> int: raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 5df6a03261523..e829a4fa6eae0 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -28,6 +28,7 @@ from collections.abc import Generator, Iterable from contextlib import suppress from datetime import datetime +from socket import socket from traceback import format_exception from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypedDict, Union @@ -43,6 +44,7 @@ from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger from airflow.sdk.execution_time.comms import ( + CommsDecoder, ConnectionResult, DagRunStateResult, DRCount, @@ -58,6 +60,7 @@ TICount, VariableResult, XComResult, + _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.stats import Stats @@ -70,8 +73,6 @@ from airflow.utils.session import provide_session if TYPE_CHECKING: - from socket import socket - from sqlalchemy.orm import Session from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -181,7 +182,6 @@ class messages: class StartTriggerer(BaseModel): """Tell the async trigger runner process to start, and where to send status update messages.""" - requests_fd: int type: Literal["StartTriggerer"] = "StartTriggerer" class TriggerStateChanges(BaseModel): @@ -295,7 +295,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): """ TriggerRunnerSupervisor is responsible for monitoring the subprocess and marshalling DB access. - This class (which runs in the main process) is responsible for querying the DB, sending RunTrigger + This class (which runs in the main/sync process) is responsible for querying the DB, sending RunTrigger workload messages to the subprocess, and collecting results and updating them in the DB. """ @@ -342,8 +342,8 @@ def start( # type: ignore[override] ): proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs) - msg = messages.StartTriggerer(requests_fd=proc._requests_fd) - proc.send_msg(msg) + msg = messages.StartTriggerer() + proc.send_msg(msg, in_response_to=0) return proc @functools.cached_property @@ -355,7 +355,7 @@ def client(self) -> Client: client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, TaskStatesResponse, @@ -454,8 +454,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - else: raise ValueError(f"Unknown message type {type(msg)}") - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def run(self) -> None: """Run synchronously and handle all database reads/writes.""" @@ -628,7 +627,7 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke ), ) - def _process_log_messages_from_subprocess(self) -> Generator[None, bytes, None]: + def _process_log_messages_from_subprocess(self) -> Generator[None, bytes | bytearray, None]: import msgspec from structlog.stdlib import NAME_TO_LEVEL @@ -691,14 +690,58 @@ class TriggerDetails(TypedDict): events: int +@attrs.define(kw_only=True) +class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): + _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer") + _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader") + + body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field( + factory=lambda: TypeAdapter(ToTriggerRunner), repr=False + ) + + _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + + def _read_frame(self): + from asgiref.sync import async_to_sync + + return async_to_sync(self._aread_frame)() + + def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + from asgiref.sync import async_to_sync + + return async_to_sync(self.asend)(msg) + + async def _aread_frame(self): + len_bytes = await self._async_reader.readexactly(4) + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = await self._async_reader.readexactly(len) + return self.resp_decoder.decode(buffer) + + async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None: + frame = await self._aread_frame() + if frame.id != expect_id: + # Given the lock we take out in `asend`, this _shouldn't_ be possible, but I'd rather fail with + # this explicit error return the wrong type of message back to a Trigger + raise RuntimeError(f"Response read out of order! Got {frame.id=}, {expect_id=}") + return self._from_frame(frame) + + async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + async with self._lock: + self._async_writer.write(bytes) + + return await self._aget_response(frame.id) + + class TriggerRunner: """ Runtime environment for all triggers. - Mainly runs inside its own thread, where it hands control off to an asyncio - event loop, but is also sometimes interacted with from the main thread - (where all the DB queries are done). All communication between threads is - done via Deques. + Mainly runs inside its own process, where it hands control off to an asyncio + event loop. All communication between this and it's (sync) supervisor is done via sockets """ # Maps trigger IDs to their running tasks and other info @@ -726,10 +769,7 @@ class TriggerRunner: # TODO: connect this to the parent process log: FilteringBoundLogger = structlog.get_logger() - requests_sock: asyncio.StreamWriter - response_sock: asyncio.StreamReader - - decoder: TypeAdapter[ToTriggerRunner] + comms_decoder: TriggerCommsDecoder def __init__(self): super().__init__() @@ -740,7 +780,6 @@ def __init__(self): self.events = deque() self.failed_triggers = deque() self.job_id = None - self.decoder = TypeAdapter(ToTriggerRunner) def run(self): """Sync entrypoint - just run a run in an async loop.""" @@ -796,36 +835,21 @@ async def init_comms(self): """ from airflow.sdk.execution_time import task_runner - loop = asyncio.get_event_loop() + # Yes, we read and write to stdin! It's a socket, not a normal stdin. + reader, writer = await asyncio.open_connection(sock=socket(fileno=0)) - comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]( - input=sys.stdin, - decoder=self.decoder, + self.comms_decoder = TriggerCommsDecoder( + async_writer=writer, + async_reader=reader, ) - task_runner.SUPERVISOR_COMMS = comms_decoder - - async def connect_stdin() -> asyncio.StreamReader: - reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, sys.stdin) - return reader - - self.response_sock = await connect_stdin() + task_runner.SUPERVISOR_COMMS = self.comms_decoder - line = await self.response_sock.readline() + msg = await self.comms_decoder._aget_response(expect_id=0) - msg = self.decoder.validate_json(line) if not isinstance(msg, messages.StartTriggerer): raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - writer_transport, writer_protocol = await loop.connect_write_pipe( - lambda: asyncio.streams.FlowControlMixin(loop=loop), - comms_decoder.request_socket, - ) - self.requests_sock = asyncio.streams.StreamWriter(writer_transport, writer_protocol, None, loop) - async def create_triggers(self): """Drain the to_create queue and create all new triggers that have been requested in the DB.""" while self.to_create: @@ -934,8 +958,6 @@ async def cleanup_finished_triggers(self) -> list[int]: return finished_ids async def sync_state_to_supervisor(self, finished_ids: list[int]): - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Copy out of our deques in threadsafe manner to sync state with parent events_to_send = [] while self.events: @@ -961,19 +983,17 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]): if not finished_ids: msg.finished = None - # Block triggers from making any requests for the duration of this - async with SUPERVISOR_COMMS.lock: - # Tell the monitor that we've finished triggers so it can update things - self.requests_sock.write(msg.model_dump_json(exclude_none=True).encode() + b"\n") - line = await self.response_sock.readline() - - if line == b"": # EoF received! + # Tell the monitor that we've finished triggers so it can update things + try: + resp = await self.comms_decoder.asend(msg) + except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") + return + raise - resp = self.decoder.validate_json(line) if not isinstance(resp, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got f{type(msg)}") + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") self.to_create.extend(resp.to_create) self.to_cancel.extend(resp.to_cancel) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index c9e974b8cdb9a..98204033ced9e 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -30,10 +30,11 @@ from datetime import datetime, timedelta from logging.config import dictConfig from pathlib import Path -from socket import socket +from socket import socket, socketpair from unittest import mock from unittest.mock import MagicMock +import msgspec import pytest import time_machine from sqlalchemy import func, select @@ -54,7 +55,6 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.execution_time.supervisor import mkpipe from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import create_session @@ -138,20 +138,19 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces logger_filehandle = MagicMock() proc.create_time.return_value = time.time() proc.wait.return_value = 0 - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socketpair() ret = DagFileProcessorProcess( process_log=MagicMock(), id=uuid7(), pid=1234, process=proc, stdin=write_end, - requests_fd=123, logger_filehandle=logger_filehandle, client=MagicMock(), ) if start_time: ret.start_time = start_time - ret._num_open_sockets = 0 + ret._open_sockets.clear() return ret, read_end @pytest.fixture @@ -552,18 +551,17 @@ def test_kill_timed_out_processors_no_kill(self): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.parametrize( - ["callbacks", "path", "expected_buffer"], + ["callbacks", "path", "expected_body"], [ pytest.param( [], "/opt/airflow/dags/test_dag.py", - b"{" - b'"file":"/opt/airflow/dags/test_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,' - b'"callback_requests":[],' - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/test_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [], + "type": "DagFileParseRequest", + }, ), pytest.param( [ @@ -577,44 +575,40 @@ def test_kill_timed_out_processors_no_kill(self): ) ], "/opt/airflow/dags/dag_callback_dag.py", - b"{" - b'"file":"/opt/airflow/dags/dag_callback_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,"callback_requests":' - b"[" - b"{" - b'"filepath":"dag_callback_dag.py",' - b'"bundle_name":"testing",' - b'"bundle_version":null,' - b'"msg":null,' - b'"dag_id":"dag_id",' - b'"run_id":"run_id",' - b'"is_failure_callback":false,' - b'"type":"DagCallbackRequest"' - b"}" - b"]," - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/dag_callback_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [ + { + "filepath": "dag_callback_dag.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": None, + "dag_id": "dag_id", + "run_id": "run_id", + "is_failure_callback": False, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ), ], ) - def test_serialize_callback_requests(self, callbacks, path, expected_buffer): + def test_serialize_callback_requests(self, callbacks, path, expected_body): + from airflow.sdk.execution_time.comms import _ResponseFrame + processor, read_socket = self.mock_processor() processor._on_child_started(callbacks, path, bundle_path=Path("/opt/airflow/dags")) read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(4096) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError): - # no response written, valid for some message types. - pass - - assert val == expected_buffer + # Read response from the read end of the socket + read_socket.settimeout(0.1) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.body == expected_body @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 28ce7a8c23fe6..eb6a58dfcfbfb 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -42,7 +42,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.api.client import Client -from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -87,7 +87,6 @@ def _process_file( DagFileParseRequest( file=file_path, bundle_path=TEST_DAG_FOLDER, - requests_fd=1, callback_requests=callback_requests or [], ), log=structlog.get_logger(), @@ -393,26 +392,38 @@ def disable_capturing(): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.usefixtures("disable_capturing") -def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency): +def test_parse_file_entrypoint_parses_dag_callbacks(mocker): r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags","requests_fd":' - + str(w2.fileno()).encode("ascii") - + b',"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' - b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": ' - b'"manual__2024-12-30T21:02:55.203691+00:00", ' - b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": "DagFileParseRequest"}\n' + + frame = comms._ResponseFrame( + id=1, + body={ + "file": "/files/dags/wait.py", + "bundle_path": "/files/dags", + "callback_requests": [ + { + "filepath": "wait.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": "task_failure", + "dag_id": "wait_to_fail", + "run_id": "manual__2024-12-30T21:02:55.203691+00:00", + "is_failure_callback": True, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ) + bytes = frame.as_bytes() + w.sendall(bytes) - decoder = CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - input=r.makefile("r"), - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + request_socket=r, + body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), ) - msg = decoder.get_message() + msg = decoder._get_response() assert isinstance(msg, DagFileParseRequest) assert msg.file == "/files/dags/wait.py" assert msg.callback_requests == [ @@ -455,7 +466,7 @@ def fake_collect_dags(self, *args, **kwargs): ) ] _parse_file( - DagFileParseRequest(file="A", bundle_path="no matter", requests_fd=1, callback_requests=requests), + DagFileParseRequest(file="A", bundle_path="no matter", callback_requests=requests), log=structlog.get_logger(), ) @@ -489,8 +500,6 @@ def fake_collect_dags(self, *args, **kwargs): bundle_version=None, ) ] - _parse_file( - DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() - ) + _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) assert called is True diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d3c9ae6f27dc9..d3cecb5fc94e6 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -34,6 +34,7 @@ from airflow.hooks.base import BaseHook from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( + TriggerCommsDecoder, TriggererJobRunner, TriggerRunner, TriggerRunnerSupervisor, @@ -172,7 +173,6 @@ def builder(job=None): pid=process.pid, stdin=mocker.Mock(), process=process, - requests_fd=-1, capacity=10, ) # Mock the selector @@ -302,10 +302,9 @@ async def test_invalid_trigger(self, supervisor_builder): id=1, ti=None, classpath="fake.classpath", encrypted_kwargs={} ) trigger_runner = TriggerRunner() - trigger_runner.requests_sock = MagicMock() - trigger_runner.response_sock = AsyncMock() - trigger_runner.response_sock.readline.return_value = ( - b'{"type": "TriggerStateSync", "to_create": [], "to_cancel": []}\n' + trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) + trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( + to_create=[], to_cancel=[] ) trigger_runner.to_create.append(workload) @@ -316,9 +315,9 @@ async def test_invalid_trigger(self, supervisor_builder): await trigger_runner.sync_state_to_supervisor(ids) # Check that we sent the right info in the failure message - assert trigger_runner.requests_sock.write.call_count == 1 - blob = trigger_runner.requests_sock.write.mock_calls[0].args[0] - msg = messages.TriggerStateChanges.model_validate_json(blob) + assert trigger_runner.comms_decoder.asend.call_count == 1 + msg = trigger_runner.comms_decoder.asend.mock_calls[0].args[0] + assert isinstance(msg, messages.TriggerStateChanges) assert msg.events is None assert msg.failures is not None @@ -552,6 +551,7 @@ def test_failed_trigger(session, dag_maker, supervisor_builder): ) ], ), + req_id=1, log=MagicMock(), ) @@ -622,10 +622,6 @@ def handle_events(self): super().handle_events() -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio @pytest.mark.execution_timeout(20) async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): @@ -726,13 +722,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"count": dag_run_states_count, "dag_run_state": dag_run_state}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" # Create the test DAG and task @@ -822,13 +813,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count, "task_states": task_states}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states.""" # Create the test DAG and task diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index b5ae53625e256..d542bfa034b05 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1956,17 +1956,20 @@ def override_caplog(request): @pytest.fixture -def mock_supervisor_comms(): +def mock_supervisor_comms(monkeypatch): # for back-compat from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if not AIRFLOW_V_3_0_PLUS: yield None return - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as supervisor_comms: - yield supervisor_comms + + import airflow.sdk.execution_time.task_runner + from airflow.sdk.execution_time.comms import CommsDecoder + + comms = mock.create_autospec(CommsDecoder) + monkeypatch.setattr(airflow.sdk.execution_time.task_runner, "SUPERVISOR_COMMS", comms, raising=False) + yield comms @pytest.fixture @@ -1991,7 +1994,6 @@ def mocked_parse(spy_agency): id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1 ), file="", - requests_fd=0, ), "example_dag_id", CustomOperator(task_id="hello"), @@ -2198,7 +2200,6 @@ def _create_task_instance( ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, ti_context=ti_context, start_date=start_date, # type: ignore ) diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index a3fa147910e3a..468a6308e6032 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -47,7 +47,6 @@ classifiers = [ dependencies = [ "apache-airflow-core<3.2.0,>=3.1.0", - "aiologic>=0.14.0", "attrs>=24.2.0, !=25.2.0", "fsspec>=2023.10.0", "httpx>=0.27.0", diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index c9b777daca32e..770dbf53df769 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -77,9 +77,8 @@ def set( map_index=map_index, ) - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -114,9 +113,8 @@ def _set_xcom_in_db( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -188,23 +186,16 @@ def _get_xcom_db_ref( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -248,23 +239,16 @@ def get_one( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - include_prior_dates=include_prior_dates, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + include_prior_dates=include_prior_dates, + ), + ) if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -307,24 +291,17 @@ def get_all( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - start=None, - stop=None, - step=None, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + msg=GetXComSequenceSlice( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + start=None, + stop=None, + step=None, + ), + ) if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") @@ -379,9 +356,8 @@ def delete( map_index=map_index, ) cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( + SUPERVISOR_COMMS.send( + DeleteXCom( key=key, dag_id=dag_id, task_id=task_id, diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index d899dd3718336..44479fbb9cd42 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -69,18 +69,16 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> ) def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: - import structlog - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name=self.__class__.__qualname__) - def _fetch_asset(name: str) -> Asset: - SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name)) - if isinstance(msg := SUPERVISOR_COMMS.get_message(), ErrorResponse): - raise AirflowRuntimeError(msg) - return Asset(**msg.model_dump(exclude={"type"})) + resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + if resp is None: + raise RuntimeError("Empty non-error response received") + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) + return Asset(**resp.model_dump(exclude={"type"})) value: Any for key, param in inspect.signature(self.python_callable).parameters.items(): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d0622cf6ffc86..67a298a6c14ab 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -19,12 +19,17 @@ Communication protocol between the Supervisor and the task process ================================================================== -* All communication is done over stdout/stdin in the form of "JSON lines" (each - message is a single JSON document terminated by `\n` character) -* Messages from the subprocess are all log messages and are sent directly to the log +* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame + (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same + encoding +* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) * No messages are sent to task process except in response to a request. (This is because the task process will be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom value etc.) +* Every request returns a response, even if the frame is otherwise empty. +* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a + bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending + requests. The reason this communication protocol exists, rather than the task process speaking directly to the Task Execution API server is because: @@ -43,15 +48,20 @@ from __future__ import annotations +import itertools from collections.abc import Iterator from datetime import datetime from functools import cached_property -from typing import Annotated, Any, Literal, Union +from pathlib import Path +from socket import socket +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from uuid import UUID import attrs +import msgspec +import structlog from fastapi import Body -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetEventDagRunReference, @@ -80,6 +90,144 @@ ) from airflow.sdk.exceptions import ErrorType +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + +SendMsgType = TypeVar("SendMsgType", bound=BaseModel) +ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) + + +def _msgpack_enc_hook(obj: Any) -> Any: + import pendulum + + if isinstance(obj, pendulum.DateTime): + # convert the complex to a tuple of real, imag + return datetime( + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo + ) + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, BaseModel): + return obj.model_dump(exclude_unset=True) + + # Raise a NotImplementedError for other types + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def _new_encoder() -> msgspec.msgpack.Encoder: + return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) + + +class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The request id, set by the sender. + + This is used to allow "pipeling" of requests and to be able to tie response to requests, which is + particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. + """ + body: dict[str, Any] | None + + req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() + + def as_bytes(self) -> bytearray: + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + buffer = bytearray(256) + + self.req_encoder.encode_into(self, buffer, 4) + + n = len(buffer) - 4 + if n > 2**32: + raise OverflowError("Cannot send messages larger than 4GiB") + buffer[:4] = n.to_bytes(4, byteorder="big") + + return buffer + + +class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The id of the request this is a response to + """ + body: dict[str, Any] | None = None + error: dict[str, Any] | None = None + + +@attrs.define() +class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): + """Handle communication between the task in this process and the supervisor parent process.""" + + log: Logger = attrs.field(repr=False, factory=structlog.get_logger) + request_socket: socket = attrs.field(factory=lambda: socket(fileno=0)) + + resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( + factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False + ) + + id_counter: Iterator[int] = attrs.field(factory=itertools.count) + + # We could be "clever" here and set the default to this based type parameters and a custom + # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a + # "sort of wrong default" + body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + def send(self, msg: SendMsgType) -> ReceiveMsgType | None: + """Send a request to the parent and block until the response is received.""" + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + self.request_socket.sendall(bytes) + + return self._get_response() + + def _read_frame(self): + """ + Get a message from the parent. + + This will block until the message has been received. + """ + if self.request_socket: + self.request_socket.setblocking(True) + len_bytes = self.request_socket.recv(4) + + if len_bytes == b"": + raise EOFError("Request socket closed before length") + + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = bytearray(len) + nread = self.request_socket.recv_into(buffer) + if nread != len: + raise RuntimeError( + f"unable to read full response in child. (We read {nread}, but expected {len})" + ) + if nread == 0: + raise EOFError("Request socket closed before response was complete") + + return self.resp_decoder.decode(buffer) + + def _from_frame(self, frame) -> ReceiveMsgType | None: + from airflow.sdk.exceptions import AirflowRuntimeError + + if frame.error is not None: + err = self.err_decoder.validate_python(frame.error) + raise AirflowRuntimeError(error=err) + + if frame.body is None: + return None + + try: + return self.body_decoder.validate_python(frame.body) + except Exception: + self.log.exception("Unable to decode message") + raise + + def _get_response(self) -> ReceiveMsgType | None: + frame = self._read_frame() + return self._from_frame(frame) + class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -87,13 +235,7 @@ class StartupDetails(BaseModel): ti: TaskInstance dag_rel_path: str bundle_info: BundleInfo - requests_fd: int start_date: datetime - """ - The channel for the task to send requests over. - - Responses will come back on stdin - """ ti_context: TIRunContext type: Literal["StartupDetails"] = "StartupDetails" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 438687bdeb870..54fb0c9e4920e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -150,13 +150,7 @@ def _get_connection(conn_id: str) -> Connection: from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -203,13 +197,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -259,11 +247,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) def _delete_variable(key: str) -> None: @@ -275,12 +259,7 @@ def _delete_variable(key: str) -> None: from airflow.sdk.execution_time.comms import DeleteVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=DeleteVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -383,23 +362,29 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: @staticmethod def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: from airflow.sdk.definitions.asset import Asset - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetAssetByName, + GetAssetByUri, + ToSupervisor, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if name: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + msg = GetAssetByName(name=name) elif uri: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + msg = GetAssetByUri(uri=uri) else: raise ValueError("Either name or uri must be provided") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetResult) - return Asset(**msg.model_dump(exclude={"type"})) + assert isinstance(resp, AssetResult) + return Asset(**resp.model_dump(exclude={"type"})) @attrs.define @@ -533,9 +518,11 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve ErrorResponse, GetAssetEventByAsset, GetAssetEventByAssetAlias, + ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): @@ -545,31 +532,33 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve if isinstance(obj, Asset): asset = self._assets[AssetUniqueKey.from_asset(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri)) + msg = GetAssetEventByAsset(name=asset.name, uri=asset.uri) elif isinstance(obj, AssetNameRef): try: asset = next(a for k, a in self._assets.items() if k.name == obj.name) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=None)) + msg = GetAssetEventByAsset(name=asset.name, uri=None) elif isinstance(obj, AssetUriRef): try: asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=None, uri=asset.uri)) + msg = GetAssetEventByAsset(name=None, uri=asset.uri) elif isinstance(obj, AssetAlias): asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name)) + msg = GetAssetEventByAssetAlias(alias_name=asset_alias.name) + else: + raise TypeError(f"`key` is of unknown type ({type(key).__name__})") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetEventsResult) + assert isinstance(resp, AssetEventsResult) - return list(msg.iter_asset_event_results()) + return list(resp.iter_asset_event_results()) @attrs.define @@ -626,8 +615,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 9cf9acfac81bb..d8b356870c7aa 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -91,16 +91,14 @@ def __len__(self) -> int: task = self._xcom_arg.operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComCount( + msg = SUPERVISOR_COMMS.send( + GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, run_id=self._ti.run_id, task_id=task.task_id, ), ) - msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise RuntimeError(msg) if not isinstance(msg, XComCountResponse): @@ -127,43 +125,37 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if isinstance(key, slice): start, stop, step = _coerce_slice(key) - with SUPERVISOR_COMMS.lock: - source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=xcom_arg.key, - dag_id=source.dag_id, - task_id=source.task_id, - run_id=self._ti.run_id, - start=start, - stop=stop, - step=step, - ), - ) - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComSequenceSliceResult): - raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") - return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] - - if not isinstance(key, int): - if (index := getattr(key, "__index__", None)) is not None: - key = index() - raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") - - with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceItem( + msg = SUPERVISOR_COMMS.send( + GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, task_id=source.task_id, run_id=self._ti.run_id, - offset=key, + start=start, + stop=stop, + step=step, ), ) - msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") + return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] + + if not isinstance(key, int): + if (index := getattr(key, "__index__", None)) is not None: + key = index() + raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") + + source = (xcom_arg := self._xcom_arg).operator + msg = SUPERVISOR_COMMS.send( + GetXComSequenceItem( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + offset=key, + ), + ) if isinstance(msg, ErrorResponse): raise IndexError(key) if not isinstance(msg, XComSequenceIndexResult): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index dc63daeef7bb7..e70f9154d8c20 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,12 +27,13 @@ import signal import sys import time +import weakref from collections import deque from collections.abc import Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus -from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair +from socket import socket, socketpair from typing import ( TYPE_CHECKING, BinaryIO, @@ -44,7 +45,6 @@ ) from uuid import UUID -import aiologic import attrs import httpx import msgspec @@ -64,6 +64,7 @@ XComSequenceIndexResponse, ) from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -109,6 +110,8 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -180,23 +183,6 @@ ********************************************************************************************************""" -def mkpipe( - remote_read: bool = False, -) -> tuple[socket, socket]: - """Create a pair of connected sockets.""" - rsock, wsock = socketpair() - local, remote = (wsock, rsock) if remote_read else (rsock, wsock) - - if remote_read: - # Setting a 4KB buffer here if possible, if not, it still works, so we will suppress all exceptions - with suppress(Exception): - local.setsockopt(SO_SNDBUF, SOL_SOCKET, BUFFER_SIZE) - # set nonblocking to True so that send or sendall waits till all data is sent - local.setblocking(True) - - return remote, local - - def _subprocess_main(): from airflow.sdk.execution_time.task_runner import main @@ -224,14 +210,13 @@ def _configure_logs_over_json_channel(log_fd: int): def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, fd, sock, mode in ( - ("stdin", 0, child_stdin, "r"), + # Yes, we want to re-open stdin in write mode! This is cause it is a bi-directional socket, so we can + # read and write to it. + ("stdin", 0, child_stdin, "w"), ("stdout", 1, child_stdout, "w"), ("stderr", 2, child_stderr, "w"), ): - handle = getattr(sys, handle_name) - handle.close() os.dup2(sock.fileno(), fd) del sock @@ -319,7 +304,7 @@ def __getattr__(name: str): def _fork_main( - child_stdin: socket, + requests: socket, child_stdout: socket, child_stderr: socket, log_fd: int, @@ -346,10 +331,12 @@ def _fork_main( # Store original stderr for last-chance exception handling last_chance_stderr = _get_last_chance_stderr() + # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno()) + _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) - _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + _reopen_std_io_handles(requests, child_stdout, child_stderr) def exit(n: int) -> NoReturn: with suppress(ValueError, OSError): @@ -360,11 +347,11 @@ def exit(n: int) -> NoReturn: last_chance_stderr.flush() # Explicitly close the child-end of our supervisor sockets so - # the parent sees EOF on both "requests" and "logs" channels. + # the parent sees EOF on "logs" channel. with suppress(OSError): os.close(log_fd) with suppress(OSError): - os.close(child_stdin.fileno()) + os.close(requests.fileno()) os._exit(n) if hasattr(atexit, "_clear"): @@ -432,16 +419,18 @@ class WatchedSubprocess: """The decoder to use for incoming messages from the child process.""" _process: psutil.Process = attrs.field(repr=False) - _requests_fd: int """File descriptor for request handling.""" - _num_open_sockets: int = 4 _exit_code: int | None = attrs.field(default=None, init=False) _process_exit_monotonic: float | None = attrs.field(default=None, init=False) - _fd_to_socket_type: dict[int, str] = attrs.field(factory=dict, init=False) + _open_sockets: weakref.WeakKeyDictionary[socket, str] = attrs.field( + factory=weakref.WeakKeyDictionary, init=False + ) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector, repr=False) + _frame_encoder: msgspec.msgpack.Encoder = attrs.field(factory=comms._new_encoder, repr=False) + process_log: FilteringBoundLogger = attrs.field(repr=False) subprocess_logs_to_stdout: bool = False @@ -460,18 +449,19 @@ def start( ) -> Self: """Fork and start a new subprocess with the specified target function.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess - child_stdin, feed_stdin = mkpipe(remote_read=True) - child_stdout, read_stdout = mkpipe() - child_stderr, read_stderr = mkpipe() + child_stdout, read_stdout = socketpair() + child_stderr, read_stderr = socketpair() - # Open these socketpair before forking off the child, so that it is open when we fork. - child_comms, read_msgs = mkpipe() - child_logs, read_logs = mkpipe() + # Place for child to send requests/read responses, and the server side to read/respond + child_requests, read_requests = socketpair() + + # Open the socketpair before forking off the child, so that it is open when we fork. + child_logs, read_logs = socketpair() pid = os.fork() if pid == 0: # Close and delete of the parent end of the sockets. - cls._close_unused_sockets(feed_stdin, read_stdout, read_stderr, read_msgs, read_logs) + cls._close_unused_sockets(read_requests, read_stdout, read_stderr, read_logs) # Python GC should delete these for us, but lets make double sure that we don't keep anything # around in the forked processes, especially things that might involve open files or sockets! @@ -480,28 +470,28 @@ def start( try: # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + _fork_main(child_requests, child_stdout, child_stderr, child_logs.fileno(), target) except BaseException as e: + import traceback + with suppress(BaseException): # We can't use log here, as if we except out of _fork_main something _weird_ went on. - print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + print("Exception in _fork_main, exiting with code 124", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr) # It's really super super important we never exit this block. We are in the forked child, and if we # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) os._exit(124) - requests_fd = child_comms.fileno() - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) + cls._close_unused_sockets(child_stdout, child_stderr, child_logs) logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc = cls( pid=pid, - stdin=feed_stdin, + stdin=read_requests, process=psutil.Process(pid), - requests_fd=requests_fd, process_log=logger, start_time=time.monotonic(), **constructor_kwargs, @@ -510,7 +500,7 @@ def start( proc._register_pipe_readers( stdout=read_stdout, stderr=read_stderr, - requests=read_msgs, + requests=read_requests, logs=read_logs, ) @@ -523,24 +513,26 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - # Track socket types for debugging - self._fd_to_socket_type = { - stdout.fileno(): "stdout", - stderr.fileno(): "stderr", - requests.fileno(): "requests", - logs.fileno(): "logs", - } + # Track the open sockets, and for debugging what type each one is + self._open_sockets.update( + ( + (stdout, "stdout"), + (stderr, "stderr"), + (logs, "logs"), + (requests, "requests"), + ) + ) target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) if self.subprocess_logs_to_stdout: target_loggers += (log,) self.selector.register( - stdout, selectors.EVENT_READ, self._create_socket_handler(target_loggers, channel="stdout") + stdout, selectors.EVENT_READ, self._create_log_forwarder(target_loggers, channel="stdout") ) self.selector.register( stderr, selectors.EVENT_READ, - self._create_socket_handler(target_loggers, channel="stderr", log_level=logging.ERROR), + self._create_log_forwarder(target_loggers, channel="stderr", log_level=logging.ERROR), ) self.selector.register( logs, @@ -552,37 +544,47 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke self.selector.register( requests, selectors.EVENT_READ, - make_buffered_socket_reader(self.handle_requests(log), on_close=self._on_socket_closed), + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), ) - def _create_socket_handler(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: + def _create_log_forwarder(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: """Create a socket handler that forwards logs to a logger.""" return make_buffered_socket_reader( forward_to_log(loggers, chan=channel, level=log_level), on_close=self._on_socket_closed ) - def _on_socket_closed(self): + def _on_socket_closed(self, sock: socket): # We want to keep servicing this process until we've read up to EOF from all the sockets. - self._num_open_sockets -= 1 - def send_msg(self, msg: BaseModel, **dump_opts): - """Send the given pydantic message to the subprocess at once by encoding it and adding a line break.""" - b = msg.model_dump_json(**dump_opts).encode() + b"\n" - self.stdin.sendall(b) + with suppress(KeyError): + self.selector.unregister(sock) + del self._open_sockets[sock] + + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): + """Send the msg as a length-prefixed response frame.""" + if msg: + frame = _ResponseFrame(id=in_response_to, body=msg.model_dump(**dump_opts)) + else: + err_resp = error.model_dump() if error else None + frame = _ResponseFrame(id=in_response_to, error=err_resp) + + self.stdin.sendall(frame.as_bytes()) - def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _RequestFrame, None]: """Handle incoming requests from the task process, respond with the appropriate data.""" while True: - line = yield + request = yield try: - msg = self.decoder.validate_json(line) + msg = self.decoder.validate_python(request.body) except Exception: - log.exception("Unable to decode message", line=line) + log.exception("Unable to decode message", body=request.body) continue try: - self._handle_request(msg, log) + self._handle_request(msg, log, request.id) except ServerResponseError as e: error_details = e.response.json() if e.response else None log.error( @@ -594,27 +596,25 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N # Send error response back to task so that the error appears in the task logs self.send_msg( - ErrorResponse( + msg=None, + error=ErrorResponse( error=ErrorType.API_SERVER_ERROR, detail={ "status_code": e.response.status_code, "message": str(e), "detail": error_details, }, - ) + ), + in_response_to=request.id, ) - def _handle_request(self, msg, log: FilteringBoundLogger) -> None: + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: raise NotImplementedError() @staticmethod def _close_unused_sockets(*sockets): """Close unused ends of sockets after fork.""" for sock in sockets: - if isinstance(sock, SocketIO): - # If we have the socket IO object, we need to close the underlying socket foricebly here too, - # else we get unclosed socket warnings, and likely leaking FDs too - sock._sock.close() sock.close() def _cleanup_open_sockets(self): @@ -624,20 +624,18 @@ def _cleanup_open_sockets(self): # sockets the supervisor would wait forever thinking they are still # active. This cleanup ensures we always release resources and exit. stuck_sockets = [] - for key in list(self.selector.get_map().values()): - socket_type = self._fd_to_socket_type.get(key.fd, f"unknown-{key.fd}") - stuck_sockets.append(f"{socket_type}({key.fd})") + for sock, socket_type in self._open_sockets.items(): + fileno = "unknown" with suppress(Exception): - self.selector.unregister(key.fileobj) - with suppress(Exception): - key.fileobj.close() # type: ignore[union-attr] + fileno = sock.fileno() + sock.close() + stuck_sockets.append(f"{socket_type}(fd={fileno})") if stuck_sockets: log.warning("Force-closed stuck sockets", pid=self.pid, sockets=stuck_sockets) self.selector.close() - self._close_unused_sockets(self.stdin) - self._num_open_sockets = 0 + self.stdin.close() def kill( self, @@ -736,7 +734,7 @@ def _service_subprocess( events = self.selector.select(timeout=timeout) for key, _ in events: # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) - socket_handler = key.data + socket_handler, on_close = key.data # Example of handler behavior: # If the subprocess writes "Hello, World!" to stdout: @@ -753,8 +751,9 @@ def _service_subprocess( # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors # are removed. if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -773,16 +772,16 @@ def _check_subprocess_exit( raise else: self._process_exit_monotonic = time.monotonic() - self._close_unused_sockets(self.stdin) - # Put a message in the viewable task logs if expect_signal is not None and self._exit_code == -expect_signal: # Bypass logging, the caller expected us to exit with this return self._exit_code - # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): + # Put a message in the viewable task logs + if self._exit_code == -signal.SIGSEGV: self.process_log.critical(SIGSEGV_MESSAGE) + # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): elif name := getattr(self._exit_code, "name", None): message = "Process terminated by signal" level = logging.ERROR @@ -809,7 +808,7 @@ class ActivitySubprocess(WatchedSubprocess): _last_heartbeat_attempt: float = attrs.field(default=0, init=False) # After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we - # will kill the process. This is to handle temporary network issues etc. ensuring that the process + # will kill theprocess. This is to handle temporary network issues etc. ensuring that the process # does not hang around forever. failed_heartbeats: int = attrs.field(default=0, init=False) @@ -861,7 +860,6 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st ti=ti, dag_rel_path=os.fspath(dag_rel_path), bundle_info=bundle_info, - requests_fd=self._requests_fd, ti_context=ti_context, start_date=start_date, ) @@ -870,7 +868,7 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st log.debug("Sending", msg=msg) try: - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) except BrokenPipeError: # Debug is fine, the process will have shown _something_ in it's last_chance exception handler log.debug("Couldn't send startup message to Subprocess - it died very early", pid=self.pid) @@ -930,7 +928,7 @@ def _monitor_subprocess(self): - Processes events triggered on the monitored file objects, such as data availability or EOF. - Sends heartbeats to ensure the process is alive and checks if the subprocess has exited. """ - while self._exit_code is None or self._num_open_sockets > 0: + while self._exit_code is None or self._open_sockets: last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. @@ -946,16 +944,11 @@ def _monitor_subprocess(self): # This listens for activity (e.g., subprocess output) on registered file objects alive = self._service_subprocess(max_wait_time=max_wait_time) is None - if self._exit_code is not None and self._num_open_sockets > 0: + if self._exit_code is not None and self._open_sockets: if ( self._process_exit_monotonic and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT ): - log.debug( - "Forcefully closing remaining sockets", - open_sockets=self._num_open_sockets, - pid=self.pid, - ) self._cleanup_open_sockets() if alive: @@ -1051,7 +1044,7 @@ def final_state(self): return SERVER_TERMINATED return TaskInstanceState.FAILED - def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} @@ -1224,10 +1217,17 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): dump_opts = {"exclude_unset": True} else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def in_process_api_server(): @@ -1237,24 +1237,26 @@ def in_process_api_server(): return api -@attrs.define +@attrs.define(kw_only=True) class InProcessSupervisorComms: """In-process communication handler that uses deques instead of sockets.""" + log: FilteringBoundLogger = attrs.field(repr=False, factory=structlog.get_logger) supervisor: InProcessTestSupervisor - messages: deque[BaseModel] = attrs.field(factory=deque) - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock) + messages: deque[BaseModel | None] = attrs.field(factory=deque) - def get_message(self) -> BaseModel: + def _get_response(self) -> BaseModel | None: """Get a message from the supervisor. Blocks until a message is available.""" return self.messages.popleft() - def send_request(self, log, msg: BaseModel): + def send(self, msg: BaseModel): """Send a request to the supervisor.""" - log.debug("Sending request", msg=msg) + self.log.debug("Sending request", msg=msg) with set_supervisor_comms(None): - self.supervisor._handle_request(msg, log) # type: ignore[arg-type] + self.supervisor._handle_request(msg, log, 0) # type: ignore[arg-type] + + return self._get_response() @attrs.define @@ -1272,7 +1274,8 @@ class InProcessTestSupervisor(ActivitySubprocess): """A supervisor that runs tasks in-process for easier testing.""" comms: InProcessSupervisorComms = attrs.field(init=False) - stdin = attrs.field(init=False) + + stdin: socket = attrs.field(init=False) @classmethod def start( # type: ignore[override] @@ -1298,7 +1301,6 @@ def start( # type: ignore[override] id=what.id, pid=os.getpid(), # Use current process process=psutil.Process(), # Current process - requests_fd=-1, # Not used in in-process mode process_log=logger or structlog.get_logger(logger_name="task").bind(), client=cls._api_client(task.dag), **kwargs, @@ -1363,7 +1365,9 @@ def _api_client(dag=None): client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def send_msg(self, msg: BaseModel, **dump_opts): + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): """Override to use in-process comms.""" self.comms.messages.append(msg) @@ -1421,9 +1425,9 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: # to a (sync) generator def make_buffered_socket_reader( gen: Generator[None, bytes | bytearray, None], - on_close: Callable, + on_close: Callable[[socket], None], buffer_size: int = 4096, -) -> Callable[[socket], bool]: +): buffer = bytearray() # This will hold our accumulated binary data read_buffer = bytearray(buffer_size) # Temporary buffer for each read @@ -1441,7 +1445,6 @@ def cb(sock: socket): with suppress(StopIteration): gen.send(buffer) # Tell loop to close this selector - on_close() return False buffer.extend(read_buffer[:n_received]) @@ -1452,18 +1455,62 @@ def cb(sock: socket): try: gen.send(line) except StopIteration: - on_close() return False buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data return True - return cb + return cb, on_close + + +def length_prefixed_frame_reader( + gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None] +): + length_needed: int | None = None + # This will hold our accumulated/partial binary frame if it doesn't come in a single read + buffer: memoryview | None = None + # position in the buffer to store next read + pos = 0 + decoder = msgspec.msgpack.Decoder[_RequestFrame](_RequestFrame) + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, length_needed, pos + + if length_needed is None: + # Read the 32bit length of the frame + bytes = sock.recv(4) + if bytes == b"": + return False + + length_needed = int.from_bytes(bytes, byteorder="big") + buffer = memoryview(bytearray(length_needed)) + if length_needed and buffer: + n = sock.recv_into(buffer[pos:]) + if n == 0: + # EOF + return False + pos += n + + if pos >= length_needed: + request = decoder.decode(buffer) + buffer = None + pos = 0 + length_needed = None + try: + gen.send(request) + except StopIteration: + return False + return True + + return cb, on_close def process_log_messages_from_subprocess( loggers: tuple[FilteringBoundLogger, ...], -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: from structlog.stdlib import NAME_TO_LEVEL while True: @@ -1499,10 +1546,9 @@ def process_log_messages_from_subprocess( def forward_to_log( target_loggers: tuple[FilteringBoundLogger, ...], chan: str, level: int -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: while True: - buf = yield - line = bytes(buf) + line = yield # Strip off new line line = line.rstrip() try: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 2374520e63f21..0268e24f2177b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -28,16 +28,14 @@ from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from datetime import datetime, timezone -from io import FileIO from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TextIO, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Literal -import aiologic import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -59,6 +57,7 @@ from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, + CommsDecoder, DagRunStateResult, DeferTask, DRCount, @@ -424,10 +423,9 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - SUPERVISOR_COMMS.send_request( - log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) + response = SUPERVISOR_COMMS.send( + msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -445,22 +443,17 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTICount( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, TICount) @@ -477,21 +470,16 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTaskStates( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTaskStates( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + ), + ) if TYPE_CHECKING: assert isinstance(response, TaskStatesResult) @@ -506,19 +494,14 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of DAG runs matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, DRCount) @@ -528,10 +511,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the DAG run with the given Run ID.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -650,62 +630,6 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -SendMsgType = TypeVar("SendMsgType", bound=BaseModel) -ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) - - -@attrs.define() -class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): - """Handle communication between the task in this process and the supervisor parent process.""" - - input: TextIO - - request_socket: FileIO = attrs.field(init=False, default=None) - - # We could be "clever" here and set the default to this based type parameters and a custom - # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a - # "sort of wrong default" - decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) - - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False) - - def get_message(self) -> ReceiveMsgType: - """ - Get a message from the parent. - - This will block until the message has been received. - """ - line = None - - # TODO: Investigate why some empty lines are sent to the processes stdin. - # That was highlighted when working on https://github.com/apache/airflow/issues/48183 - # and is maybe related to deferred/triggerer only context. - while not line: - line = self.input.readline() - - try: - msg = self.decoder.validate_json(line) - except Exception: - structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) - raise - - if isinstance(msg, StartupDetails): - # If we read a startup message, pull out the FDs we care about! - if msg.requests_fd > 0: - self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - elif isinstance(msg, ErrorResponse) and msg.error == ErrorType.API_SERVER_ERROR: - structlog.get_logger(logger_name="task").error("Error response from the API Server") - raise AirflowRuntimeError(error=msg) - - return msg - - def send_request(self, log: Logger, msg: SendMsgType): - encoded_msg = msg.model_dump_json().encode() + b"\n" - - log.debug("Sending request", json=encoded_msg) - self.request_socket.write(encoded_msg) - - # This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # @@ -725,31 +649,33 @@ def send_request(self, log: Logger, msg: SendMsgType): def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - msg = SUPERVISOR_COMMS.get_message() + # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent + # in response to us sending a request. + msg = SUPERVISOR_COMMS._get_response() + + if not isinstance(msg, StartupDetails): + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") log = structlog.get_logger(logger_name="task") + # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 + os_type = sys.platform + if os_type == "darwin": + log.debug("Mac OS detected, skipping setproctitle") + else: + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + try: get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) except Exception: log.exception("error calling listener") - if isinstance(msg, StartupDetails): - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 - os_type = sys.platform - if os_type == "darwin": - log.debug("Mac OS detected, skipping setproctitle") - else: - from setproctitle import setproctitle - - setproctitle(f"airflow worker -- {msg.ti.id}") - - with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): - ti = parse(msg, log) - ti.log_url = get_log_url_from_ti(ti) - log.debug("DAG file parsed", file=msg.dag_rel_path) - else: - raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): + ti = parse(msg, log) + ti.log_url = get_log_url_from_ti(ti) + log.debug("DAG file parsed", file=msg.dag_rel_path) return ti, ti.get_template_context(), log @@ -797,7 +723,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) _validate_task_inlets_and_outlets(ti=ti, log=log) @@ -817,8 +743,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log) - inactive_assets_resp = SUPERVISOR_COMMS.get_message() + inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -914,7 +839,7 @@ def run( except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send_request(log=log, msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -976,7 +901,7 @@ def run( error = e finally: if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + SUPERVISOR_COMMS.send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1014,9 +939,8 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - SUPERVISOR_COMMS.send_request( - log=log, - msg=TriggerDagRun( + comms_msg = SUPERVISOR_COMMS.send( + TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, logical_date=drte.logical_date, @@ -1025,7 +949,6 @@ def _handle_trigger_dag_run( ), ) - comms_msg = SUPERVISOR_COMMS.get_message() if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: if drte.skip_when_already_exists: log.info( @@ -1080,10 +1003,9 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - SUPERVISOR_COMMS.send_request( - log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) + comms_msg = SUPERVISOR_COMMS.send( + GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) ) - comms_msg = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1272,10 +1194,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)), - ) + SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1316,9 +1235,11 @@ def finalize( def main(): - # TODO: add an exception here, it causes an oof of a stack trace! + # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! + log = structlog.get_logger(logger_name="task") + global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) try: ti, context, log = startup() @@ -1329,11 +1250,9 @@ def main(): state, msg, error = run(ti, context, log) finalize(ti, state, context, log, error) except KeyboardInterrupt: - log = structlog.get_logger(logger_name="task") log.exception("Ctrl-c hit") exit(2) except Exception: - log = structlog.get_logger(logger_name="task") log.exception("Top level error") exit(1) finally: diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py new file mode 100644 index 0000000000000..ee2f956b063c9 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +import msgspec +import pytest + +from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails, _ResponseFrame +from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.utils import timezone + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + @pytest.mark.usefixtures("disable_capturing") + def test_recv_StartupDetails(self): + r, w = socketpair() + + msg = { + "type": "StartupDetails", + "ti": { + "id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + "task_id": "a", + "try_number": 1, + "run_id": "b", + "dag_id": "c", + }, + "ti_context": { + "dag_run": { + "dag_id": "c", + "run_id": "b", + "logical_date": "2024-12-01T01:00:00Z", + "data_interval_start": "2024-12-01T00:00:00Z", + "data_interval_end": "2024-12-01T01:00:00Z", + "start_date": "2024-12-01T01:00:00Z", + "run_after": "2024-12-01T01:00:00Z", + "end_date": None, + "run_type": "manual", + "conf": None, + "consumed_asset_events": [], + }, + "max_tries": 0, + "should_retry": False, + "variables": None, + "connections": None, + }, + "file": "/dev/null", + "start_date": "2024-12-01T01:00:00Z", + "dag_rel_path": "/dev/null", + "bundle_info": {"name": "any-name", "version": "any-version"}, + } + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes) + + decoder = CommsDecoder(request_socket=r, log=None) + + msg = decoder._get_response() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.dag_rel_path == "/dev/null" + assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") + assert msg.start_date == timezone.datetime(2024, 12, 1, 1) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 59266da758892..6880c656b01cd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -190,7 +190,7 @@ def test_getattr_connection(self, mock_supervisor_comms): # Conn from the supervisor / API Server conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection; triggers __getattr__ conn = accessor.mysql_conn @@ -203,7 +203,7 @@ def test_get_method_valid_connection(self, mock_supervisor_comms): accessor = ConnectionAccessor() conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result conn = accessor.get("mysql_conn") assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) @@ -216,7 +216,7 @@ def test_get_method_with_default(self, mock_supervisor_comms): error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"} ) - mock_supervisor_comms.get_message.return_value = error_response + mock_supervisor_comms.send.return_value = error_response conn = accessor.get("nonexistent_conn", default_conn=default_conn) assert conn == default_conn @@ -233,7 +233,7 @@ def test_getattr_connection_for_extra_dejson(self, mock_supervisor_comms): extra='{"extra_key": "extra_value"}', ) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection's dejson; triggers __getattr__ dejson = accessor.mysql_conn.extra_dejson @@ -251,7 +251,7 @@ def test_getattr_connection_for_extra_dejson_decode_error(self, mock_log, mock_s conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306, extra="This is not JSON!" ) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection's dejson; triggers __getattr__ dejson = accessor.mysql_conn.extra_dejson @@ -274,7 +274,7 @@ def test_getattr_variable(self, mock_supervisor_comms): # Variable from the supervisor / API Server var_result = VariableResult(key="test_key", value="test_value") - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result # Fetch the variable; triggers __getattr__ value = accessor.test_key @@ -286,7 +286,7 @@ def test_get_method_valid_variable(self, mock_supervisor_comms): accessor = VariableAccessor(deserialize_json=False) var_result = VariableResult(key="test_key", value="test_value") - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result val = accessor.get("test_key") assert val == var_result.value @@ -297,7 +297,7 @@ def test_get_method_with_default(self, mock_supervisor_comms): accessor = VariableAccessor(deserialize_json=False) error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"}) - mock_supervisor_comms.get_message.return_value = error_response + mock_supervisor_comms.send.return_value = error_response val = accessor.get("nonexistent_var_key", default="default_value") assert val == "default_value" @@ -367,7 +367,7 @@ class TestOutletEventAccessor: ), ) def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="") + mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") outlet_event_accessor = OutletEventAccessor(key=key, extra={}) outlet_event_accessor.add(add_arg) @@ -398,7 +398,7 @@ def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): ), ) def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="") + mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) outlet_event_accessor.add(add_arg, extra={}) @@ -497,11 +497,11 @@ def test_getitem_name_ref( resolved_asset, result_indexes, ): - mock_supervisor_comms.get_message.return_value = resolved_asset + mock_supervisor_comms.send.return_value = resolved_asset expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes] assert accessor[Asset.ref(name=name)] == expected - assert len(mock_supervisor_comms.send_request.mock_calls) == 1 - assert mock_supervisor_comms.send_request.mock_calls[0].kwargs["msg"] == GetAssetByName(name=name) + + mock_supervisor_comms.send.assert_called_once_with(GetAssetByName(name=name, type="GetAssetByName")) assert _AssetRefResolutionMixin._asset_ref_cache @pytest.mark.parametrize( @@ -520,11 +520,10 @@ def test_getitem_uri_ref( resolved_asset, result_indexes, ): - mock_supervisor_comms.get_message.return_value = resolved_asset + mock_supervisor_comms.send.return_value = resolved_asset expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes] assert accessor[Asset.ref(uri=uri)] == expected - assert len(mock_supervisor_comms.send_request.mock_calls) == 1 - assert mock_supervisor_comms.send_request.mock_calls[0].kwargs["msg"] == GetAssetByUri(uri=uri) + mock_supervisor_comms.send.assert_called_once_with(GetAssetByUri(uri=uri)) assert _AssetRefResolutionMixin._asset_ref_cache def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): @@ -534,22 +533,19 @@ def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): assert source == AssetEventSourceTaskInstance(dag_id="d1", task_id="t2", run_id="r1", map_index=-1) mock_supervisor_comms.reset_mock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComResult(key="return_value", value="__example_xcom_value__"), ] assert source.xcom_pull() == "__example_xcom_value__" - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call( - log=mock.ANY, - msg=GetXCom( - key="return_value", - dag_id="d1", - run_id="r1", - task_id="t2", - map_index=-1, - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXCom( + key="return_value", + dag_id="d1", + run_id="r1", + task_id="t2", + map_index=-1, ), - ] + ) TEST_ASSET = Asset(name="test_uri", uri="test://test") @@ -593,7 +589,7 @@ def test__get_item__asset_ref(self, access_key, asset, mock_supervisor_comms): assert len(outlet_event_accessors) == 0 # Asset from the API Server via the supervisor - mock_supervisor_comms.get_message.return_value = AssetResult( + mock_supervisor_comms.send.return_value = AssetResult( name=asset.name, uri=asset.uri, group=asset.group, @@ -628,7 +624,7 @@ def test_for_asset_alias(self, mocked__getitem__): class TestInletEventAccessor: @pytest.fixture def sample_inlet_evnets_accessor(self, mock_supervisor_comms): - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult(name="test_uri", uri="test://test", group="asset"), AssetResult(name="test_uri", uri="test://test", group="asset"), ] @@ -656,7 +652,7 @@ def test__get_item__(self, key, sample_inlet_evnets_accessor, mock_supervisor_co asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.get_message.side_effect = [events_result] * 4 + mock_supervisor_comms.send.side_effect = [events_result] * 4 assert sample_inlet_evnets_accessor[key] == [asset_event_resp] @@ -684,7 +680,7 @@ def test_for_asset_alias(self, mocked__getitem__, sample_inlet_evnets_accessor): assert mocked__getitem__.call_args[0][0] == TEST_ASSET_ALIAS def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock_supervisor_comms): - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetEventsResult( asset_events=[ AssetEventResponse( @@ -707,9 +703,7 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock ) ] events = sample_inlet_evnets_accessor[Asset.ref(name="test_uri")] - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call(log=mock.ANY, msg=GetAssetEventByAsset(name="test_uri", uri=None)), - ] + mock_supervisor_comms.send.assert_called_once_with(GetAssetEventByAsset(name="test_uri", uri=None)) assert len(events) == 2 assert events[1].source_task_instance is None @@ -723,19 +717,16 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock ) mock_supervisor_comms.reset_mock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComResult(key="return_value", value="__example_xcom_value__"), ] assert source.xcom_pull() == "__example_xcom_value__" - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call( - log=mock.ANY, - msg=GetXCom( - key="return_value", - dag_id="__dag__", - run_id="__run__", - task_id="__task__", - map_index=0, - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXCom( + key="return_value", + dag_id="__dag__", + run_id="__run__", + task_id="__task__", + map_index=0, ), - ] + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py index e4943196a09da..65dd30b2c7c7b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -17,7 +17,7 @@ from __future__ import annotations -from unittest.mock import ANY, Mock, call +from unittest.mock import Mock, call import pytest @@ -66,121 +66,109 @@ def deserialize_value(cls, xcom): def test_len(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3) + mock_supervisor_comms.send.return_value = XComCountResponse(len=3) assert len(lazy_sequence) == 3 - assert mock_supervisor_comms.send_request.mock_calls == [ - call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run")), - ] + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run"), + ) def test_iter(mock_supervisor_comms, lazy_sequence): it = iter(lazy_sequence) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComSequenceIndexResult(root="f"), ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}), ] assert list(it) == ["f"] - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=0, + mock_supervisor_comms.send.assert_has_calls( + [ + call( + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=0, + ), ), - ), - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=1, + call( + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=1, + ), ), - ), - ] + ] + ) def test_getitem_index(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="f") + mock_supervisor_comms.send.return_value = XComSequenceIndexResult(root="f") assert lazy_sequence[4] == "f" - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) @conf_vars({("core", "xcom_backend"): "task_sdk.execution_time.test_lazy_sequence.CustomXCom"}) def test_getitem_calls_correct_deserialise(monkeypatch, mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="some-value") + mock_supervisor_comms.send.return_value = XComSequenceIndexResult(root="some-value") xcom = resolve_xcom_backend() assert xcom.__name__ == "CustomXCom" monkeypatch.setattr(airflow.sdk.execution_time.xcom, "XCom", xcom) assert lazy_sequence[4] == "Made with CustomXCom: some-value" - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = ErrorResponse( + mock_supervisor_comms.send.return_value = ErrorResponse( error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}, ) with pytest.raises(IndexError) as ctx: lazy_sequence[4] assert ctx.value.args == (4,) - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) def test_getitem_slice(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceSliceResult(root=[6, 4, 1]) + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=[6, 4, 1]) assert lazy_sequence[:5] == [6, 4, 1] - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceSlice( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - start=None, - stop=5, - step=None, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceSlice( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + start=None, + stop=5, + step=None, ), - ] + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4f5e4cfc7ac9d..6e1c3b2e093cb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,13 +27,14 @@ import socket import sys import time -from io import BytesIO from operator import attrgetter +from random import randint from time import sleep from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import httpx +import msgspec import psutil import pytest from pytest_unordered import unordered @@ -56,6 +57,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, DagRunStateResult, DeferTask, @@ -97,17 +99,16 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.supervisor import ( - BUFFER_SIZE, ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, - mkpipe, set_supervisor_comms, supervise, ) -from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz if TYPE_CHECKING: @@ -136,18 +137,25 @@ def local_dag_bundle_cfg(path, name="my-bundle"): } +@pytest.fixture +def client_with_ti_start(make_ti_context): + client = MagicMock(spec=sdk_client.Client) + client.task_instances.start.return_value = make_ti_context() + return client + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @pytest.fixture(autouse=True) def disable_log_upload(self, spy_agency): spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) - def test_reading_from_pipes(self, captured_logs, time_machine): + def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start): def subprocess_main(): # This is run in the subprocess! - # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + # Ensure we follow the "protocol" and get the startup message before we do anything else + CommsDecoder()._get_response() import logging import warnings @@ -180,7 +188,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -228,12 +236,12 @@ def subprocess_main(): ] ) - def test_subprocess_sigkilled(self): + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid() def subprocess_main(): # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() assert os.getpid() != main_pid os.kill(os.getpid(), signal.SIGKILL) @@ -248,7 +256,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -285,7 +293,7 @@ def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -314,7 +322,7 @@ def test_no_heartbeat_in_overtime(self, spy_agency: kgb.SpyAgency, monkeypatch, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -340,7 +348,7 @@ def _on_child_started(self, *args, **kwargs): assert proc.wait() == 0 spy_agency.assert_spy_not_called(heartbeat_spy) - def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context): + def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, client_with_ti_start): """Test running a simple DAG in a subprocess and capturing the output.""" instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) @@ -355,11 +363,6 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker try_number=1, ) - # Create a mock client to assert calls to the client - # We assume the implementation of the client is correct and only need to check the calls - mock_client = mocker.Mock(spec=sdk_client.Client) - mock_client.task_instances.start.return_value = make_ti_context() - bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): exit_code = supervise( @@ -368,7 +371,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker token="", server="", dry_run=True, - client=mock_client, + client=client_with_ti_start, bundle_info=bundle_info, ) assert exit_code == 0, captured_logs @@ -498,7 +501,7 @@ def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, m monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.0) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() sleep(5) # Shouldn't get here exit(5) @@ -611,7 +614,6 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t stdin=mocker.MagicMock(), client=client, process=mock_process, - requests_fd=-1, ) time_now = tz.datetime(2024, 11, 28, 12, 0, 0) @@ -701,7 +703,6 @@ def test_overtime_handling( stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock(), - requests_fd=-1, ) # Set the terminal state and task end datetime @@ -738,7 +739,7 @@ def test_overtime_handling( ), ), ) - def test_exit_by_signal(self, monkeypatch, signal_to_raise, log_pattern, cap_structlog): + def test_exit_by_signal(self, signal_to_raise, log_pattern, cap_structlog, client_with_ti_start): def subprocess_main(): import faulthandler import os @@ -748,7 +749,7 @@ def subprocess_main(): faulthandler.disable() # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() os.kill(os.getpid(), signal_to_raise) @@ -762,7 +763,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -791,26 +792,26 @@ def test_cleanup_sockets_after_delay(self, monkeypatch, mocker, time_machine): stdin=mocker.MagicMock(), client=mocker.MagicMock(), process=mock_process, - requests_fd=-1, ) proc.selector = mocker.MagicMock() proc.selector.select.return_value = [] proc._exit_code = 0 - proc._num_open_sockets = 1 + # Create a fake placeholder in the open socket weakref + proc._open_sockets[mocker.MagicMock()] = "test placeholder" proc._process_exit_monotonic = time.monotonic() mocker.patch.object( ActivitySubprocess, "_cleanup_open_sockets", - side_effect=lambda: setattr(proc, "_num_open_sockets", 0), + side_effect=lambda: setattr(proc, "_open_sockets", {}), ) time_machine.shift(2) proc._monitor_subprocess() - assert proc._num_open_sockets == 0 + assert len(proc._open_sockets) == 0 class TestWatchedSubprocessKill: @@ -829,7 +830,6 @@ def watched_subprocess(self, mocker, mock_process): stdin=mocker.Mock(), client=mocker.Mock(), process=mock_process, - requests_fd=-1, ) # Mock the selector mock_selector = mocker.Mock(spec=selectors.DefaultSelector) @@ -888,7 +888,7 @@ def test_kill_process_custom_signal(self, watched_subprocess, mock_process): ), ], ) - def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, captured_logs, monkeypatch): + def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, client_with_ti_start): def subprocess_main(): import signal @@ -905,7 +905,7 @@ def _handler(sig, frame): signal.signal(signal.SIGINT, _handler) signal.signal(signal.SIGTERM, _handler) try: - sys.stdin.readline() + CommsDecoder()._get_response() print("Ready") sleep(10) except Exception as e: @@ -919,7 +919,7 @@ def _handler(sig, frame): dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -976,9 +976,11 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker): mock_stdout_handler = mocker.Mock(return_value=False) # Simulate EOF for stdout mock_stderr_handler = mocker.Mock(return_value=True) # Continue processing for stderr + mock_on_close = mocker.Mock() + # Mock selector to return events - mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=mock_stdout_handler) - mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=mock_stderr_handler) + mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=(mock_stdout_handler, mock_on_close)) + mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=(mock_stderr_handler, mock_on_close)) watched_subprocess.selector.select.return_value = [(mock_key_stdout, None), (mock_key_stderr, None)] # Mock to simulate process exited successfully @@ -996,8 +998,7 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker): mock_stderr_handler.assert_called_once_with(mock_stderr) # Validate unregistering and closing of EOF file object - watched_subprocess.selector.unregister.assert_called_once_with(mock_stdout) - mock_stdout.close.assert_called_once() + mock_on_close.assert_called_once_with(mock_stdout) # Validate that `_check_subprocess_exit` is called mock_process.wait.assert_called_once_with(timeout=0) @@ -1073,16 +1074,15 @@ def test_max_wait_time_calculation_edge_cases( class TestHandleRequest: @pytest.fixture def watched_subprocess(self, mocker): - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socket.socketpair() subprocess = ActivitySubprocess( process_log=mocker.MagicMock(), id=TI_ID, pid=12345, - stdin=write_end, # this is the writer side + stdin=write_end, client=mocker.Mock(), process=mocker.Mock(), - requests_fd=-1, ) return subprocess, read_end @@ -1091,7 +1091,7 @@ def watched_subprocess(self, mocker): @pytest.mark.parametrize( [ "message", - "expected_buffer", + "expected_body", "client_attr_path", "method_arg", "method_kwarg", @@ -1101,7 +1101,7 @@ def watched_subprocess(self, mocker): [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1111,7 +1111,12 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n', + { + "conn_id": "test_conn", + "conn_type": "mysql", + "password": "password", + "type": "ConnectionResult", + }, "connections.get", ("test_conn",), {}, @@ -1121,7 +1126,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "schema": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1131,7 +1136,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', + {"key": "test_key", "value": "test_value", "type": "VariableResult"}, "variables.get", ("test_key",), {}, @@ -1141,7 +1146,7 @@ def watched_subprocess(self, mocker): ), pytest.param( PutVariable(key="test_key", value="test_value", description="test_description"), - b"", + None, "variables.set", ("test_key", "test_value", "test_description"), {}, @@ -1151,7 +1156,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeleteVariable(key="test_key"), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "variables.delete", ("test_key",), {}, @@ -1161,7 +1166,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeferTask(next_method="execute_callback", classpath="my-classpath"), - b"", + None, "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), {}, @@ -1174,7 +1179,7 @@ def watched_subprocess(self, mocker): reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), end_date=timezone.parse("2024-10-31T12:00:00Z"), ), - b"", + None, "task_instances.reschedule", ( TI_ID, @@ -1190,7 +1195,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1202,7 +1207,7 @@ def watched_subprocess(self, mocker): GetXCom( dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 ), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2, False), {}, @@ -1212,7 +1217,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1228,7 +1233,7 @@ def watched_subprocess(self, mocker): key="test_key", include_prior_dates=True, ), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, True), {}, @@ -1244,7 +1249,7 @@ def watched_subprocess(self, mocker): key="test_key", value='{"key": "test_key", "value": {"key2": "value2"}}', ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1269,7 +1274,7 @@ def watched_subprocess(self, mocker): value='{"key": "test_key", "value": {"key2": "value2"}}', map_index=2, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1295,7 +1300,7 @@ def watched_subprocess(self, mocker): map_index=2, mapped_length=3, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1319,15 +1324,9 @@ def watched_subprocess(self, mocker): key="test_key", map_index=2, ), - b"", + None, "xcoms.delete", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - 2, - ), + ("test_dag", "test_run", "test_task", "test_key", 2), {}, OKResponse(ok=True), None, @@ -1337,7 +1336,7 @@ def watched_subprocess(self, mocker): # if it can handle TaskState message pytest.param( TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), - b"", + None, "", (), {}, @@ -1349,7 +1348,7 @@ def watched_subprocess(self, mocker): RetryTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" ), - b"", + None, "task_instances.retry", (), { @@ -1363,7 +1362,7 @@ def watched_subprocess(self, mocker): ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), - b"", + None, "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), {}, @@ -1373,7 +1372,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByName(name="asset"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"name": "asset"}, @@ -1383,7 +1382,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByUri(uri="s3://bucket/obj"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"uri": "s3://bucket/obj"}, @@ -1393,11 +1392,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": "test"}, @@ -1416,11 +1421,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name=None), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": None}, @@ -1439,11 +1450,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri=None, name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": None, "name": "test"}, @@ -1462,11 +1479,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAssetAlias(alias_name="test_alias"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"alias_name": "test_alias"}, @@ -1485,7 +1508,10 @@ def watched_subprocess(self, mocker): ), pytest.param( ValidateInletsAndOutlets(ti_id=TI_ID), - b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n', + { + "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], + "type": "InactiveAssetsResult", + }, "task_instances.validate_inlets_and_outlets", (TI_ID,), {}, @@ -1499,7 +1525,7 @@ def watched_subprocess(self, mocker): SucceedTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" ), - b"", + None, "task_instances.succeed", (), { @@ -1515,11 +1541,13 @@ def watched_subprocess(self, mocker): ), pytest.param( GetPrevSuccessfulDagRun(ti_id=TI_ID), - ( - b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",' - b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",' - b'"type":"PrevSuccessfulDagRunResult"}\n' - ), + { + "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), + "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), + "start_date": timezone.parse("2025-01-10T12:00:00Z"), + "end_date": timezone.parse("2025-01-10T14:00:00Z"), + "type": "PrevSuccessfulDagRunResult", + }, "task_instances.get_previous_successful_dagrun", (TI_ID,), {}, @@ -1540,7 +1568,7 @@ def watched_subprocess(self, mocker): logical_date=timezone.datetime(2025, 1, 1), reset_dag_run=True, ), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "dag_runs.trigger", ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), {}, @@ -1549,8 +1577,9 @@ def watched_subprocess(self, mocker): id="dag_run_trigger", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR TriggerDagRun(dag_id="test_dag", run_id="test_run"), - b'{"error":"DAGRUN_ALREADY_EXISTS","detail":null,"type":"ErrorResponse"}\n', + {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, "dag_runs.trigger", ("test_dag", "test_run", None, None, False), {}, @@ -1560,7 +1589,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDagRunState(dag_id="test_dag", run_id="test_run"), - b'{"state":"running","type":"DagRunStateResult"}\n', + {"state": "running", "type": "DagRunStateResult"}, "dag_runs.get_state", ("test_dag", "test_run"), {}, @@ -1570,7 +1599,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskRescheduleStartDate(ti_id=TI_ID), - b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n', + {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": "TaskRescheduleStartDate"}, "task_instances.get_reschedule_start_date", (TI_ID, 1), {}, @@ -1580,7 +1609,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), - b'{"count":2,"type":"TICount"}\n', + {"count": 2, "type": "TICount"}, "task_instances.get_count", (), { @@ -1598,7 +1627,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDRCount(dag_id="test_dag", states=["success", "failed"]), - b'{"count":2,"type":"DRCount"}\n', + {"count": 2, "type": "DRCount"}, "dag_runs.get_count", (), { @@ -1613,7 +1642,10 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), - b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n', + { + "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, + "type": "TaskStatesResult", + }, "task_instances.get_task_states", (), { @@ -1636,7 +1668,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=0, ), - b'{"root":"test_value","type":"XComSequenceIndexResult"}\n', + {"root": "test_value", "type": "XComSequenceIndexResult"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 0), {}, @@ -1645,6 +1677,7 @@ def watched_subprocess(self, mocker): id="get_xcom_seq_item", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR GetXComSequenceItem( key="test_key", dag_id="test_dag", @@ -1652,7 +1685,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=2, ), - b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + {"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 2), {}, @@ -1670,7 +1703,7 @@ def watched_subprocess(self, mocker): stop=None, step=None, ), - b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n', + {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", ("test_dag", "test_run", "test_task", "test_key", None, None, None), {}, @@ -1687,7 +1720,7 @@ def test_handle_requests( mocker, time_machine, message, - expected_buffer, + expected_body, client_attr_path, method_arg, method_kwarg, @@ -1715,8 +1748,9 @@ def test_handle_requests( generator = watched_subprocess.handle_requests(log=mocker.Mock()) # Initialize the generator next(generator) - msg = message.model_dump_json().encode() + b"\n" - generator.send(msg) + + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) + generator.send(req_frame) if mask_secret_args: mock_mask_secret.assert_called_with(*mask_secret_args) @@ -1729,33 +1763,23 @@ def test_handle_requests( # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(BUFFER_SIZE) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError, socket.timeout): - # no response written, valid for some message types like setters and TI operations. - pass + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id # Verify the response was added to the buffer - assert val == expected_buffer + assert frame.body == expected_body # Verify the response is correctly decoded # This is important because the subprocess/task runner will read the response # and deserialize it to the correct message type - # Only decode the buffer if it contains data. An empty buffer implies no response was written. - if not val and (mock_response and not isinstance(mock_response, OKResponse)): - pytest.fail("Expected a response, but got an empty buffer.") - - if val: + if frame.body is not None: # Using BytesIO to simulate a readable stream for CommsDecoder. - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) - assert decoder.get_message() == mock_response + decoder = CommsDecoder(request_socket=None).body_decoder + assert decoder.validate_python(frame.body) == mock_response def test_handle_requests_api_server_error(self, watched_subprocess, mocker): """Test that API server errors are properly handled and sent back to the task.""" @@ -1777,28 +1801,32 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): next(generator) - msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")).model_dump_json().encode() + b"\n" - generator.send(msg) + msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")) + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg.model_dump()) + generator.send(req_frame) - # Read response directly from the reader socket + # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - val += read_socket.recv(4096) - except (BlockingIOError, TimeoutError): - pass - - assert val == ( - b'{"error":"API_SERVER_ERROR","detail":{"status_code":500,"message":"API Server Error",' - b'"detail":{"detail":"Internal Server Error"}},"type":"ErrorResponse"}\n' - ) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id + + assert frame.error == { + "error": "API_SERVER_ERROR", + "detail": { + "status_code": 500, + "message": "API Server Error", + "detail": {"detail": "Internal Server Error"}, + }, + "type": "ErrorResponse", + } # Verify the error can be decoded correctly - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) + comms = CommsDecoder(request_socket=None) with pytest.raises(AirflowRuntimeError) as exc_info: - decoder.get_message() + comms._from_frame(frame) assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR assert exc_info.value.error.detail == { @@ -1864,14 +1892,13 @@ def test_inprocess_supervisor_comms_roundtrip(self): """ class MinimalSupervisor(InProcessTestSupervisor): - def _handle_request(self, msg, log): + def _handle_request(self, msg, log, req_id): resp = VariableResult(key=msg.key, value="value") - self.send_msg(resp) + self.send_msg(resp, req_id) supervisor = MinimalSupervisor( id="test", pid=123, - requests_fd=-1, process=MagicMock(), process_log=MagicMock(), client=MagicMock(), @@ -1881,9 +1908,8 @@ def _handle_request(self, msg, log): test_msg = GetVariable(key="test_key") - comms.send_request(log=MagicMock(), msg=test_msg) + response = comms.send(test_msg) # Ensure we got back what we expect - response = comms.get_message() assert isinstance(response, VariableResult) assert response.value == "value" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index d786fdaa5b2bc..5ce460c455bf4 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -22,11 +22,9 @@ import json import os import textwrap -import uuid from collections.abc import Iterable from datetime import datetime, timedelta from pathlib import Path -from socket import socketpair from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -103,7 +101,6 @@ VariableAccessor, ) from airflow.sdk.execution_time.task_runner import ( - CommsDecoder, RuntimeTaskInstance, TaskRunnerMarker, _push_xcom_if_needed, @@ -139,47 +136,6 @@ def execute(self, context): print(f"Hello World {task_id}!") -class TestCommsDecoder: - """Test the communication between the subprocess and the "supervisor".""" - - @pytest.mark.usefixtures("disable_capturing") - def test_recv_StartupDetails(self): - r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"type":"StartupDetails", "ti": {' - b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", ' - b'"dag_id": "c"}, "ti_context":{"dag_run":{"dag_id":"c","run_id":"b",' - b'"logical_date":"2024-12-01T01:00:00Z",' - b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' - b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,' - b'"run_type":"manual","conf":null,"consumed_asset_events":[]},' - b'"max_tries":0,"should_retry":false,"variables":null,"connections":null},"file": "/dev/null",' - b'"start_date":"2024-12-01T01:00:00Z", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' - b'"any-name", "version": "any-version"}, "requests_fd": ' - + str(w2.fileno()).encode("ascii") - + b"}\n" - ) - - decoder = CommsDecoder(input=r.makefile("r")) - - msg = decoder.get_message() - assert isinstance(msg, StartupDetails) - assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") - assert msg.ti.task_id == "a" - assert msg.ti.dag_id == "c" - assert msg.dag_rel_path == "/dev/null" - assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") - assert msg.start_date == timezone.datetime(2024, 12, 1, 1) - - # Since this was a StartupDetails message, the decoder should open the other socket - assert decoder.request_socket is not None - assert decoder.request_socket.writable() - assert decoder.request_socket.fileno() == w2.fileno() - - def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( @@ -192,7 +148,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -248,7 +203,6 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -302,7 +256,6 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -360,8 +313,8 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com assert ti.state == TaskInstanceState.DEFERRED - # send_request will only be called when the TaskDeferred exception is raised - mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) + # send will only be called when the TaskDeferred exception is raised + mock_supervisor_comms.send.assert_any_call(expected_defer_task) def test_run_downstream_skipped(mocked_parse, create_runtime_ti, mock_supervisor_comms): @@ -384,8 +337,8 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] log.info.assert_called_with("Skipping downstream tasks.") - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, msg=SkipDownstreamTasks(tasks=["task1", "task2"], type="SkipDownstreamTasks") + mock_supervisor_comms.send.assert_any_call( + SkipDownstreamTasks(tasks=["task1", "task2"], type="SkipDownstreamTasks") ) @@ -436,8 +389,8 @@ def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comm assert ti.state == TaskInstanceState.SKIPPED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState(state=TaskInstanceState.SKIPPED, end_date=instant), log=mock.ANY + mock_supervisor_comms.send.assert_called_with( + TaskState(state=TaskInstanceState.SKIPPED, end_date=instant) ) @@ -458,12 +411,11 @@ def test_run_raises_base_exception(time_machine, create_runtime_ti, mock_supervi assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( + mock_supervisor_comms.send.assert_called_with( msg=TaskState( state=TaskInstanceState.FAILED, end_date=instant, ), - log=mock.ANY, ) @@ -485,13 +437,7 @@ def test_run_raises_system_exit(time_machine, create_runtime_ti, mock_supervisor assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) log.exception.assert_not_called() log.error.assert_called_with(mock.ANY, exit_code=10) @@ -516,13 +462,7 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms): @@ -545,13 +485,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms assert ti.state == TaskInstanceState.FAILED # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): @@ -573,7 +507,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ), bundle_info=FAKE_BUNDLE, dag_rel_path="", - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -583,7 +516,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm spy_agency.spy_on(task.prepare_for_execution) assert not task._lock_for_execution - # mock_supervisor_comms.get_message.return_value = what run(ti, context=ti.get_template_context(), log=mock.Mock()) spy_agency.assert_spy_called(task.prepare_for_execution) @@ -591,7 +523,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm assert ti.task is not task, "ti.task should be a copy of the original task" assert ti.state == TaskInstanceState.SUCCESS - mock_supervisor_comms.send_request.assert_any_call( + mock_supervisor_comms.send.assert_any_call( msg=SetRenderedFields( rendered_fields={ "bash_command": "echo 'Logical date is 2024-12-01 01:00:00+00:00'", @@ -599,7 +531,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm "env": None, } ), - log=mock.ANY, ) @@ -689,7 +620,6 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -697,22 +627,18 @@ def execute(self, context): time_machine.move_to(instant, tick=False) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what run(*startup()) expected_calls = [ - mock.call.send_request( - msg=SetRenderedFields(rendered_fields=expected_rendered_fields), - log=mock.ANY, - ), - mock.call.send_request( + mock.call.send(SetRenderedFields(rendered_fields=expected_rendered_fields)), + mock.call.send( msg=SucceedTask( end_date=instant, state=TaskInstanceState.SUCCESS, task_outlets=[], outlet_events=[], ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -766,9 +692,8 @@ def execute(self, context): assert ti.state == TaskInstanceState.SUCCESS # Ensure the task is Successful - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send.assert_called_once_with( msg=SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), - log=mock.ANY, ) @@ -815,8 +740,8 @@ def execute(self, context): assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant), log=mock.ANY + mock_supervisor_comms.send.assert_called_once_with( + msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant) ) @@ -834,12 +759,11 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(dag_id=dag_id, run_id="c"), start_date=timezone.utcnow(), ) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what # Set the environment variable for DAG bundles # We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test! @@ -955,7 +879,7 @@ def test_run_with_asset_outlets( validate_mock.assert_called_once() - mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) + mock_supervisor_comms.send.assert_any_call(expected_msg) def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): @@ -967,7 +891,7 @@ def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.get_message.return_value = events_result + mock_supervisor_comms.send.return_value = events_result from airflow.providers.standard.operators.bash import BashOperator @@ -1121,7 +1045,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s dr = runtime_ti._ti_context_from_server.dag_run - mock_supervisor_comms.get_message.return_value = PrevSuccessfulDagRunResult( + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( data_interval_end=dr.logical_date - timedelta(hours=1), data_interval_start=dr.logical_date - timedelta(hours=2), start_date=dr.start_date - timedelta(hours=1), @@ -1171,7 +1095,7 @@ def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, dag_id="basic_task") - mock_supervisor_comms.get_message.return_value = PrevSuccessfulDagRunResult( + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( data_interval_end=timezone.datetime(2025, 1, 1, 2, 0, 0), data_interval_start=timezone.datetime(2025, 1, 1, 1, 0, 0), start_date=timezone.datetime(2025, 1, 1, 1, 0, 0), @@ -1181,13 +1105,13 @@ def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock context = runtime_ti.get_template_context() # Assert lazy attributes are not resolved initially - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access a lazy-loaded attribute to trigger computation assert context["prev_data_interval_start_success"] == timezone.datetime(2025, 1, 1, 1, 0, 0) # Now the lazy attribute should trigger the call - mock_supervisor_comms.get_message.assert_called_once() + mock_supervisor_comms.send.assert_called_once() def test_get_connection_from_context(self, create_runtime_ti, mock_supervisor_comms): """Test that the connection is fetched from the API server via the Supervisor lazily when accessed""" @@ -1206,22 +1130,18 @@ def test_get_connection_from_context(self, create_runtime_ti, mock_supervisor_co ) runtime_ti = create_runtime_ti(task=task, dag_id="test_get_connection_from_context") - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn context = runtime_ti.get_template_context() # Assert that the connection is not fetched from the API server yet! # The connection should be only fetched connection is accessed - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access the connection from the context conn_from_context = context["conn"].test_conn - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, msg=GetConnection(conn_id="test_conn") - ) - mock_supervisor_comms.get_message.assert_called_once_with() + mock_supervisor_comms.send.assert_called_once_with(GetConnection(conn_id="test_conn")) assert conn_from_context == Connection( conn_id="test_conn", @@ -1280,7 +1200,7 @@ def test_template_with_connection( extra='{"extra__asana__workspace": "extra1"}', ) - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn context = runtime_ti.get_template_context() result = runtime_ti.task.render_template(content, context) @@ -1307,22 +1227,18 @@ def test_get_variable_from_context( var = VariableResult(key="test_key", value=var_value) - mock_supervisor_comms.get_message.return_value = var + mock_supervisor_comms.send.return_value = var context = runtime_ti.get_template_context() # Assert that the variable is not fetched from the API server yet! # The variable should be only fetched connection is accessed - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access the variable from the context var_from_context = context["var"][accessor_type].test_key - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, msg=GetVariable(key="test_key") - ) - mock_supervisor_comms.get_message.assert_called_once_with() + mock_supervisor_comms.send.assert_called_once_with(GetVariable(key="test_key")) assert var_from_context == expected_value @@ -1390,16 +1306,14 @@ def execute(self, context): ser_value = BaseXCom.serialize_value(xcom_values) - def mock_get_message_side_effect(*args, **kwargs): - calls = mock_supervisor_comms.send_request.call_args_list - if calls: - last_call = calls[-1] - msg = last_call[1]["msg"] - if isinstance(msg, GetXComSequenceSlice): - return XComSequenceSliceResult(root=[ser_value]) + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + print(f"{args=}, {kwargs=}, {msg=}") + if isinstance(msg, GetXComSequenceSlice): + return XComSequenceSliceResult(root=[ser_value]) return XComResult(key="key", value=ser_value) - mock_supervisor_comms.get_message.side_effect = mock_get_message_side_effect + mock_supervisor_comms.send.side_effect = mock_send_side_effect run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) @@ -1415,8 +1329,7 @@ def mock_get_message_side_effect(*args, **kwargs): task_id = test_task_id for map_index in map_indexes: if map_index == NOTSET: - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( msg=GetXComSequenceSlice( key="key", dag_id="test_dag", @@ -1429,8 +1342,7 @@ def mock_get_message_side_effect(*args, **kwargs): ) else: expected_map_index = map_index if map_index is not None else None - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( msg=GetXCom( key="key", dag_id="test_dag", @@ -1591,11 +1503,8 @@ def execute(self, context): context = runtime_ti.get_template_context() run(runtime_ti, context=context, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: @@ -1644,9 +1553,8 @@ def __init__(self, bash_command, *args, **kwargs): log=mock.MagicMock(), ) - mock_supervisor_comms.send_request.assert_called_with( - msg=SetRenderedFields(rendered_fields={"bash_command": rendered_cmd}), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_with( + msg=SetRenderedFields(rendered_fields={"bash_command": rendered_cmd}) ) @pytest.mark.parametrize( @@ -1669,7 +1577,7 @@ def test_get_first_reschedule_date( task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count) - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate( + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate( start_date=timezone.datetime(2025, 1, 1) ) @@ -1678,7 +1586,7 @@ def test_get_first_reschedule_date( def test_get_ti_count(self, mock_supervisor_comms): """Test that get_ti_count sends the correct request and returns the count.""" - mock_supervisor_comms.get_message.return_value = TICount(count=2) + mock_supervisor_comms.send.return_value = TICount(count=2) count = RuntimeTaskInstance.get_ti_count( dag_id="test_dag", @@ -1689,8 +1597,7 @@ def test_get_ti_count(self, mock_supervisor_comms): states=["success", "failed"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetTICount( dag_id="test_dag", task_ids=["task1", "task2"], @@ -1704,7 +1611,7 @@ def test_get_ti_count(self, mock_supervisor_comms): def test_get_dr_count(self, mock_supervisor_comms): """Test that get_dr_count sends the correct request and returns the count.""" - mock_supervisor_comms.get_message.return_value = DRCount(count=2) + mock_supervisor_comms.send.return_value = DRCount(count=2) count = RuntimeTaskInstance.get_dr_count( dag_id="test_dag", @@ -1713,8 +1620,7 @@ def test_get_dr_count(self, mock_supervisor_comms): states=["success", "failed"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetDRCount( dag_id="test_dag", logical_dates=[timezone.datetime(2024, 1, 1)], @@ -1726,27 +1632,21 @@ def test_get_dr_count(self, mock_supervisor_comms): def test_get_dagrun_state(self, mock_supervisor_comms): """Test that get_dagrun_state sends the correct request and returns the state.""" - mock_supervisor_comms.get_message.return_value = DagRunStateResult(state="running") + mock_supervisor_comms.send.return_value = DagRunStateResult(state="running") state = RuntimeTaskInstance.get_dagrun_state( dag_id="test_dag", run_id="run1", ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, - msg=GetDagRunState( - dag_id="test_dag", - run_id="run1", - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagRunState(dag_id="test_dag", run_id="run1"), ) assert state == "running" def test_get_task_states(self, mock_supervisor_comms): """Test that get_task_states sends the correct request and returns the states.""" - mock_supervisor_comms.get_message.return_value = TaskStatesResult( - task_states={"run1": {"task1": "running"}} - ) + mock_supervisor_comms.send.return_value = TaskStatesResult(task_states={"run1": {"task1": "running"}}) states = RuntimeTaskInstance.get_task_states( dag_id="test_dag", @@ -1754,8 +1654,7 @@ def test_get_task_states(self, mock_supervisor_comms): run_ids=["run1"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetTaskStates( dag_id="test_dag", task_ids=["task1"], @@ -1929,7 +1828,6 @@ def execute(self, context): assert not any( x == mock.call( - log=mock.ANY, msg=SetXCom( key="key", value="pushing to xcom backend!", @@ -1939,7 +1837,7 @@ def execute(self, context): map_index=-1, ), ) - for x in mock_supervisor_comms.send_request.call_args_list + for x in mock_supervisor_comms.send.call_args_list ) def test_xcom_pull_from_custom_xcom_backend( @@ -1966,7 +1864,6 @@ def execute(self, context): assert not any( x == mock.call( - log=mock.ANY, msg=GetXCom( key="key", dag_id="test_dag", @@ -1975,7 +1872,7 @@ def execute(self, context): map_index=-1, ), ) - for x in mock_supervisor_comms.send_request.call_args_list + for x in mock_supervisor_comms.send.call_args_list ) @@ -2006,11 +1903,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dag_overwrite(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -2035,11 +1929,8 @@ def execute(self, context): task=task, dag_id="dag_with_dag_params_overwrite", conf={"value": "new_value"} ) run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dag_default(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -2062,11 +1953,8 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task, dag_id="dag_with_dag_params_default") run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_resolves( @@ -2097,11 +1985,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dagrun_parameterized( @@ -2136,11 +2021,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) @pytest.mark.parametrize("value", [VALUE, 0]) @@ -2167,11 +2049,8 @@ def return_num(num): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_any_call( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) @@ -2234,12 +2113,11 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what mocked_parse(what, "basic_dag", task) runtime_ti, context, log = startup() @@ -2539,17 +2417,15 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): assert msg.state == TaskInstanceState.SUCCESS expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", reset_dag_run=False, logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=SetXCom( key="trigger_run_id", value="test_run_id", @@ -2558,7 +2434,6 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): run_id="test_run", map_index=-1, ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -2586,23 +2461,21 @@ def test_handle_trigger_dag_run_conflict( ti = create_runtime_ti(dag_id="test_handle_trigger_dag_run_conflict", run_id="test_run", task=task) log = mock.MagicMock() - mock_supervisor_comms.get_message.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) + mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) state, msg, _ = run(ti, ti.get_template_context(), log) assert state == expected_state assert msg.state == expected_state expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), run_id="test_run_id", reset_dag_run=False, ), - log=mock.ANY, ), - mock.call.get_message(), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -2647,13 +2520,19 @@ def test_handle_trigger_dag_run_wait_for_completion( ) log = mock.MagicMock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ + # Set RTIF + None, # Successful Dag Run trigger OKResponse(ok=True), + # Set XCOM, + None, # Dag Run is still running DagRunStateResult(state=DagRunState.RUNNING), # Dag Run completes execution on the next poll DagRunStateResult(state=target_dr_state), + # Succeed/Fail task + None, ] with mock.patch("time.sleep", return_value=None): state, msg, _ = run(ti, ti.get_template_context(), log) @@ -2662,16 +2541,14 @@ def test_handle_trigger_dag_run_wait_for_completion( assert msg.state == expected_task_state expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=SetXCom( key="trigger_run_id", value="test_run_id", @@ -2680,22 +2557,18 @@ def test_handle_trigger_dag_run_wait_for_completion( run_id="test_run", map_index=-1, ), - log=mock.ANY, ), - mock.call.send_request( + mock.call.send( msg=GetDagRunState( dag_id="test_dag", run_id="test_run_id", ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=GetDagRunState( dag_id="test_dag", run_id="test_run_id", ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) From 8ad4db594ba7b953c2c5047d6f938a4f81f719e3 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 10 Jun 2025 11:55:46 +0100 Subject: [PATCH 08/21] Switch the Supervisor/task process from line-based to length-prefixed The existing JSON Lines based approach had two major drawbacks 1. In the case of really large lines (in the region of 10 or 20MB) the python line buffering could _sometimes_ result in a partial read 2. The JSON based approach didn't have the ability to add any metadata (such as errors). 3. Not every message type/call-site waited for a response, which meant those client functions could never get told about an error One of the ways this line-based approach fell down was if you suddenly tried to run 100s of triggers at the same time you would get an error like this: ``` Traceback (most recent call last): File "/Users/ash/.local/share/uv/python/cpython-3.12.7-macos-aarch64-none/lib/python3.12/asyncio/streams.py", line 568, in readline line = await self.readuntil(sep) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/ash/.local/share/uv/python/cpython-3.12.7-macos-aarch64-none/lib/python3.12/asyncio/streams.py", line 663, in readuntil raise exceptions.LimitOverrunError( asyncio.exceptions.LimitOverrunError: Separator is found, but chunk is longer than limit ``` The other way this caused problems was if you parse a large dag (as in one with 20k tasks or more) the DagFileProcessor could end up getting a partial read which would be invalid JSON. This changes the communications protocol in in a couple of ways. First off at the python level the separate send and receive methods in the client/task side have been removed and replaced with a single `send()` that sends the request, reads the response and raises an error if one is returned. (But note, right now almost nothing in the supervisor side sets the error, that will be a future PR.) Secondly the JSON Lines approach has been changed from a line-based protocol to a binary "frame" one. The protocol (which is the same for whichever side is sending) is length-prefixed, i.e. we first send the length of the data as a 4byte big-endian integer, followed by the data itself. This should remove the possibility of JSON parse errors due to reading incomplete lines Finally the last change made in this PR is to remove the "extra" requests socket/channel. Upon closer examination with this comms path I realised that this socket is unnecessary: Since we are in 100% control of the client side we can make use of the bi-directional nature of `socketpair` and save file handles. This also happens to help the `run_as_user` feature which is currently broken, as without extra config to `sudoers` file, `sudo` will close all filehandles other than stdin, stdout, and stderr -- so by introducing this change we make it easier to re-add run_as_user support. In order to support this in the DagFileProcessor (as the fact that the proc manager uses a single selector for multiple processes) means I have moved the `on_close` callback to be part of the object we store in the `selector` object in the supervisors, previoulsy it was the "on_read" callback, now we store a tuple of `(on_read, on_close)` and on_close is called once universally. This also changes the way comms are handled from the (async) TriggerRunner process. Previously we had a sync+async lock, but that made it possible to end up deadlocking things. The change now is to have `send` on `TriggerCommsDecoder` "go back" to the async even loop via `async_to_sync`, so that only async code deals with the socket, and we can use an async lock (rather than the hybrid sync and async lock we tried before). This seems to help the deadlock issue, but I'm not 100% sure it will remove it entirely, but it makes it much much harder to hit - I've not been able to reprouce it with this change --- .../src/airflow/dag_processing/manager.py | 9 +- .../src/airflow/dag_processing/processor.py | 33 +- .../src/airflow/jobs/triggerer_job_runner.py | 124 ++++--- .../tests/unit/dag_processing/test_manager.py | 82 ++--- .../unit/dag_processing/test_processor.py | 51 +-- airflow-core/tests/unit/hooks/test_base.py | 13 +- .../tests/unit/jobs/test_triggerer_job.py | 34 +- .../tests/unit/models/test_taskinstance.py | 4 +- .../src/tests_common/pytest_plugin.py | 15 +- .../unit/amazon/aws/links/test_athena.py | 2 +- .../tests/unit/amazon/aws/links/test_batch.py | 6 +- .../unit/amazon/aws/links/test_comprehend.py | 4 +- .../unit/amazon/aws/links/test_datasync.py | 4 +- .../tests/unit/amazon/aws/links/test_ec2.py | 4 +- .../tests/unit/amazon/aws/links/test_emr.py | 12 +- .../tests/unit/amazon/aws/links/test_glue.py | 2 +- .../tests/unit/amazon/aws/links/test_logs.py | 2 +- .../unit/amazon/aws/links/test_sagemaker.py | 2 +- .../links/test_sagemaker_unified_studio.py | 2 +- .../amazon/aws/links/test_step_function.py | 4 +- .../tests/unit/common/io/xcom/test_backend.py | 15 +- .../unit/dbt/cloud/operators/test_dbt.py | 2 +- .../unit/google/cloud/links/test_cloud_run.py | 2 +- .../unit/google/cloud/links/test_dataplex.py | 20 +- .../unit/google/cloud/links/test_translate.py | 8 +- .../google/cloud/operators/test_dataproc.py | 20 +- .../azure/operators/test_data_factory.py | 2 +- .../microsoft/azure/operators/test_synapse.py | 2 +- task-sdk/pyproject.toml | 1 - task-sdk/src/airflow/sdk/bases/xcom.py | 98 ++---- .../sdk/definitions/asset/decorators.py | 14 +- .../src/airflow/sdk/execution_time/comms.py | 164 ++++++++- .../src/airflow/sdk/execution_time/context.py | 76 ++-- .../sdk/execution_time/lazy_sequence.py | 60 ++-- .../airflow/sdk/execution_time/supervisor.py | 290 +++++++++------- .../airflow/sdk/execution_time/task_runner.py | 215 ++++-------- task-sdk/tests/task_sdk/bases/test_sensor.py | 10 +- .../tests/task_sdk/definitions/conftest.py | 6 +- .../definitions/test_asset_decorators.py | 16 +- .../task_sdk/definitions/test_connections.py | 2 +- .../definitions/test_mappedoperator.py | 41 +-- .../task_sdk/definitions/test_variables.py | 5 +- .../task_sdk/definitions/test_xcom_arg.py | 36 +- .../task_sdk/execution_time/test_comms.py | 83 +++++ .../task_sdk/execution_time/test_context.py | 85 ++--- .../execution_time/test_lazy_sequence.py | 140 ++++---- .../execution_time/test_supervisor.py | 328 ++++++++++-------- .../execution_time/test_task_runner.py | 293 +++++----------- 48 files changed, 1229 insertions(+), 1214 deletions(-) create mode 100644 task-sdk/tests/task_sdk/execution_time/test_comms.py diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 272e086d90903..8af35af62b7ef 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -77,6 +77,8 @@ from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: + from socket import socket + from sqlalchemy.orm import Session from airflow.callbacks.callback_requests import CallbackRequest @@ -388,7 +390,7 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): """ events = self.selector.select(timeout=timeout) for key, _ in events: - socket_handler = key.data + socket_handler, on_close = key.data # BrokenPipeError should be caught and treated as if the handler returned false, similar # to EOF case @@ -397,8 +399,9 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): except BrokenPipeError: need_more = False if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() def _queue_requested_files_for_parsing(self) -> None: """Queue any files requested for parsing as requested by users via UI/API.""" diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 5f69082c758aa..4341d6d91390e 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel): bundle_path: Path """Passing bundle path around lets us figure out relative file path.""" - requests_fd: int callback_requests: list[CallbackRequest] = Field(default_factory=list) type: Literal["DagFileParseRequest"] = "DagFileParseRequest" @@ -102,18 +101,16 @@ class DagFileParsingResult(BaseModel): def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time import comms, task_runner # Parse DAG file, send JSON back up! - comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager]( - input=sys.stdin, - decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), + comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) - msg = comms_decoder.get_message() + msg = comms_decoder._get_response() if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") @@ -125,7 +122,7 @@ def _parse_file_entrypoint(): result = _parse_file(msg, log) if result is not None: - comms_decoder.send_request(log, result) + comms_decoder.send(result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: @@ -266,20 +263,18 @@ def _on_child_started( msg = DagFileParseRequest( file=os.fspath(path), bundle_path=bundle_path, - requests_fd=self._requests_fd, callback_requests=callbacks, ) - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) - def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse resp: BaseModel | None = None dump_opts = {} if isinstance(msg, DagFileParsingResult): self.parsing_result = msg - return - if isinstance(msg, GetConnection): + elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) @@ -301,10 +296,16 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # resp = self.client.variables.delete(msg.key) else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) @property def is_ready(self) -> bool: @@ -312,7 +313,7 @@ def is_ready(self) -> bool: # Process still alive, def can't be finished yet return False - return self._num_open_sockets == 0 + return not self._open_sockets def wait(self) -> int: raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 5df6a03261523..e829a4fa6eae0 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -28,6 +28,7 @@ from collections.abc import Generator, Iterable from contextlib import suppress from datetime import datetime +from socket import socket from traceback import format_exception from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypedDict, Union @@ -43,6 +44,7 @@ from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger from airflow.sdk.execution_time.comms import ( + CommsDecoder, ConnectionResult, DagRunStateResult, DRCount, @@ -58,6 +60,7 @@ TICount, VariableResult, XComResult, + _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.stats import Stats @@ -70,8 +73,6 @@ from airflow.utils.session import provide_session if TYPE_CHECKING: - from socket import socket - from sqlalchemy.orm import Session from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -181,7 +182,6 @@ class messages: class StartTriggerer(BaseModel): """Tell the async trigger runner process to start, and where to send status update messages.""" - requests_fd: int type: Literal["StartTriggerer"] = "StartTriggerer" class TriggerStateChanges(BaseModel): @@ -295,7 +295,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): """ TriggerRunnerSupervisor is responsible for monitoring the subprocess and marshalling DB access. - This class (which runs in the main process) is responsible for querying the DB, sending RunTrigger + This class (which runs in the main/sync process) is responsible for querying the DB, sending RunTrigger workload messages to the subprocess, and collecting results and updating them in the DB. """ @@ -342,8 +342,8 @@ def start( # type: ignore[override] ): proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs) - msg = messages.StartTriggerer(requests_fd=proc._requests_fd) - proc.send_msg(msg) + msg = messages.StartTriggerer() + proc.send_msg(msg, in_response_to=0) return proc @functools.cached_property @@ -355,7 +355,7 @@ def client(self) -> Client: client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, TaskStatesResponse, @@ -454,8 +454,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - else: raise ValueError(f"Unknown message type {type(msg)}") - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def run(self) -> None: """Run synchronously and handle all database reads/writes.""" @@ -628,7 +627,7 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke ), ) - def _process_log_messages_from_subprocess(self) -> Generator[None, bytes, None]: + def _process_log_messages_from_subprocess(self) -> Generator[None, bytes | bytearray, None]: import msgspec from structlog.stdlib import NAME_TO_LEVEL @@ -691,14 +690,58 @@ class TriggerDetails(TypedDict): events: int +@attrs.define(kw_only=True) +class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): + _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer") + _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader") + + body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field( + factory=lambda: TypeAdapter(ToTriggerRunner), repr=False + ) + + _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + + def _read_frame(self): + from asgiref.sync import async_to_sync + + return async_to_sync(self._aread_frame)() + + def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + from asgiref.sync import async_to_sync + + return async_to_sync(self.asend)(msg) + + async def _aread_frame(self): + len_bytes = await self._async_reader.readexactly(4) + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = await self._async_reader.readexactly(len) + return self.resp_decoder.decode(buffer) + + async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None: + frame = await self._aread_frame() + if frame.id != expect_id: + # Given the lock we take out in `asend`, this _shouldn't_ be possible, but I'd rather fail with + # this explicit error return the wrong type of message back to a Trigger + raise RuntimeError(f"Response read out of order! Got {frame.id=}, {expect_id=}") + return self._from_frame(frame) + + async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + async with self._lock: + self._async_writer.write(bytes) + + return await self._aget_response(frame.id) + + class TriggerRunner: """ Runtime environment for all triggers. - Mainly runs inside its own thread, where it hands control off to an asyncio - event loop, but is also sometimes interacted with from the main thread - (where all the DB queries are done). All communication between threads is - done via Deques. + Mainly runs inside its own process, where it hands control off to an asyncio + event loop. All communication between this and it's (sync) supervisor is done via sockets """ # Maps trigger IDs to their running tasks and other info @@ -726,10 +769,7 @@ class TriggerRunner: # TODO: connect this to the parent process log: FilteringBoundLogger = structlog.get_logger() - requests_sock: asyncio.StreamWriter - response_sock: asyncio.StreamReader - - decoder: TypeAdapter[ToTriggerRunner] + comms_decoder: TriggerCommsDecoder def __init__(self): super().__init__() @@ -740,7 +780,6 @@ def __init__(self): self.events = deque() self.failed_triggers = deque() self.job_id = None - self.decoder = TypeAdapter(ToTriggerRunner) def run(self): """Sync entrypoint - just run a run in an async loop.""" @@ -796,36 +835,21 @@ async def init_comms(self): """ from airflow.sdk.execution_time import task_runner - loop = asyncio.get_event_loop() + # Yes, we read and write to stdin! It's a socket, not a normal stdin. + reader, writer = await asyncio.open_connection(sock=socket(fileno=0)) - comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]( - input=sys.stdin, - decoder=self.decoder, + self.comms_decoder = TriggerCommsDecoder( + async_writer=writer, + async_reader=reader, ) - task_runner.SUPERVISOR_COMMS = comms_decoder - - async def connect_stdin() -> asyncio.StreamReader: - reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, sys.stdin) - return reader - - self.response_sock = await connect_stdin() + task_runner.SUPERVISOR_COMMS = self.comms_decoder - line = await self.response_sock.readline() + msg = await self.comms_decoder._aget_response(expect_id=0) - msg = self.decoder.validate_json(line) if not isinstance(msg, messages.StartTriggerer): raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - writer_transport, writer_protocol = await loop.connect_write_pipe( - lambda: asyncio.streams.FlowControlMixin(loop=loop), - comms_decoder.request_socket, - ) - self.requests_sock = asyncio.streams.StreamWriter(writer_transport, writer_protocol, None, loop) - async def create_triggers(self): """Drain the to_create queue and create all new triggers that have been requested in the DB.""" while self.to_create: @@ -934,8 +958,6 @@ async def cleanup_finished_triggers(self) -> list[int]: return finished_ids async def sync_state_to_supervisor(self, finished_ids: list[int]): - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Copy out of our deques in threadsafe manner to sync state with parent events_to_send = [] while self.events: @@ -961,19 +983,17 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]): if not finished_ids: msg.finished = None - # Block triggers from making any requests for the duration of this - async with SUPERVISOR_COMMS.lock: - # Tell the monitor that we've finished triggers so it can update things - self.requests_sock.write(msg.model_dump_json(exclude_none=True).encode() + b"\n") - line = await self.response_sock.readline() - - if line == b"": # EoF received! + # Tell the monitor that we've finished triggers so it can update things + try: + resp = await self.comms_decoder.asend(msg) + except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") + return + raise - resp = self.decoder.validate_json(line) if not isinstance(resp, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got f{type(msg)}") + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") self.to_create.extend(resp.to_create) self.to_cancel.extend(resp.to_cancel) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index c9e974b8cdb9a..98204033ced9e 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -30,10 +30,11 @@ from datetime import datetime, timedelta from logging.config import dictConfig from pathlib import Path -from socket import socket +from socket import socket, socketpair from unittest import mock from unittest.mock import MagicMock +import msgspec import pytest import time_machine from sqlalchemy import func, select @@ -54,7 +55,6 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.execution_time.supervisor import mkpipe from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import create_session @@ -138,20 +138,19 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces logger_filehandle = MagicMock() proc.create_time.return_value = time.time() proc.wait.return_value = 0 - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socketpair() ret = DagFileProcessorProcess( process_log=MagicMock(), id=uuid7(), pid=1234, process=proc, stdin=write_end, - requests_fd=123, logger_filehandle=logger_filehandle, client=MagicMock(), ) if start_time: ret.start_time = start_time - ret._num_open_sockets = 0 + ret._open_sockets.clear() return ret, read_end @pytest.fixture @@ -552,18 +551,17 @@ def test_kill_timed_out_processors_no_kill(self): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.parametrize( - ["callbacks", "path", "expected_buffer"], + ["callbacks", "path", "expected_body"], [ pytest.param( [], "/opt/airflow/dags/test_dag.py", - b"{" - b'"file":"/opt/airflow/dags/test_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,' - b'"callback_requests":[],' - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/test_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [], + "type": "DagFileParseRequest", + }, ), pytest.param( [ @@ -577,44 +575,40 @@ def test_kill_timed_out_processors_no_kill(self): ) ], "/opt/airflow/dags/dag_callback_dag.py", - b"{" - b'"file":"/opt/airflow/dags/dag_callback_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,"callback_requests":' - b"[" - b"{" - b'"filepath":"dag_callback_dag.py",' - b'"bundle_name":"testing",' - b'"bundle_version":null,' - b'"msg":null,' - b'"dag_id":"dag_id",' - b'"run_id":"run_id",' - b'"is_failure_callback":false,' - b'"type":"DagCallbackRequest"' - b"}" - b"]," - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/dag_callback_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [ + { + "filepath": "dag_callback_dag.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": None, + "dag_id": "dag_id", + "run_id": "run_id", + "is_failure_callback": False, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ), ], ) - def test_serialize_callback_requests(self, callbacks, path, expected_buffer): + def test_serialize_callback_requests(self, callbacks, path, expected_body): + from airflow.sdk.execution_time.comms import _ResponseFrame + processor, read_socket = self.mock_processor() processor._on_child_started(callbacks, path, bundle_path=Path("/opt/airflow/dags")) read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(4096) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError): - # no response written, valid for some message types. - pass - - assert val == expected_buffer + # Read response from the read end of the socket + read_socket.settimeout(0.1) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.body == expected_body @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 28ce7a8c23fe6..eb6a58dfcfbfb 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -42,7 +42,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.api.client import Client -from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -87,7 +87,6 @@ def _process_file( DagFileParseRequest( file=file_path, bundle_path=TEST_DAG_FOLDER, - requests_fd=1, callback_requests=callback_requests or [], ), log=structlog.get_logger(), @@ -393,26 +392,38 @@ def disable_capturing(): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.usefixtures("disable_capturing") -def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency): +def test_parse_file_entrypoint_parses_dag_callbacks(mocker): r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags","requests_fd":' - + str(w2.fileno()).encode("ascii") - + b',"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' - b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": ' - b'"manual__2024-12-30T21:02:55.203691+00:00", ' - b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": "DagFileParseRequest"}\n' + + frame = comms._ResponseFrame( + id=1, + body={ + "file": "/files/dags/wait.py", + "bundle_path": "/files/dags", + "callback_requests": [ + { + "filepath": "wait.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": "task_failure", + "dag_id": "wait_to_fail", + "run_id": "manual__2024-12-30T21:02:55.203691+00:00", + "is_failure_callback": True, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ) + bytes = frame.as_bytes() + w.sendall(bytes) - decoder = CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - input=r.makefile("r"), - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + request_socket=r, + body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), ) - msg = decoder.get_message() + msg = decoder._get_response() assert isinstance(msg, DagFileParseRequest) assert msg.file == "/files/dags/wait.py" assert msg.callback_requests == [ @@ -455,7 +466,7 @@ def fake_collect_dags(self, *args, **kwargs): ) ] _parse_file( - DagFileParseRequest(file="A", bundle_path="no matter", requests_fd=1, callback_requests=requests), + DagFileParseRequest(file="A", bundle_path="no matter", callback_requests=requests), log=structlog.get_logger(), ) @@ -489,8 +500,6 @@ def fake_collect_dags(self, *args, **kwargs): bundle_version=None, ) ] - _parse_file( - DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() - ) + _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) assert called is True diff --git a/airflow-core/tests/unit/hooks/test_base.py b/airflow-core/tests/unit/hooks/test_base.py index 8c2d44f84854b..6d54827e67ced 100644 --- a/airflow-core/tests/unit/hooks/test_base.py +++ b/airflow-core/tests/unit/hooks/test_base.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -from unittest import mock - import pytest from airflow.exceptions import AirflowNotFoundException @@ -54,18 +52,18 @@ def test_get_connection(self, mock_supervisor_comms): extra='{"extra_key": "extra_value"}', ) - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn hook = BaseHook(logger_name="") hook.get_connection(conn_id="test_conn") - mock_supervisor_comms.send_request.assert_called_once_with( - msg=GetConnection(conn_id="test_conn"), log=mock.ANY + mock_supervisor_comms.send.assert_called_once_with( + msg=GetConnection(conn_id="test_conn"), ) def test_get_connection_not_found(self, mock_supervisor_comms): conn_id = "test_conn" hook = BaseHook() - mock_supervisor_comms.get_message.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) with pytest.raises(AirflowNotFoundException, match=rf".*{conn_id}.*"): hook.get_connection(conn_id=conn_id) @@ -85,5 +83,4 @@ def test_get_connection_secrets_backend_configured(self, mock_supervisor_comms, assert retrieved_conn.conn_id == "CONN_A" - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d3c9ae6f27dc9..d3cecb5fc94e6 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -34,6 +34,7 @@ from airflow.hooks.base import BaseHook from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( + TriggerCommsDecoder, TriggererJobRunner, TriggerRunner, TriggerRunnerSupervisor, @@ -172,7 +173,6 @@ def builder(job=None): pid=process.pid, stdin=mocker.Mock(), process=process, - requests_fd=-1, capacity=10, ) # Mock the selector @@ -302,10 +302,9 @@ async def test_invalid_trigger(self, supervisor_builder): id=1, ti=None, classpath="fake.classpath", encrypted_kwargs={} ) trigger_runner = TriggerRunner() - trigger_runner.requests_sock = MagicMock() - trigger_runner.response_sock = AsyncMock() - trigger_runner.response_sock.readline.return_value = ( - b'{"type": "TriggerStateSync", "to_create": [], "to_cancel": []}\n' + trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) + trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( + to_create=[], to_cancel=[] ) trigger_runner.to_create.append(workload) @@ -316,9 +315,9 @@ async def test_invalid_trigger(self, supervisor_builder): await trigger_runner.sync_state_to_supervisor(ids) # Check that we sent the right info in the failure message - assert trigger_runner.requests_sock.write.call_count == 1 - blob = trigger_runner.requests_sock.write.mock_calls[0].args[0] - msg = messages.TriggerStateChanges.model_validate_json(blob) + assert trigger_runner.comms_decoder.asend.call_count == 1 + msg = trigger_runner.comms_decoder.asend.mock_calls[0].args[0] + assert isinstance(msg, messages.TriggerStateChanges) assert msg.events is None assert msg.failures is not None @@ -552,6 +551,7 @@ def test_failed_trigger(session, dag_maker, supervisor_builder): ) ], ), + req_id=1, log=MagicMock(), ) @@ -622,10 +622,6 @@ def handle_events(self): super().handle_events() -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio @pytest.mark.execution_timeout(20) async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): @@ -726,13 +722,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"count": dag_run_states_count, "dag_run_state": dag_run_state}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" # Create the test DAG and task @@ -822,13 +813,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count, "task_states": task_states}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states.""" # Create the test DAG and task diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 401302e64eea4..5c6e4f69a5256 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -1886,7 +1886,7 @@ def producer_with_inactive(*, outlet_events): def test_inlet_asset_extra(self, dag_maker, session, mock_supervisor_comms): from airflow.sdk.definitions.asset import Asset - mock_supervisor_comms.get_message.return_value = AssetEventsResult( + mock_supervisor_comms.send.return_value = AssetEventsResult( asset_events=[ AssetEventResponse( id=1, @@ -1960,7 +1960,7 @@ def read(*, inlet_events): @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_comms): asset_alias_name = "test_inlet_asset_extra_asset_alias" - mock_supervisor_comms.get_message.return_value = AssetEventsResult(asset_events=[]) + mock_supervisor_comms.send.return_value = AssetEventsResult(asset_events=[]) asset_alias_model = AssetAliasModel(name=asset_alias_name) session.add(asset_alias_model) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index b5ae53625e256..d542bfa034b05 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1956,17 +1956,20 @@ def override_caplog(request): @pytest.fixture -def mock_supervisor_comms(): +def mock_supervisor_comms(monkeypatch): # for back-compat from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if not AIRFLOW_V_3_0_PLUS: yield None return - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as supervisor_comms: - yield supervisor_comms + + import airflow.sdk.execution_time.task_runner + from airflow.sdk.execution_time.comms import CommsDecoder + + comms = mock.create_autospec(CommsDecoder) + monkeypatch.setattr(airflow.sdk.execution_time.task_runner, "SUPERVISOR_COMMS", comms, raising=False) + yield comms @pytest.fixture @@ -1991,7 +1994,6 @@ def mocked_parse(spy_agency): id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1 ), file="", - requests_fd=0, ), "example_dag_id", CustomOperator(task_id="hello"), @@ -2198,7 +2200,6 @@ def _create_task_instance( ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, ti_context=ti_context, start_date=start_date, # type: ignore ) diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_athena.py b/providers/amazon/tests/unit/amazon/aws/links/test_athena.py index 8abe65cf2cb8c..99a3536d17838 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_athena.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_athena.py @@ -30,7 +30,7 @@ class TestAthenaQueryResultsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=AthenaQueryResultsLink.key, value={ "region_name": "eu-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_batch.py b/providers/amazon/tests/unit/amazon/aws/links/test_batch.py index 8ecf7022f2162..70cd65655bfec 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_batch.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_batch.py @@ -34,7 +34,7 @@ class TestBatchJobDefinitionLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -58,7 +58,7 @@ class TestBatchJobDetailsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "cn-north-1", @@ -80,7 +80,7 @@ class TestBatchJobQueueLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py b/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py index 2c60aa27d5085..9b88270b5bd30 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py @@ -34,7 +34,7 @@ class TestComprehendPiiEntitiesDetectionLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): test_job_id = "123-345-678" if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -61,7 +61,7 @@ def test_extra_link(self, mock_supervisor_comms): "arn:aws:comprehend:us-east-1:0123456789:document-classifier/test-custom-document-classifier" ) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py b/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py index cd7f5be5cbe63..79c8469b701f7 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py @@ -34,7 +34,7 @@ class TestDataSyncTaskLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): task_id = TASK_ID if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", @@ -56,7 +56,7 @@ class TestDataSyncTaskExecutionLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py index 79680d3caa136..f451c910058cf 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py @@ -32,7 +32,7 @@ class TestEC2InstanceLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -66,7 +66,7 @@ def test_instance_id_filter(self): def test_extra_link(self, mock_supervisor_comms): instance_list = ",:".join(self.INSTANCE_IDS) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_emr.py b/providers/amazon/tests/unit/amazon/aws/links/test_emr.py index cbe2544bd86db..feda067f7cc7a 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_emr.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_emr.py @@ -45,7 +45,7 @@ class TestEmrClusterLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", @@ -82,7 +82,7 @@ class TestEmrLogsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-2", @@ -127,7 +127,7 @@ def test_extra_link(self, mocked_emr_serverless_hook, mock_supervisor_comms): mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "conn_id": "aws-test", @@ -160,7 +160,7 @@ def test_extra_link(self, mocked_emr_serverless_hook, mock_supervisor_comms): mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "conn_id": "aws-test", @@ -254,7 +254,7 @@ class TestEmrServerlessS3LogsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", @@ -282,7 +282,7 @@ class TestEmrServerlessCloudWatchLogsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_glue.py b/providers/amazon/tests/unit/amazon/aws/links/test_glue.py index b73c65182f662..2b1f076e149df 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_glue.py @@ -30,7 +30,7 @@ class TestGlueJobRunDetailsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "ap-southeast-2", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_logs.py b/providers/amazon/tests/unit/amazon/aws/links/test_logs.py index 8c642e55c1a6a..2c90eecd232ad 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_logs.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_logs.py @@ -30,7 +30,7 @@ class TestCloudWatchEventsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py index 25bae4efdfa62..f08d7df93d509 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py @@ -31,7 +31,7 @@ class TestSageMakerTransformDetailsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="sagemaker_transform_job_details", value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py index 1b3cbed1f142d..bb749727323e2 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py @@ -30,7 +30,7 @@ class TestSageMakerUnifiedStudioLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py b/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py index ca4505855e9f8..acfad7e98e96c 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py @@ -47,7 +47,7 @@ class TestStateMachineDetailsLink(BaseAwsLinksTestCase): ) def test_extra_link(self, state_machine_arn, expected_url: str, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -82,7 +82,7 @@ class TestStateMachineExecutionsDetailsLink(BaseAwsLinksTestCase): ) def test_extra_link(self, execution_arn, expected_url: str, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py b/providers/common/io/tests/unit/common/io/xcom/test_backend.py index 99fb46a66c7e5..fdecf31a436d3 100644 --- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py +++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py @@ -106,9 +106,7 @@ def test_value_db(self, task_instance, mock_supervisor_comms, session): if AIRFLOW_V_3_0_PLUS: # When using XComObjectStorageBackend, the value is stored in the db is serialized with json dumps # so we need to mimic that same behavior below. - mock_supervisor_comms.get_message.return_value = XComResult( - key="return_value", value={"key": "value"} - ) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value={"key": "value"}) value = XCom.get_value( key=XCOM_RETURN_KEY, @@ -169,7 +167,7 @@ def test_value_storage(self, task_instance, mock_supervisor_comms, session): assert p.exists() is True if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=XCOM_RETURN_KEY, value={"key": "bigvaluebigvaluebigvalue" * 100} ) @@ -213,7 +211,8 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): ) if AIRFLOW_V_3_0_PLUS: - path = mock_supervisor_comms.send_request.call_args_list[-1].kwargs["msg"].value + last_call = mock_supervisor_comms.send.call_args_list[-1] + path = (last_call.kwargs.get("msg") or last_call.args[0]).value XComModel.set( key=XCOM_RETURN_KEY, value=path, @@ -251,7 +250,7 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): assert p.exists() is True if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=XCOM_RETURN_KEY, value={"key": "superlargevalue" * 100} ) value = XCom.get_value( @@ -261,7 +260,7 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): assert value if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult(key=XCOM_RETURN_KEY, value=path) + mock_supervisor_comms.send.return_value = XComResult(key=XCOM_RETURN_KEY, value=path) XCom.delete( dag_id=task_instance.dag_id, task_id=task_instance.task_id, @@ -356,7 +355,7 @@ def test_compression(self, task_instance, session, mock_supervisor_comms): assert data.endswith(".gz") if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=XCOM_RETURN_KEY, value={"key": "superlargevalue" * 100} ) diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py index 9e8408c56955c..616423d5ba984 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py @@ -665,7 +665,7 @@ def test_run_job_operator_link( ti.xcom_push(key="job_run_url", value=_run_response["data"]["href"]) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="job_run_url", value=EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( account_id=account_id or DEFAULT_ACCOUNT_ID, diff --git a/providers/google/tests/unit/google/cloud/links/test_cloud_run.py b/providers/google/tests/unit/google/cloud/links/test_cloud_run.py index 7b115655102e2..5f3c348698c19 100644 --- a/providers/google/tests/unit/google/cloud/links/test_cloud_run.py +++ b/providers/google/tests/unit/google/cloud/links/test_cloud_run.py @@ -68,7 +68,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, log_uri=TEST_LOG_URI) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value=TEST_LOG_URI, ) diff --git a/providers/google/tests/unit/google/cloud/links/test_dataplex.py b/providers/google/tests/unit/google/cloud/links/test_dataplex.py index 3ec2905751390..babdf9127c8e1 100644 --- a/providers/google/tests/unit/google/cloud/links/test_dataplex.py +++ b/providers/google/tests/unit/google/cloud/links/test_dataplex.py @@ -123,7 +123,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "lake_id": ti.task.lake_id, @@ -153,7 +153,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "project_id": ti.task.project_id, @@ -183,7 +183,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "lake_id": ti.task.lake_id, @@ -212,7 +212,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "entry_group_id": ti.task.entry_group_id, @@ -242,7 +242,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -270,7 +270,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "entry_type_id": ti.task.entry_type_id, @@ -300,7 +300,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -328,7 +328,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "aspect_type_id": ti.task.aspect_type_id, @@ -358,7 +358,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -387,7 +387,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "entry_id": ti.task.entry_id, diff --git a/providers/google/tests/unit/google/cloud/links/test_translate.py b/providers/google/tests/unit/google/cloud/links/test_translate.py index ddfc3e205e8e2..640b958a3841e 100644 --- a/providers/google/tests/unit/google/cloud/links/test_translate.py +++ b/providers/google/tests/unit/google/cloud/links/test_translate.py @@ -62,7 +62,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, dataset_id=DATASET, project_id=GCP_PROJECT_ID) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={"location": ti.task.location, "dataset_id": DATASET, "project_id": GCP_PROJECT_ID}, ) @@ -85,7 +85,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, project_id=GCP_PROJECT_ID) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "project_id": GCP_PROJECT_ID, @@ -121,7 +121,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis project_id=GCP_PROJECT_ID, ) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -158,7 +158,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis project_id=GCP_PROJECT_ID, ) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index e10eba067b6a8..986cfba4113f5 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -1136,7 +1136,7 @@ def test_create_cluster_operator_extra_links( assert operator_extra_link.name == "Dataproc Cluster" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value="", ) @@ -1146,7 +1146,7 @@ def test_create_cluster_operator_extra_links( ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={"cluster_id": "cluster_name", "project_id": "test-project", "region": "test-location"}, ) @@ -2021,7 +2021,7 @@ def test_submit_job_operator_extra_links( assert operator_extra_link.name == "Dataproc Job" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_job", value="", ) @@ -2032,7 +2032,7 @@ def test_submit_job_operator_extra_links( ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_job", value=DATAPROC_JOB_EXPECTED, ) @@ -2237,7 +2237,7 @@ def test_update_cluster_operator_extra_links( assert operator_extra_link.name == "Dataproc Cluster" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_cluster", value="", ) @@ -2247,7 +2247,7 @@ def test_update_cluster_operator_extra_links( ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, ) @@ -2463,7 +2463,7 @@ def test_instantiate_workflow_operator_extra_links( assert operator_extra_link.name == "Dataproc Workflow" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value="", ) @@ -2472,7 +2472,7 @@ def test_instantiate_workflow_operator_extra_links( ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED, ) @@ -3148,7 +3148,7 @@ def test_instantiate_inline_workflow_operator_extra_links( operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value="", ) @@ -3157,7 +3157,7 @@ def test_instantiate_inline_workflow_operator_extra_links( ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py index 788f52efc3772..23d7d18d4a407 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py @@ -253,7 +253,7 @@ def test_run_pipeline_operator_link( ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"], ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py index 45c538c2c400c..62a1021ead8a5 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py @@ -292,7 +292,7 @@ def test_run_pipeline_operator_link(self, create_task_instance_of_operator, mock ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"], ) diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index a3fa147910e3a..468a6308e6032 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -47,7 +47,6 @@ classifiers = [ dependencies = [ "apache-airflow-core<3.2.0,>=3.1.0", - "aiologic>=0.14.0", "attrs>=24.2.0, !=25.2.0", "fsspec>=2023.10.0", "httpx>=0.27.0", diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index c9b777daca32e..770dbf53df769 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -77,9 +77,8 @@ def set( map_index=map_index, ) - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -114,9 +113,8 @@ def _set_xcom_in_db( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -188,23 +186,16 @@ def _get_xcom_db_ref( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -248,23 +239,16 @@ def get_one( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - include_prior_dates=include_prior_dates, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + include_prior_dates=include_prior_dates, + ), + ) if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -307,24 +291,17 @@ def get_all( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - start=None, - stop=None, - step=None, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + msg=GetXComSequenceSlice( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + start=None, + stop=None, + step=None, + ), + ) if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") @@ -379,9 +356,8 @@ def delete( map_index=map_index, ) cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( + SUPERVISOR_COMMS.send( + DeleteXCom( key=key, dag_id=dag_id, task_id=task_id, diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index d899dd3718336..44479fbb9cd42 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -69,18 +69,16 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> ) def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: - import structlog - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name=self.__class__.__qualname__) - def _fetch_asset(name: str) -> Asset: - SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name)) - if isinstance(msg := SUPERVISOR_COMMS.get_message(), ErrorResponse): - raise AirflowRuntimeError(msg) - return Asset(**msg.model_dump(exclude={"type"})) + resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + if resp is None: + raise RuntimeError("Empty non-error response received") + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) + return Asset(**resp.model_dump(exclude={"type"})) value: Any for key, param in inspect.signature(self.python_callable).parameters.items(): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d0622cf6ffc86..67a298a6c14ab 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -19,12 +19,17 @@ Communication protocol between the Supervisor and the task process ================================================================== -* All communication is done over stdout/stdin in the form of "JSON lines" (each - message is a single JSON document terminated by `\n` character) -* Messages from the subprocess are all log messages and are sent directly to the log +* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame + (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same + encoding +* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) * No messages are sent to task process except in response to a request. (This is because the task process will be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom value etc.) +* Every request returns a response, even if the frame is otherwise empty. +* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a + bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending + requests. The reason this communication protocol exists, rather than the task process speaking directly to the Task Execution API server is because: @@ -43,15 +48,20 @@ from __future__ import annotations +import itertools from collections.abc import Iterator from datetime import datetime from functools import cached_property -from typing import Annotated, Any, Literal, Union +from pathlib import Path +from socket import socket +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from uuid import UUID import attrs +import msgspec +import structlog from fastapi import Body -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetEventDagRunReference, @@ -80,6 +90,144 @@ ) from airflow.sdk.exceptions import ErrorType +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + +SendMsgType = TypeVar("SendMsgType", bound=BaseModel) +ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) + + +def _msgpack_enc_hook(obj: Any) -> Any: + import pendulum + + if isinstance(obj, pendulum.DateTime): + # convert the complex to a tuple of real, imag + return datetime( + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo + ) + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, BaseModel): + return obj.model_dump(exclude_unset=True) + + # Raise a NotImplementedError for other types + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def _new_encoder() -> msgspec.msgpack.Encoder: + return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) + + +class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The request id, set by the sender. + + This is used to allow "pipeling" of requests and to be able to tie response to requests, which is + particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. + """ + body: dict[str, Any] | None + + req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() + + def as_bytes(self) -> bytearray: + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + buffer = bytearray(256) + + self.req_encoder.encode_into(self, buffer, 4) + + n = len(buffer) - 4 + if n > 2**32: + raise OverflowError("Cannot send messages larger than 4GiB") + buffer[:4] = n.to_bytes(4, byteorder="big") + + return buffer + + +class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The id of the request this is a response to + """ + body: dict[str, Any] | None = None + error: dict[str, Any] | None = None + + +@attrs.define() +class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): + """Handle communication between the task in this process and the supervisor parent process.""" + + log: Logger = attrs.field(repr=False, factory=structlog.get_logger) + request_socket: socket = attrs.field(factory=lambda: socket(fileno=0)) + + resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( + factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False + ) + + id_counter: Iterator[int] = attrs.field(factory=itertools.count) + + # We could be "clever" here and set the default to this based type parameters and a custom + # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a + # "sort of wrong default" + body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + def send(self, msg: SendMsgType) -> ReceiveMsgType | None: + """Send a request to the parent and block until the response is received.""" + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + self.request_socket.sendall(bytes) + + return self._get_response() + + def _read_frame(self): + """ + Get a message from the parent. + + This will block until the message has been received. + """ + if self.request_socket: + self.request_socket.setblocking(True) + len_bytes = self.request_socket.recv(4) + + if len_bytes == b"": + raise EOFError("Request socket closed before length") + + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = bytearray(len) + nread = self.request_socket.recv_into(buffer) + if nread != len: + raise RuntimeError( + f"unable to read full response in child. (We read {nread}, but expected {len})" + ) + if nread == 0: + raise EOFError("Request socket closed before response was complete") + + return self.resp_decoder.decode(buffer) + + def _from_frame(self, frame) -> ReceiveMsgType | None: + from airflow.sdk.exceptions import AirflowRuntimeError + + if frame.error is not None: + err = self.err_decoder.validate_python(frame.error) + raise AirflowRuntimeError(error=err) + + if frame.body is None: + return None + + try: + return self.body_decoder.validate_python(frame.body) + except Exception: + self.log.exception("Unable to decode message") + raise + + def _get_response(self) -> ReceiveMsgType | None: + frame = self._read_frame() + return self._from_frame(frame) + class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -87,13 +235,7 @@ class StartupDetails(BaseModel): ti: TaskInstance dag_rel_path: str bundle_info: BundleInfo - requests_fd: int start_date: datetime - """ - The channel for the task to send requests over. - - Responses will come back on stdin - """ ti_context: TIRunContext type: Literal["StartupDetails"] = "StartupDetails" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 438687bdeb870..54fb0c9e4920e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -150,13 +150,7 @@ def _get_connection(conn_id: str) -> Connection: from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -203,13 +197,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -259,11 +247,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) def _delete_variable(key: str) -> None: @@ -275,12 +259,7 @@ def _delete_variable(key: str) -> None: from airflow.sdk.execution_time.comms import DeleteVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=DeleteVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -383,23 +362,29 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: @staticmethod def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: from airflow.sdk.definitions.asset import Asset - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetAssetByName, + GetAssetByUri, + ToSupervisor, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if name: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + msg = GetAssetByName(name=name) elif uri: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + msg = GetAssetByUri(uri=uri) else: raise ValueError("Either name or uri must be provided") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetResult) - return Asset(**msg.model_dump(exclude={"type"})) + assert isinstance(resp, AssetResult) + return Asset(**resp.model_dump(exclude={"type"})) @attrs.define @@ -533,9 +518,11 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve ErrorResponse, GetAssetEventByAsset, GetAssetEventByAssetAlias, + ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): @@ -545,31 +532,33 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve if isinstance(obj, Asset): asset = self._assets[AssetUniqueKey.from_asset(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri)) + msg = GetAssetEventByAsset(name=asset.name, uri=asset.uri) elif isinstance(obj, AssetNameRef): try: asset = next(a for k, a in self._assets.items() if k.name == obj.name) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=None)) + msg = GetAssetEventByAsset(name=asset.name, uri=None) elif isinstance(obj, AssetUriRef): try: asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=None, uri=asset.uri)) + msg = GetAssetEventByAsset(name=None, uri=asset.uri) elif isinstance(obj, AssetAlias): asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name)) + msg = GetAssetEventByAssetAlias(alias_name=asset_alias.name) + else: + raise TypeError(f"`key` is of unknown type ({type(key).__name__})") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetEventsResult) + assert isinstance(resp, AssetEventsResult) - return list(msg.iter_asset_event_results()) + return list(resp.iter_asset_event_results()) @attrs.define @@ -626,8 +615,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 9cf9acfac81bb..d8b356870c7aa 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -91,16 +91,14 @@ def __len__(self) -> int: task = self._xcom_arg.operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComCount( + msg = SUPERVISOR_COMMS.send( + GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, run_id=self._ti.run_id, task_id=task.task_id, ), ) - msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise RuntimeError(msg) if not isinstance(msg, XComCountResponse): @@ -127,43 +125,37 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if isinstance(key, slice): start, stop, step = _coerce_slice(key) - with SUPERVISOR_COMMS.lock: - source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=xcom_arg.key, - dag_id=source.dag_id, - task_id=source.task_id, - run_id=self._ti.run_id, - start=start, - stop=stop, - step=step, - ), - ) - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComSequenceSliceResult): - raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") - return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] - - if not isinstance(key, int): - if (index := getattr(key, "__index__", None)) is not None: - key = index() - raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") - - with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceItem( + msg = SUPERVISOR_COMMS.send( + GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, task_id=source.task_id, run_id=self._ti.run_id, - offset=key, + start=start, + stop=stop, + step=step, ), ) - msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") + return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] + + if not isinstance(key, int): + if (index := getattr(key, "__index__", None)) is not None: + key = index() + raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") + + source = (xcom_arg := self._xcom_arg).operator + msg = SUPERVISOR_COMMS.send( + GetXComSequenceItem( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + offset=key, + ), + ) if isinstance(msg, ErrorResponse): raise IndexError(key) if not isinstance(msg, XComSequenceIndexResult): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index dc63daeef7bb7..a19acf3d13d18 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,12 +27,13 @@ import signal import sys import time +import weakref from collections import deque from collections.abc import Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus -from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair +from socket import socket, socketpair from typing import ( TYPE_CHECKING, BinaryIO, @@ -44,7 +45,6 @@ ) from uuid import UUID -import aiologic import attrs import httpx import msgspec @@ -64,6 +64,7 @@ XComSequenceIndexResponse, ) from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -109,6 +110,8 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -180,23 +183,6 @@ ********************************************************************************************************""" -def mkpipe( - remote_read: bool = False, -) -> tuple[socket, socket]: - """Create a pair of connected sockets.""" - rsock, wsock = socketpair() - local, remote = (wsock, rsock) if remote_read else (rsock, wsock) - - if remote_read: - # Setting a 4KB buffer here if possible, if not, it still works, so we will suppress all exceptions - with suppress(Exception): - local.setsockopt(SO_SNDBUF, SOL_SOCKET, BUFFER_SIZE) - # set nonblocking to True so that send or sendall waits till all data is sent - local.setblocking(True) - - return remote, local - - def _subprocess_main(): from airflow.sdk.execution_time.task_runner import main @@ -224,14 +210,13 @@ def _configure_logs_over_json_channel(log_fd: int): def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, fd, sock, mode in ( - ("stdin", 0, child_stdin, "r"), + # Yes, we want to re-open stdin in write mode! This is cause it is a bi-directional socket, so we can + # read and write to it. + ("stdin", 0, child_stdin, "w"), ("stdout", 1, child_stdout, "w"), ("stderr", 2, child_stderr, "w"), ): - handle = getattr(sys, handle_name) - handle.close() os.dup2(sock.fileno(), fd) del sock @@ -319,7 +304,7 @@ def __getattr__(name: str): def _fork_main( - child_stdin: socket, + requests: socket, child_stdout: socket, child_stderr: socket, log_fd: int, @@ -346,10 +331,12 @@ def _fork_main( # Store original stderr for last-chance exception handling last_chance_stderr = _get_last_chance_stderr() + # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno()) + _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) - _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + _reopen_std_io_handles(requests, child_stdout, child_stderr) def exit(n: int) -> NoReturn: with suppress(ValueError, OSError): @@ -360,11 +347,11 @@ def exit(n: int) -> NoReturn: last_chance_stderr.flush() # Explicitly close the child-end of our supervisor sockets so - # the parent sees EOF on both "requests" and "logs" channels. + # the parent sees EOF on "logs" channel. with suppress(OSError): os.close(log_fd) with suppress(OSError): - os.close(child_stdin.fileno()) + os.close(requests.fileno()) os._exit(n) if hasattr(atexit, "_clear"): @@ -432,16 +419,18 @@ class WatchedSubprocess: """The decoder to use for incoming messages from the child process.""" _process: psutil.Process = attrs.field(repr=False) - _requests_fd: int """File descriptor for request handling.""" - _num_open_sockets: int = 4 _exit_code: int | None = attrs.field(default=None, init=False) _process_exit_monotonic: float | None = attrs.field(default=None, init=False) - _fd_to_socket_type: dict[int, str] = attrs.field(factory=dict, init=False) + _open_sockets: weakref.WeakKeyDictionary[socket, str] = attrs.field( + factory=weakref.WeakKeyDictionary, init=False + ) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector, repr=False) + _frame_encoder: msgspec.msgpack.Encoder = attrs.field(factory=comms._new_encoder, repr=False) + process_log: FilteringBoundLogger = attrs.field(repr=False) subprocess_logs_to_stdout: bool = False @@ -460,18 +449,19 @@ def start( ) -> Self: """Fork and start a new subprocess with the specified target function.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess - child_stdin, feed_stdin = mkpipe(remote_read=True) - child_stdout, read_stdout = mkpipe() - child_stderr, read_stderr = mkpipe() + child_stdout, read_stdout = socketpair() + child_stderr, read_stderr = socketpair() - # Open these socketpair before forking off the child, so that it is open when we fork. - child_comms, read_msgs = mkpipe() - child_logs, read_logs = mkpipe() + # Place for child to send requests/read responses, and the server side to read/respond + child_requests, read_requests = socketpair() + + # Open the socketpair before forking off the child, so that it is open when we fork. + child_logs, read_logs = socketpair() pid = os.fork() if pid == 0: # Close and delete of the parent end of the sockets. - cls._close_unused_sockets(feed_stdin, read_stdout, read_stderr, read_msgs, read_logs) + cls._close_unused_sockets(read_requests, read_stdout, read_stderr, read_logs) # Python GC should delete these for us, but lets make double sure that we don't keep anything # around in the forked processes, especially things that might involve open files or sockets! @@ -480,28 +470,28 @@ def start( try: # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + _fork_main(child_requests, child_stdout, child_stderr, child_logs.fileno(), target) except BaseException as e: + import traceback + with suppress(BaseException): # We can't use log here, as if we except out of _fork_main something _weird_ went on. - print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + print("Exception in _fork_main, exiting with code 124", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr) # It's really super super important we never exit this block. We are in the forked child, and if we # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) os._exit(124) - requests_fd = child_comms.fileno() - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) + cls._close_unused_sockets(child_stdout, child_stderr, child_logs) logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc = cls( pid=pid, - stdin=feed_stdin, + stdin=read_requests, process=psutil.Process(pid), - requests_fd=requests_fd, process_log=logger, start_time=time.monotonic(), **constructor_kwargs, @@ -510,7 +500,7 @@ def start( proc._register_pipe_readers( stdout=read_stdout, stderr=read_stderr, - requests=read_msgs, + requests=read_requests, logs=read_logs, ) @@ -523,24 +513,26 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - # Track socket types for debugging - self._fd_to_socket_type = { - stdout.fileno(): "stdout", - stderr.fileno(): "stderr", - requests.fileno(): "requests", - logs.fileno(): "logs", - } + # Track the open sockets, and for debugging what type each one is + self._open_sockets.update( + ( + (stdout, "stdout"), + (stderr, "stderr"), + (logs, "logs"), + (requests, "requests"), + ) + ) target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) if self.subprocess_logs_to_stdout: target_loggers += (log,) self.selector.register( - stdout, selectors.EVENT_READ, self._create_socket_handler(target_loggers, channel="stdout") + stdout, selectors.EVENT_READ, self._create_log_forwarder(target_loggers, channel="stdout") ) self.selector.register( stderr, selectors.EVENT_READ, - self._create_socket_handler(target_loggers, channel="stderr", log_level=logging.ERROR), + self._create_log_forwarder(target_loggers, channel="stderr", log_level=logging.ERROR), ) self.selector.register( logs, @@ -552,37 +544,47 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke self.selector.register( requests, selectors.EVENT_READ, - make_buffered_socket_reader(self.handle_requests(log), on_close=self._on_socket_closed), + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), ) - def _create_socket_handler(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: + def _create_log_forwarder(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: """Create a socket handler that forwards logs to a logger.""" return make_buffered_socket_reader( forward_to_log(loggers, chan=channel, level=log_level), on_close=self._on_socket_closed ) - def _on_socket_closed(self): + def _on_socket_closed(self, sock: socket): # We want to keep servicing this process until we've read up to EOF from all the sockets. - self._num_open_sockets -= 1 - def send_msg(self, msg: BaseModel, **dump_opts): - """Send the given pydantic message to the subprocess at once by encoding it and adding a line break.""" - b = msg.model_dump_json(**dump_opts).encode() + b"\n" - self.stdin.sendall(b) + with suppress(KeyError): + self.selector.unregister(sock) + del self._open_sockets[sock] + + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): + """Send the msg as a length-prefixed response frame.""" + if msg: + frame = _ResponseFrame(id=in_response_to, body=msg.model_dump(**dump_opts)) + else: + err_resp = error.model_dump() if error else None + frame = _ResponseFrame(id=in_response_to, error=err_resp) + + self.stdin.sendall(frame.as_bytes()) - def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _RequestFrame, None]: """Handle incoming requests from the task process, respond with the appropriate data.""" while True: - line = yield + request = yield try: - msg = self.decoder.validate_json(line) + msg = self.decoder.validate_python(request.body) except Exception: - log.exception("Unable to decode message", line=line) + log.exception("Unable to decode message", body=request.body) continue try: - self._handle_request(msg, log) + self._handle_request(msg, log, request.id) except ServerResponseError as e: error_details = e.response.json() if e.response else None log.error( @@ -594,27 +596,25 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N # Send error response back to task so that the error appears in the task logs self.send_msg( - ErrorResponse( + msg=None, + error=ErrorResponse( error=ErrorType.API_SERVER_ERROR, detail={ "status_code": e.response.status_code, "message": str(e), "detail": error_details, }, - ) + ), + in_response_to=request.id, ) - def _handle_request(self, msg, log: FilteringBoundLogger) -> None: + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: raise NotImplementedError() @staticmethod def _close_unused_sockets(*sockets): """Close unused ends of sockets after fork.""" for sock in sockets: - if isinstance(sock, SocketIO): - # If we have the socket IO object, we need to close the underlying socket foricebly here too, - # else we get unclosed socket warnings, and likely leaking FDs too - sock._sock.close() sock.close() def _cleanup_open_sockets(self): @@ -624,20 +624,18 @@ def _cleanup_open_sockets(self): # sockets the supervisor would wait forever thinking they are still # active. This cleanup ensures we always release resources and exit. stuck_sockets = [] - for key in list(self.selector.get_map().values()): - socket_type = self._fd_to_socket_type.get(key.fd, f"unknown-{key.fd}") - stuck_sockets.append(f"{socket_type}({key.fd})") + for sock, socket_type in self._open_sockets.items(): + fileno = "unknown" with suppress(Exception): - self.selector.unregister(key.fileobj) - with suppress(Exception): - key.fileobj.close() # type: ignore[union-attr] + fileno = sock.fileno() + sock.close() + stuck_sockets.append(f"{socket_type}(fd={fileno})") if stuck_sockets: log.warning("Force-closed stuck sockets", pid=self.pid, sockets=stuck_sockets) self.selector.close() - self._close_unused_sockets(self.stdin) - self._num_open_sockets = 0 + self.stdin.close() def kill( self, @@ -736,7 +734,7 @@ def _service_subprocess( events = self.selector.select(timeout=timeout) for key, _ in events: # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) - socket_handler = key.data + socket_handler, on_close = key.data # Example of handler behavior: # If the subprocess writes "Hello, World!" to stdout: @@ -746,15 +744,16 @@ def _service_subprocess( # to EOF case try: need_more = socket_handler(key.fileobj) - except BrokenPipeError: + except (BrokenPipeError, ConnectionResetError): need_more = False # If the handler signals that the file object is no longer needed (EOF, closed, etc.) # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors # are removed. if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -773,16 +772,16 @@ def _check_subprocess_exit( raise else: self._process_exit_monotonic = time.monotonic() - self._close_unused_sockets(self.stdin) - # Put a message in the viewable task logs if expect_signal is not None and self._exit_code == -expect_signal: # Bypass logging, the caller expected us to exit with this return self._exit_code - # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): + # Put a message in the viewable task logs + if self._exit_code == -signal.SIGSEGV: self.process_log.critical(SIGSEGV_MESSAGE) + # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): elif name := getattr(self._exit_code, "name", None): message = "Process terminated by signal" level = logging.ERROR @@ -809,7 +808,7 @@ class ActivitySubprocess(WatchedSubprocess): _last_heartbeat_attempt: float = attrs.field(default=0, init=False) # After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we - # will kill the process. This is to handle temporary network issues etc. ensuring that the process + # will kill theprocess. This is to handle temporary network issues etc. ensuring that the process # does not hang around forever. failed_heartbeats: int = attrs.field(default=0, init=False) @@ -861,7 +860,6 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st ti=ti, dag_rel_path=os.fspath(dag_rel_path), bundle_info=bundle_info, - requests_fd=self._requests_fd, ti_context=ti_context, start_date=start_date, ) @@ -870,8 +868,8 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st log.debug("Sending", msg=msg) try: - self.send_msg(msg) - except BrokenPipeError: + self.send_msg(msg, in_response_to=0) + except (BrokenPipeError, ConnectionResetError): # Debug is fine, the process will have shown _something_ in it's last_chance exception handler log.debug("Couldn't send startup message to Subprocess - it died very early", pid=self.pid) @@ -930,7 +928,7 @@ def _monitor_subprocess(self): - Processes events triggered on the monitored file objects, such as data availability or EOF. - Sends heartbeats to ensure the process is alive and checks if the subprocess has exited. """ - while self._exit_code is None or self._num_open_sockets > 0: + while self._exit_code is None or self._open_sockets: last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. @@ -946,16 +944,11 @@ def _monitor_subprocess(self): # This listens for activity (e.g., subprocess output) on registered file objects alive = self._service_subprocess(max_wait_time=max_wait_time) is None - if self._exit_code is not None and self._num_open_sockets > 0: + if self._exit_code is not None and self._open_sockets: if ( self._process_exit_monotonic and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT ): - log.debug( - "Forcefully closing remaining sockets", - open_sockets=self._num_open_sockets, - pid=self.pid, - ) self._cleanup_open_sockets() if alive: @@ -1051,7 +1044,7 @@ def final_state(self): return SERVER_TERMINATED return TaskInstanceState.FAILED - def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} @@ -1224,10 +1217,17 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): dump_opts = {"exclude_unset": True} else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) def in_process_api_server(): @@ -1237,24 +1237,26 @@ def in_process_api_server(): return api -@attrs.define +@attrs.define(kw_only=True) class InProcessSupervisorComms: """In-process communication handler that uses deques instead of sockets.""" + log: FilteringBoundLogger = attrs.field(repr=False, factory=structlog.get_logger) supervisor: InProcessTestSupervisor - messages: deque[BaseModel] = attrs.field(factory=deque) - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock) + messages: deque[BaseModel | None] = attrs.field(factory=deque) - def get_message(self) -> BaseModel: + def _get_response(self) -> BaseModel | None: """Get a message from the supervisor. Blocks until a message is available.""" return self.messages.popleft() - def send_request(self, log, msg: BaseModel): + def send(self, msg: BaseModel): """Send a request to the supervisor.""" - log.debug("Sending request", msg=msg) + self.log.debug("Sending request", msg=msg) with set_supervisor_comms(None): - self.supervisor._handle_request(msg, log) # type: ignore[arg-type] + self.supervisor._handle_request(msg, log, 0) # type: ignore[arg-type] + + return self._get_response() @attrs.define @@ -1272,7 +1274,8 @@ class InProcessTestSupervisor(ActivitySubprocess): """A supervisor that runs tasks in-process for easier testing.""" comms: InProcessSupervisorComms = attrs.field(init=False) - stdin = attrs.field(init=False) + + stdin: socket = attrs.field(init=False) @classmethod def start( # type: ignore[override] @@ -1298,7 +1301,6 @@ def start( # type: ignore[override] id=what.id, pid=os.getpid(), # Use current process process=psutil.Process(), # Current process - requests_fd=-1, # Not used in in-process mode process_log=logger or structlog.get_logger(logger_name="task").bind(), client=cls._api_client(task.dag), **kwargs, @@ -1363,7 +1365,9 @@ def _api_client(dag=None): client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def send_msg(self, msg: BaseModel, **dump_opts): + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): """Override to use in-process comms.""" self.comms.messages.append(msg) @@ -1421,9 +1425,9 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: # to a (sync) generator def make_buffered_socket_reader( gen: Generator[None, bytes | bytearray, None], - on_close: Callable, + on_close: Callable[[socket], None], buffer_size: int = 4096, -) -> Callable[[socket], bool]: +): buffer = bytearray() # This will hold our accumulated binary data read_buffer = bytearray(buffer_size) # Temporary buffer for each read @@ -1441,7 +1445,6 @@ def cb(sock: socket): with suppress(StopIteration): gen.send(buffer) # Tell loop to close this selector - on_close() return False buffer.extend(read_buffer[:n_received]) @@ -1452,18 +1455,62 @@ def cb(sock: socket): try: gen.send(line) except StopIteration: - on_close() return False buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data return True - return cb + return cb, on_close + + +def length_prefixed_frame_reader( + gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None] +): + length_needed: int | None = None + # This will hold our accumulated/partial binary frame if it doesn't come in a single read + buffer: memoryview | None = None + # position in the buffer to store next read + pos = 0 + decoder = msgspec.msgpack.Decoder[_RequestFrame](_RequestFrame) + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, length_needed, pos + + if length_needed is None: + # Read the 32bit length of the frame + bytes = sock.recv(4) + if bytes == b"": + return False + + length_needed = int.from_bytes(bytes, byteorder="big") + buffer = memoryview(bytearray(length_needed)) + if length_needed and buffer: + n = sock.recv_into(buffer[pos:]) + if n == 0: + # EOF + return False + pos += n + + if pos >= length_needed: + request = decoder.decode(buffer) + buffer = None + pos = 0 + length_needed = None + try: + gen.send(request) + except StopIteration: + return False + return True + + return cb, on_close def process_log_messages_from_subprocess( loggers: tuple[FilteringBoundLogger, ...], -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: from structlog.stdlib import NAME_TO_LEVEL while True: @@ -1499,10 +1546,9 @@ def process_log_messages_from_subprocess( def forward_to_log( target_loggers: tuple[FilteringBoundLogger, ...], chan: str, level: int -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: while True: - buf = yield - line = bytes(buf) + line = yield # Strip off new line line = line.rstrip() try: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d2cc479d5f4e2..c545c5eebc7b7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -28,16 +28,14 @@ from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from datetime import datetime, timezone -from io import FileIO from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TextIO, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Literal -import aiologic import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -59,6 +57,7 @@ from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, + CommsDecoder, DagRunStateResult, DeferTask, DRCount, @@ -424,10 +423,9 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - SUPERVISOR_COMMS.send_request( - log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) + response = SUPERVISOR_COMMS.send( + msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -445,22 +443,17 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTICount( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, TICount) @@ -477,21 +470,16 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTaskStates( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTaskStates( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + ), + ) if TYPE_CHECKING: assert isinstance(response, TaskStatesResult) @@ -506,19 +494,14 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of DAG runs matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, DRCount) @@ -528,10 +511,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the DAG run with the given Run ID.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -650,62 +630,6 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -SendMsgType = TypeVar("SendMsgType", bound=BaseModel) -ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) - - -@attrs.define() -class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): - """Handle communication between the task in this process and the supervisor parent process.""" - - input: TextIO - - request_socket: FileIO = attrs.field(init=False, default=None) - - # We could be "clever" here and set the default to this based type parameters and a custom - # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a - # "sort of wrong default" - decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) - - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False) - - def get_message(self) -> ReceiveMsgType: - """ - Get a message from the parent. - - This will block until the message has been received. - """ - line = None - - # TODO: Investigate why some empty lines are sent to the processes stdin. - # That was highlighted when working on https://github.com/apache/airflow/issues/48183 - # and is maybe related to deferred/triggerer only context. - while not line: - line = self.input.readline() - - try: - msg = self.decoder.validate_json(line) - except Exception: - structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) - raise - - if isinstance(msg, StartupDetails): - # If we read a startup message, pull out the FDs we care about! - if msg.requests_fd > 0: - self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - elif isinstance(msg, ErrorResponse) and msg.error == ErrorType.API_SERVER_ERROR: - structlog.get_logger(logger_name="task").error("Error response from the API Server") - raise AirflowRuntimeError(error=msg) - - return msg - - def send_request(self, log: Logger, msg: SendMsgType): - encoded_msg = msg.model_dump_json().encode() + b"\n" - - log.debug("Sending request", json=encoded_msg) - self.request_socket.write(encoded_msg) - - # This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # @@ -725,31 +649,33 @@ def send_request(self, log: Logger, msg: SendMsgType): def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - msg = SUPERVISOR_COMMS.get_message() + # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent + # in response to us sending a request. + msg = SUPERVISOR_COMMS._get_response() + + if not isinstance(msg, StartupDetails): + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") log = structlog.get_logger(logger_name="task") + # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 + os_type = sys.platform + if os_type == "darwin": + log.debug("Mac OS detected, skipping setproctitle") + else: + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + try: get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) except Exception: log.exception("error calling listener") - if isinstance(msg, StartupDetails): - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 - os_type = sys.platform - if os_type == "darwin": - log.debug("Mac OS detected, skipping setproctitle") - else: - from setproctitle import setproctitle - - setproctitle(f"airflow worker -- {msg.ti.id}") - - with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): - ti = parse(msg, log) - ti.log_url = get_log_url_from_ti(ti) - log.debug("DAG file parsed", file=msg.dag_rel_path) - else: - raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): + ti = parse(msg, log) + ti.log_url = get_log_url_from_ti(ti) + log.debug("DAG file parsed", file=msg.dag_rel_path) return ti, ti.get_template_context(), log @@ -797,7 +723,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) _validate_task_inlets_and_outlets(ti=ti, log=log) @@ -817,8 +743,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log) - inactive_assets_resp = SUPERVISOR_COMMS.get_message() + inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -914,7 +839,7 @@ def run( except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send_request(log=log, msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -974,7 +899,7 @@ def run( error = e finally: if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + SUPERVISOR_COMMS.send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1012,9 +937,8 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - SUPERVISOR_COMMS.send_request( - log=log, - msg=TriggerDagRun( + comms_msg = SUPERVISOR_COMMS.send( + TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, logical_date=drte.logical_date, @@ -1023,7 +947,6 @@ def _handle_trigger_dag_run( ), ) - comms_msg = SUPERVISOR_COMMS.get_message() if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: if drte.skip_when_already_exists: log.info( @@ -1078,10 +1001,9 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - SUPERVISOR_COMMS.send_request( - log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) + comms_msg = SUPERVISOR_COMMS.send( + GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) ) - comms_msg = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1270,10 +1192,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)), - ) + SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1314,9 +1233,11 @@ def finalize( def main(): - # TODO: add an exception here, it causes an oof of a stack trace! + # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! + log = structlog.get_logger(logger_name="task") + global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) try: ti, context, log = startup() @@ -1327,11 +1248,9 @@ def main(): state, msg, error = run(ti, context, log) finalize(ti, state, context, log, error) except KeyboardInterrupt: - log = structlog.get_logger(logger_name="task") log.exception("Ctrl-c hit") exit(2) except Exception: - log = structlog.get_logger(logger_name="task") log.exception("Top level error") exit(1) finally: diff --git a/task-sdk/tests/task_sdk/bases/test_sensor.py b/task-sdk/tests/task_sdk/bases/test_sensor.py index 2c0a82783d4d6..d0e2b430d2c74 100644 --- a/task-sdk/tests/task_sdk/bases/test_sensor.py +++ b/task-sdk/tests/task_sdk/bases/test_sensor.py @@ -205,7 +205,7 @@ def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_su time_machine.coordinates.shift(sensor.poke_interval) # Mocking values from DB/API-server - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=date1) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1) state, msg, error = run_task(task=sensor, context_update={"task_reschedule_count": 1}) assert state == State.FAILED @@ -227,7 +227,7 @@ def test_soft_fail_with_reschedule(self, run_task, make_sensor, time_machine, mo time_machine.coordinates.shift(sensor.poke_interval) # Mocking values from DB/API-server - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=date1) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1) state, msg, _ = run_task(task=sensor, context_update={"task_reschedule_count": 1}) assert state == State.SKIPPED @@ -254,7 +254,7 @@ def run_duration(): return (timezone.utcnow() - task_start_date).total_seconds() new_interval = 0 - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=task_start_date) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=task_start_date) # loop poke returns false for _poke_count in range(1, false_count + 1): @@ -516,9 +516,9 @@ def _run_task(): # For timeout calculation, we need to use the first reschedule date # This ensures the timeout is calculated from the start of the task if test_state["first_reschedule_date"] is None: - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=None) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=None) else: - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate( + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate( start_date=test_state["first_reschedule_date"] ) diff --git a/task-sdk/tests/task_sdk/definitions/conftest.py b/task-sdk/tests/task_sdk/definitions/conftest.py index c0ac385628388..3f89f34b4d2da 100644 --- a/task-sdk/tests/task_sdk/definitions/conftest.py +++ b/task-sdk/tests/task_sdk/definitions/conftest.py @@ -36,12 +36,12 @@ def run(dag: DAG, task_id: str, map_index: int): log = structlog.get_logger(__name__) - mock_supervisor_comms.send_request.reset_mock() + mock_supervisor_comms.send.reset_mock() ti = create_runtime_ti(dag.task_dict[task_id], map_index=map_index) run(ti, ti.get_template_context(), log) - for call in mock_supervisor_comms.send_request.mock_calls: - msg = call.kwargs["msg"] + for call in mock_supervisor_comms.send.mock_calls: + msg = call.kwargs.get("msg") or call.args[0] if isinstance(msg, (TaskState, SucceedTask)): return msg.state raise RuntimeError("Unable to find call to TaskState") diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py index 264cfee06290a..c066b70d99c9f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py @@ -295,7 +295,7 @@ def test_determine_kwargs( example_asset_func_with_valid_arg_as_inlet_asset ) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult( name="example_asset_func", uri="s3://bucket/object", @@ -326,12 +326,9 @@ def test_determine_kwargs( } assert mock_supervisor_comms.mock_calls == [ - mock.call.send_request(mock.ANY, GetAssetByName(name="example_asset_func")), - mock.call.get_message(), - mock.call.send_request(mock.ANY, GetAssetByName(name="inlet_asset_1")), - mock.call.get_message(), - mock.call.send_request(mock.ANY, GetAssetByName(name="inlet_asset_2")), - mock.call.get_message(), + mock.call.send(GetAssetByName(name="example_asset_func")), + mock.call.send(GetAssetByName(name="inlet_asset_1")), + mock.call.send(GetAssetByName(name="inlet_asset_2")), ] @mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) @@ -342,7 +339,7 @@ def test_determine_kwargs_defaults( ): asset_definition = asset(schedule=None)(example_asset_func_with_valid_arg_as_inlet_asset_and_default) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult(name="inlet_asset_1", uri="s3://bucket/object1", group="asset", extra=None), ] @@ -360,6 +357,5 @@ def test_determine_kwargs_defaults( } assert mock_supervisor_comms.mock_calls == [ - mock.call.send_request(mock.ANY, GetAssetByName(name="inlet_asset_1")), - mock.call.get_message(), + mock.call.send(GetAssetByName(name="inlet_asset_1")), ] diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py b/task-sdk/tests/task_sdk/definitions/test_connections.py index 102e85d36b30d..3bbb63a769788 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connections.py +++ b/task-sdk/tests/task_sdk/definitions/test_connections.py @@ -104,7 +104,7 @@ def test_get_uri(self): def test_conn_get(self, mock_supervisor_comms): conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result conn = Connection.get(conn_id="mysql_conn") assert conn is not None diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index cdb9c954fd916..5c81b64b605b3 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -251,7 +251,7 @@ def execute(self, context): ) mapped = callable(mapped, task1.output) - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["{{ ds }}"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["{{ ds }}"]) mapped_ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={task1.task_id: 1}) @@ -299,7 +299,7 @@ def test_expand_kwargs_render_template_fields_validating_operator( task1 = BaseOperator(task_id="op1") mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="return_value", value=[{"arg1": "{{ ds }}"}, {"arg1": 2}] ) @@ -427,16 +427,14 @@ def show(number, letter): show.expand(number=emit_numbers(), letter=emit_letters()) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - task = dag.get_task(last_request.task_id) + task = dag.get_task(msg.task_id) value = task.python_callable() return XComResult(key="return_value", value=value) - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get states = [run_ti(dag, "show", map_index) for map_index in range(6)] assert states == [TaskInstanceState.SUCCESS] * 6 @@ -467,16 +465,14 @@ def show(a, b): emit_task = emit_numbers() show.expand(a=emit_task, b=emit_task) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - task = dag.get_task(last_request.task_id) + task = dag.get_task(msg.task_id) value = task.python_callable() return XComResult(key="return_value", value=value) - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get states = [run_ti(dag, "show", map_index) for map_index in range(4)] assert states == [TaskInstanceState.SUCCESS] * 4 @@ -594,22 +590,20 @@ def tg(va): # Aggregates results from task group. t.override(task_id="t3")(tg1) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - key = (last_request.task_id, last_request.map_index) + key = (msg.task_id, msg.map_index) if key in expected_values: value = expected_values[key] return XComResult(key="return_value", value=value) - if last_request.map_index is None: + if msg.map_index is None: # Get all mapped XComValues for this ti - value = [v for k, v in expected_values.items() if k[0] == last_request.task_id] + value = [v for k, v in expected_values.items() if k[0] == msg.task_id] return XComResult(key="return_value", value=value) return mock.DEFAULT - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get expected_values = { ("tg.t1", 0): ["a", "b"], @@ -683,10 +677,9 @@ def group(x): ti.task.execute(context) assert ti - mock_supervisor_comms.send_request.assert_has_calls( + mock_supervisor_comms.send.assert_has_calls( [ mock.call( - log=mock.ANY, msg=SetXCom( key="skipmixin_key", value={"skipped": ["group.empty_task"]}, diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 8bcf9ef28199a..c85924df6a6d9 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -51,7 +51,7 @@ class TestVariables: ) def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_comms): var_result = VariableResult(key="my_key", value=value) - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result var = Variable.get(key="my_key", deserialize_json=deserialize_json) assert var is not None @@ -83,8 +83,7 @@ def test_var_set(self, key, value, description, serialize_json, mock_supervisor_ if serialize_json: expected_value = json.dumps(value, indent=2) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=PutVariable( key=key, value=expected_value, description=description, serialize_json=serialize_json ), diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index d313b45dae265..73468dcb9e12b 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -52,7 +52,7 @@ def pull(value): assert set(dag.task_dict) == {"push", "pull"} # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) for map_index in range(3): assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS @@ -81,7 +81,7 @@ def c_to_none(v): pull.expand(value=push().map(c_to_none)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # Run "pull". This should automatically convert "c" to None. for map_index in range(3): @@ -111,7 +111,7 @@ def c_to_none(v): pull.expand_kwargs(push().map(c_to_none)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # The first two "pull" tis should succeed. for map_index in range(2): @@ -165,7 +165,7 @@ def does_not_work_with_c(v): pull.expand_kwargs(push().map(does_not_work_with_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # The third one (for "c") will fail. assert run_ti(dag, "pull", 2) == TaskInstanceState.FAILED @@ -209,7 +209,7 @@ def pull(value): pull.expand_kwargs(converted) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # Now "pull" should apply the mapping functions in order. for map_index in range(3): @@ -243,20 +243,18 @@ def convert_zipped(zipped): pull.expand(value=combined.map(convert_zipped)) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - if last_request.task_id == "push_letters": + if msg.task_id == "push_letters": value = push_letters.function() return XComResult(key="return_value", value=value) - if last_request.task_id == "push_numbers": + if msg.task_id == "push_numbers": value = push_numbers.function() return XComResult(key="return_value", value=value) return mock.DEFAULT - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get # Run "pull". for map_index in range(4): @@ -286,7 +284,7 @@ def skip_c(v): forward.expand_kwargs(push().map(skip_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # Run "forward". This should automatically skip "c". states = [run_ti(dag, "forward", map_index) for map_index in range(3)] @@ -341,20 +339,18 @@ def pull_all(value): pull_one.expand(value=pushed_values) pull_all(pushed_values) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - if last_request.task_id == "push_letters": + if msg.task_id == "push_letters": value = push_letters.function() return XComResult(key="return_value", value=value) - if last_request.task_id == "push_numbers": + if msg.task_id == "push_numbers": value = push_numbers.function() return XComResult(key="return_value", value=value) return mock.DEFAULT - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get # Run "pull_one" and "pull_all". assert run_ti(dag, "pull_all", None) == TaskInstanceState.SUCCESS diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py new file mode 100644 index 0000000000000..ee2f956b063c9 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +import msgspec +import pytest + +from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails, _ResponseFrame +from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.utils import timezone + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + @pytest.mark.usefixtures("disable_capturing") + def test_recv_StartupDetails(self): + r, w = socketpair() + + msg = { + "type": "StartupDetails", + "ti": { + "id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + "task_id": "a", + "try_number": 1, + "run_id": "b", + "dag_id": "c", + }, + "ti_context": { + "dag_run": { + "dag_id": "c", + "run_id": "b", + "logical_date": "2024-12-01T01:00:00Z", + "data_interval_start": "2024-12-01T00:00:00Z", + "data_interval_end": "2024-12-01T01:00:00Z", + "start_date": "2024-12-01T01:00:00Z", + "run_after": "2024-12-01T01:00:00Z", + "end_date": None, + "run_type": "manual", + "conf": None, + "consumed_asset_events": [], + }, + "max_tries": 0, + "should_retry": False, + "variables": None, + "connections": None, + }, + "file": "/dev/null", + "start_date": "2024-12-01T01:00:00Z", + "dag_rel_path": "/dev/null", + "bundle_info": {"name": "any-name", "version": "any-version"}, + } + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes) + + decoder = CommsDecoder(request_socket=r, log=None) + + msg = decoder._get_response() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.dag_rel_path == "/dev/null" + assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") + assert msg.start_date == timezone.datetime(2024, 12, 1, 1) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 59266da758892..6880c656b01cd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -190,7 +190,7 @@ def test_getattr_connection(self, mock_supervisor_comms): # Conn from the supervisor / API Server conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection; triggers __getattr__ conn = accessor.mysql_conn @@ -203,7 +203,7 @@ def test_get_method_valid_connection(self, mock_supervisor_comms): accessor = ConnectionAccessor() conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result conn = accessor.get("mysql_conn") assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) @@ -216,7 +216,7 @@ def test_get_method_with_default(self, mock_supervisor_comms): error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"} ) - mock_supervisor_comms.get_message.return_value = error_response + mock_supervisor_comms.send.return_value = error_response conn = accessor.get("nonexistent_conn", default_conn=default_conn) assert conn == default_conn @@ -233,7 +233,7 @@ def test_getattr_connection_for_extra_dejson(self, mock_supervisor_comms): extra='{"extra_key": "extra_value"}', ) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection's dejson; triggers __getattr__ dejson = accessor.mysql_conn.extra_dejson @@ -251,7 +251,7 @@ def test_getattr_connection_for_extra_dejson_decode_error(self, mock_log, mock_s conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306, extra="This is not JSON!" ) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection's dejson; triggers __getattr__ dejson = accessor.mysql_conn.extra_dejson @@ -274,7 +274,7 @@ def test_getattr_variable(self, mock_supervisor_comms): # Variable from the supervisor / API Server var_result = VariableResult(key="test_key", value="test_value") - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result # Fetch the variable; triggers __getattr__ value = accessor.test_key @@ -286,7 +286,7 @@ def test_get_method_valid_variable(self, mock_supervisor_comms): accessor = VariableAccessor(deserialize_json=False) var_result = VariableResult(key="test_key", value="test_value") - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result val = accessor.get("test_key") assert val == var_result.value @@ -297,7 +297,7 @@ def test_get_method_with_default(self, mock_supervisor_comms): accessor = VariableAccessor(deserialize_json=False) error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"}) - mock_supervisor_comms.get_message.return_value = error_response + mock_supervisor_comms.send.return_value = error_response val = accessor.get("nonexistent_var_key", default="default_value") assert val == "default_value" @@ -367,7 +367,7 @@ class TestOutletEventAccessor: ), ) def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="") + mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") outlet_event_accessor = OutletEventAccessor(key=key, extra={}) outlet_event_accessor.add(add_arg) @@ -398,7 +398,7 @@ def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): ), ) def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="") + mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) outlet_event_accessor.add(add_arg, extra={}) @@ -497,11 +497,11 @@ def test_getitem_name_ref( resolved_asset, result_indexes, ): - mock_supervisor_comms.get_message.return_value = resolved_asset + mock_supervisor_comms.send.return_value = resolved_asset expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes] assert accessor[Asset.ref(name=name)] == expected - assert len(mock_supervisor_comms.send_request.mock_calls) == 1 - assert mock_supervisor_comms.send_request.mock_calls[0].kwargs["msg"] == GetAssetByName(name=name) + + mock_supervisor_comms.send.assert_called_once_with(GetAssetByName(name=name, type="GetAssetByName")) assert _AssetRefResolutionMixin._asset_ref_cache @pytest.mark.parametrize( @@ -520,11 +520,10 @@ def test_getitem_uri_ref( resolved_asset, result_indexes, ): - mock_supervisor_comms.get_message.return_value = resolved_asset + mock_supervisor_comms.send.return_value = resolved_asset expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes] assert accessor[Asset.ref(uri=uri)] == expected - assert len(mock_supervisor_comms.send_request.mock_calls) == 1 - assert mock_supervisor_comms.send_request.mock_calls[0].kwargs["msg"] == GetAssetByUri(uri=uri) + mock_supervisor_comms.send.assert_called_once_with(GetAssetByUri(uri=uri)) assert _AssetRefResolutionMixin._asset_ref_cache def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): @@ -534,22 +533,19 @@ def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): assert source == AssetEventSourceTaskInstance(dag_id="d1", task_id="t2", run_id="r1", map_index=-1) mock_supervisor_comms.reset_mock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComResult(key="return_value", value="__example_xcom_value__"), ] assert source.xcom_pull() == "__example_xcom_value__" - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call( - log=mock.ANY, - msg=GetXCom( - key="return_value", - dag_id="d1", - run_id="r1", - task_id="t2", - map_index=-1, - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXCom( + key="return_value", + dag_id="d1", + run_id="r1", + task_id="t2", + map_index=-1, ), - ] + ) TEST_ASSET = Asset(name="test_uri", uri="test://test") @@ -593,7 +589,7 @@ def test__get_item__asset_ref(self, access_key, asset, mock_supervisor_comms): assert len(outlet_event_accessors) == 0 # Asset from the API Server via the supervisor - mock_supervisor_comms.get_message.return_value = AssetResult( + mock_supervisor_comms.send.return_value = AssetResult( name=asset.name, uri=asset.uri, group=asset.group, @@ -628,7 +624,7 @@ def test_for_asset_alias(self, mocked__getitem__): class TestInletEventAccessor: @pytest.fixture def sample_inlet_evnets_accessor(self, mock_supervisor_comms): - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult(name="test_uri", uri="test://test", group="asset"), AssetResult(name="test_uri", uri="test://test", group="asset"), ] @@ -656,7 +652,7 @@ def test__get_item__(self, key, sample_inlet_evnets_accessor, mock_supervisor_co asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.get_message.side_effect = [events_result] * 4 + mock_supervisor_comms.send.side_effect = [events_result] * 4 assert sample_inlet_evnets_accessor[key] == [asset_event_resp] @@ -684,7 +680,7 @@ def test_for_asset_alias(self, mocked__getitem__, sample_inlet_evnets_accessor): assert mocked__getitem__.call_args[0][0] == TEST_ASSET_ALIAS def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock_supervisor_comms): - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetEventsResult( asset_events=[ AssetEventResponse( @@ -707,9 +703,7 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock ) ] events = sample_inlet_evnets_accessor[Asset.ref(name="test_uri")] - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call(log=mock.ANY, msg=GetAssetEventByAsset(name="test_uri", uri=None)), - ] + mock_supervisor_comms.send.assert_called_once_with(GetAssetEventByAsset(name="test_uri", uri=None)) assert len(events) == 2 assert events[1].source_task_instance is None @@ -723,19 +717,16 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock ) mock_supervisor_comms.reset_mock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComResult(key="return_value", value="__example_xcom_value__"), ] assert source.xcom_pull() == "__example_xcom_value__" - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call( - log=mock.ANY, - msg=GetXCom( - key="return_value", - dag_id="__dag__", - run_id="__run__", - task_id="__task__", - map_index=0, - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXCom( + key="return_value", + dag_id="__dag__", + run_id="__run__", + task_id="__task__", + map_index=0, ), - ] + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py index e4943196a09da..65dd30b2c7c7b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -17,7 +17,7 @@ from __future__ import annotations -from unittest.mock import ANY, Mock, call +from unittest.mock import Mock, call import pytest @@ -66,121 +66,109 @@ def deserialize_value(cls, xcom): def test_len(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3) + mock_supervisor_comms.send.return_value = XComCountResponse(len=3) assert len(lazy_sequence) == 3 - assert mock_supervisor_comms.send_request.mock_calls == [ - call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run")), - ] + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run"), + ) def test_iter(mock_supervisor_comms, lazy_sequence): it = iter(lazy_sequence) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComSequenceIndexResult(root="f"), ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}), ] assert list(it) == ["f"] - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=0, + mock_supervisor_comms.send.assert_has_calls( + [ + call( + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=0, + ), ), - ), - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=1, + call( + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=1, + ), ), - ), - ] + ] + ) def test_getitem_index(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="f") + mock_supervisor_comms.send.return_value = XComSequenceIndexResult(root="f") assert lazy_sequence[4] == "f" - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) @conf_vars({("core", "xcom_backend"): "task_sdk.execution_time.test_lazy_sequence.CustomXCom"}) def test_getitem_calls_correct_deserialise(monkeypatch, mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="some-value") + mock_supervisor_comms.send.return_value = XComSequenceIndexResult(root="some-value") xcom = resolve_xcom_backend() assert xcom.__name__ == "CustomXCom" monkeypatch.setattr(airflow.sdk.execution_time.xcom, "XCom", xcom) assert lazy_sequence[4] == "Made with CustomXCom: some-value" - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = ErrorResponse( + mock_supervisor_comms.send.return_value = ErrorResponse( error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}, ) with pytest.raises(IndexError) as ctx: lazy_sequence[4] assert ctx.value.args == (4,) - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) def test_getitem_slice(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceSliceResult(root=[6, 4, 1]) + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=[6, 4, 1]) assert lazy_sequence[:5] == [6, 4, 1] - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceSlice( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - start=None, - stop=5, - step=None, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceSlice( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + start=None, + stop=5, + step=None, ), - ] + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4f5e4cfc7ac9d..6e1c3b2e093cb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,13 +27,14 @@ import socket import sys import time -from io import BytesIO from operator import attrgetter +from random import randint from time import sleep from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import httpx +import msgspec import psutil import pytest from pytest_unordered import unordered @@ -56,6 +57,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, DagRunStateResult, DeferTask, @@ -97,17 +99,16 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.supervisor import ( - BUFFER_SIZE, ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, - mkpipe, set_supervisor_comms, supervise, ) -from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz if TYPE_CHECKING: @@ -136,18 +137,25 @@ def local_dag_bundle_cfg(path, name="my-bundle"): } +@pytest.fixture +def client_with_ti_start(make_ti_context): + client = MagicMock(spec=sdk_client.Client) + client.task_instances.start.return_value = make_ti_context() + return client + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @pytest.fixture(autouse=True) def disable_log_upload(self, spy_agency): spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) - def test_reading_from_pipes(self, captured_logs, time_machine): + def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start): def subprocess_main(): # This is run in the subprocess! - # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + # Ensure we follow the "protocol" and get the startup message before we do anything else + CommsDecoder()._get_response() import logging import warnings @@ -180,7 +188,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -228,12 +236,12 @@ def subprocess_main(): ] ) - def test_subprocess_sigkilled(self): + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid() def subprocess_main(): # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() assert os.getpid() != main_pid os.kill(os.getpid(), signal.SIGKILL) @@ -248,7 +256,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -285,7 +293,7 @@ def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -314,7 +322,7 @@ def test_no_heartbeat_in_overtime(self, spy_agency: kgb.SpyAgency, monkeypatch, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -340,7 +348,7 @@ def _on_child_started(self, *args, **kwargs): assert proc.wait() == 0 spy_agency.assert_spy_not_called(heartbeat_spy) - def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context): + def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, client_with_ti_start): """Test running a simple DAG in a subprocess and capturing the output.""" instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) @@ -355,11 +363,6 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker try_number=1, ) - # Create a mock client to assert calls to the client - # We assume the implementation of the client is correct and only need to check the calls - mock_client = mocker.Mock(spec=sdk_client.Client) - mock_client.task_instances.start.return_value = make_ti_context() - bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): exit_code = supervise( @@ -368,7 +371,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker token="", server="", dry_run=True, - client=mock_client, + client=client_with_ti_start, bundle_info=bundle_info, ) assert exit_code == 0, captured_logs @@ -498,7 +501,7 @@ def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, m monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.0) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() sleep(5) # Shouldn't get here exit(5) @@ -611,7 +614,6 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t stdin=mocker.MagicMock(), client=client, process=mock_process, - requests_fd=-1, ) time_now = tz.datetime(2024, 11, 28, 12, 0, 0) @@ -701,7 +703,6 @@ def test_overtime_handling( stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock(), - requests_fd=-1, ) # Set the terminal state and task end datetime @@ -738,7 +739,7 @@ def test_overtime_handling( ), ), ) - def test_exit_by_signal(self, monkeypatch, signal_to_raise, log_pattern, cap_structlog): + def test_exit_by_signal(self, signal_to_raise, log_pattern, cap_structlog, client_with_ti_start): def subprocess_main(): import faulthandler import os @@ -748,7 +749,7 @@ def subprocess_main(): faulthandler.disable() # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() os.kill(os.getpid(), signal_to_raise) @@ -762,7 +763,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -791,26 +792,26 @@ def test_cleanup_sockets_after_delay(self, monkeypatch, mocker, time_machine): stdin=mocker.MagicMock(), client=mocker.MagicMock(), process=mock_process, - requests_fd=-1, ) proc.selector = mocker.MagicMock() proc.selector.select.return_value = [] proc._exit_code = 0 - proc._num_open_sockets = 1 + # Create a fake placeholder in the open socket weakref + proc._open_sockets[mocker.MagicMock()] = "test placeholder" proc._process_exit_monotonic = time.monotonic() mocker.patch.object( ActivitySubprocess, "_cleanup_open_sockets", - side_effect=lambda: setattr(proc, "_num_open_sockets", 0), + side_effect=lambda: setattr(proc, "_open_sockets", {}), ) time_machine.shift(2) proc._monitor_subprocess() - assert proc._num_open_sockets == 0 + assert len(proc._open_sockets) == 0 class TestWatchedSubprocessKill: @@ -829,7 +830,6 @@ def watched_subprocess(self, mocker, mock_process): stdin=mocker.Mock(), client=mocker.Mock(), process=mock_process, - requests_fd=-1, ) # Mock the selector mock_selector = mocker.Mock(spec=selectors.DefaultSelector) @@ -888,7 +888,7 @@ def test_kill_process_custom_signal(self, watched_subprocess, mock_process): ), ], ) - def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, captured_logs, monkeypatch): + def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, client_with_ti_start): def subprocess_main(): import signal @@ -905,7 +905,7 @@ def _handler(sig, frame): signal.signal(signal.SIGINT, _handler) signal.signal(signal.SIGTERM, _handler) try: - sys.stdin.readline() + CommsDecoder()._get_response() print("Ready") sleep(10) except Exception as e: @@ -919,7 +919,7 @@ def _handler(sig, frame): dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -976,9 +976,11 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker): mock_stdout_handler = mocker.Mock(return_value=False) # Simulate EOF for stdout mock_stderr_handler = mocker.Mock(return_value=True) # Continue processing for stderr + mock_on_close = mocker.Mock() + # Mock selector to return events - mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=mock_stdout_handler) - mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=mock_stderr_handler) + mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=(mock_stdout_handler, mock_on_close)) + mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=(mock_stderr_handler, mock_on_close)) watched_subprocess.selector.select.return_value = [(mock_key_stdout, None), (mock_key_stderr, None)] # Mock to simulate process exited successfully @@ -996,8 +998,7 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker): mock_stderr_handler.assert_called_once_with(mock_stderr) # Validate unregistering and closing of EOF file object - watched_subprocess.selector.unregister.assert_called_once_with(mock_stdout) - mock_stdout.close.assert_called_once() + mock_on_close.assert_called_once_with(mock_stdout) # Validate that `_check_subprocess_exit` is called mock_process.wait.assert_called_once_with(timeout=0) @@ -1073,16 +1074,15 @@ def test_max_wait_time_calculation_edge_cases( class TestHandleRequest: @pytest.fixture def watched_subprocess(self, mocker): - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socket.socketpair() subprocess = ActivitySubprocess( process_log=mocker.MagicMock(), id=TI_ID, pid=12345, - stdin=write_end, # this is the writer side + stdin=write_end, client=mocker.Mock(), process=mocker.Mock(), - requests_fd=-1, ) return subprocess, read_end @@ -1091,7 +1091,7 @@ def watched_subprocess(self, mocker): @pytest.mark.parametrize( [ "message", - "expected_buffer", + "expected_body", "client_attr_path", "method_arg", "method_kwarg", @@ -1101,7 +1101,7 @@ def watched_subprocess(self, mocker): [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1111,7 +1111,12 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n', + { + "conn_id": "test_conn", + "conn_type": "mysql", + "password": "password", + "type": "ConnectionResult", + }, "connections.get", ("test_conn",), {}, @@ -1121,7 +1126,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "schema": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1131,7 +1136,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', + {"key": "test_key", "value": "test_value", "type": "VariableResult"}, "variables.get", ("test_key",), {}, @@ -1141,7 +1146,7 @@ def watched_subprocess(self, mocker): ), pytest.param( PutVariable(key="test_key", value="test_value", description="test_description"), - b"", + None, "variables.set", ("test_key", "test_value", "test_description"), {}, @@ -1151,7 +1156,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeleteVariable(key="test_key"), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "variables.delete", ("test_key",), {}, @@ -1161,7 +1166,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeferTask(next_method="execute_callback", classpath="my-classpath"), - b"", + None, "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), {}, @@ -1174,7 +1179,7 @@ def watched_subprocess(self, mocker): reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), end_date=timezone.parse("2024-10-31T12:00:00Z"), ), - b"", + None, "task_instances.reschedule", ( TI_ID, @@ -1190,7 +1195,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1202,7 +1207,7 @@ def watched_subprocess(self, mocker): GetXCom( dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 ), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2, False), {}, @@ -1212,7 +1217,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1228,7 +1233,7 @@ def watched_subprocess(self, mocker): key="test_key", include_prior_dates=True, ), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, True), {}, @@ -1244,7 +1249,7 @@ def watched_subprocess(self, mocker): key="test_key", value='{"key": "test_key", "value": {"key2": "value2"}}', ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1269,7 +1274,7 @@ def watched_subprocess(self, mocker): value='{"key": "test_key", "value": {"key2": "value2"}}', map_index=2, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1295,7 +1300,7 @@ def watched_subprocess(self, mocker): map_index=2, mapped_length=3, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1319,15 +1324,9 @@ def watched_subprocess(self, mocker): key="test_key", map_index=2, ), - b"", + None, "xcoms.delete", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - 2, - ), + ("test_dag", "test_run", "test_task", "test_key", 2), {}, OKResponse(ok=True), None, @@ -1337,7 +1336,7 @@ def watched_subprocess(self, mocker): # if it can handle TaskState message pytest.param( TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), - b"", + None, "", (), {}, @@ -1349,7 +1348,7 @@ def watched_subprocess(self, mocker): RetryTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" ), - b"", + None, "task_instances.retry", (), { @@ -1363,7 +1362,7 @@ def watched_subprocess(self, mocker): ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), - b"", + None, "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), {}, @@ -1373,7 +1372,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByName(name="asset"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"name": "asset"}, @@ -1383,7 +1382,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByUri(uri="s3://bucket/obj"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"uri": "s3://bucket/obj"}, @@ -1393,11 +1392,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": "test"}, @@ -1416,11 +1421,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name=None), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": None}, @@ -1439,11 +1450,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri=None, name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": None, "name": "test"}, @@ -1462,11 +1479,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAssetAlias(alias_name="test_alias"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"alias_name": "test_alias"}, @@ -1485,7 +1508,10 @@ def watched_subprocess(self, mocker): ), pytest.param( ValidateInletsAndOutlets(ti_id=TI_ID), - b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n', + { + "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], + "type": "InactiveAssetsResult", + }, "task_instances.validate_inlets_and_outlets", (TI_ID,), {}, @@ -1499,7 +1525,7 @@ def watched_subprocess(self, mocker): SucceedTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" ), - b"", + None, "task_instances.succeed", (), { @@ -1515,11 +1541,13 @@ def watched_subprocess(self, mocker): ), pytest.param( GetPrevSuccessfulDagRun(ti_id=TI_ID), - ( - b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",' - b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",' - b'"type":"PrevSuccessfulDagRunResult"}\n' - ), + { + "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), + "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), + "start_date": timezone.parse("2025-01-10T12:00:00Z"), + "end_date": timezone.parse("2025-01-10T14:00:00Z"), + "type": "PrevSuccessfulDagRunResult", + }, "task_instances.get_previous_successful_dagrun", (TI_ID,), {}, @@ -1540,7 +1568,7 @@ def watched_subprocess(self, mocker): logical_date=timezone.datetime(2025, 1, 1), reset_dag_run=True, ), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "dag_runs.trigger", ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), {}, @@ -1549,8 +1577,9 @@ def watched_subprocess(self, mocker): id="dag_run_trigger", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR TriggerDagRun(dag_id="test_dag", run_id="test_run"), - b'{"error":"DAGRUN_ALREADY_EXISTS","detail":null,"type":"ErrorResponse"}\n', + {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, "dag_runs.trigger", ("test_dag", "test_run", None, None, False), {}, @@ -1560,7 +1589,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDagRunState(dag_id="test_dag", run_id="test_run"), - b'{"state":"running","type":"DagRunStateResult"}\n', + {"state": "running", "type": "DagRunStateResult"}, "dag_runs.get_state", ("test_dag", "test_run"), {}, @@ -1570,7 +1599,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskRescheduleStartDate(ti_id=TI_ID), - b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n', + {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": "TaskRescheduleStartDate"}, "task_instances.get_reschedule_start_date", (TI_ID, 1), {}, @@ -1580,7 +1609,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), - b'{"count":2,"type":"TICount"}\n', + {"count": 2, "type": "TICount"}, "task_instances.get_count", (), { @@ -1598,7 +1627,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDRCount(dag_id="test_dag", states=["success", "failed"]), - b'{"count":2,"type":"DRCount"}\n', + {"count": 2, "type": "DRCount"}, "dag_runs.get_count", (), { @@ -1613,7 +1642,10 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), - b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n', + { + "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, + "type": "TaskStatesResult", + }, "task_instances.get_task_states", (), { @@ -1636,7 +1668,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=0, ), - b'{"root":"test_value","type":"XComSequenceIndexResult"}\n', + {"root": "test_value", "type": "XComSequenceIndexResult"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 0), {}, @@ -1645,6 +1677,7 @@ def watched_subprocess(self, mocker): id="get_xcom_seq_item", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR GetXComSequenceItem( key="test_key", dag_id="test_dag", @@ -1652,7 +1685,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=2, ), - b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + {"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 2), {}, @@ -1670,7 +1703,7 @@ def watched_subprocess(self, mocker): stop=None, step=None, ), - b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n', + {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", ("test_dag", "test_run", "test_task", "test_key", None, None, None), {}, @@ -1687,7 +1720,7 @@ def test_handle_requests( mocker, time_machine, message, - expected_buffer, + expected_body, client_attr_path, method_arg, method_kwarg, @@ -1715,8 +1748,9 @@ def test_handle_requests( generator = watched_subprocess.handle_requests(log=mocker.Mock()) # Initialize the generator next(generator) - msg = message.model_dump_json().encode() + b"\n" - generator.send(msg) + + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) + generator.send(req_frame) if mask_secret_args: mock_mask_secret.assert_called_with(*mask_secret_args) @@ -1729,33 +1763,23 @@ def test_handle_requests( # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(BUFFER_SIZE) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError, socket.timeout): - # no response written, valid for some message types like setters and TI operations. - pass + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id # Verify the response was added to the buffer - assert val == expected_buffer + assert frame.body == expected_body # Verify the response is correctly decoded # This is important because the subprocess/task runner will read the response # and deserialize it to the correct message type - # Only decode the buffer if it contains data. An empty buffer implies no response was written. - if not val and (mock_response and not isinstance(mock_response, OKResponse)): - pytest.fail("Expected a response, but got an empty buffer.") - - if val: + if frame.body is not None: # Using BytesIO to simulate a readable stream for CommsDecoder. - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) - assert decoder.get_message() == mock_response + decoder = CommsDecoder(request_socket=None).body_decoder + assert decoder.validate_python(frame.body) == mock_response def test_handle_requests_api_server_error(self, watched_subprocess, mocker): """Test that API server errors are properly handled and sent back to the task.""" @@ -1777,28 +1801,32 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): next(generator) - msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")).model_dump_json().encode() + b"\n" - generator.send(msg) + msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")) + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg.model_dump()) + generator.send(req_frame) - # Read response directly from the reader socket + # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - val += read_socket.recv(4096) - except (BlockingIOError, TimeoutError): - pass - - assert val == ( - b'{"error":"API_SERVER_ERROR","detail":{"status_code":500,"message":"API Server Error",' - b'"detail":{"detail":"Internal Server Error"}},"type":"ErrorResponse"}\n' - ) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id + + assert frame.error == { + "error": "API_SERVER_ERROR", + "detail": { + "status_code": 500, + "message": "API Server Error", + "detail": {"detail": "Internal Server Error"}, + }, + "type": "ErrorResponse", + } # Verify the error can be decoded correctly - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) + comms = CommsDecoder(request_socket=None) with pytest.raises(AirflowRuntimeError) as exc_info: - decoder.get_message() + comms._from_frame(frame) assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR assert exc_info.value.error.detail == { @@ -1864,14 +1892,13 @@ def test_inprocess_supervisor_comms_roundtrip(self): """ class MinimalSupervisor(InProcessTestSupervisor): - def _handle_request(self, msg, log): + def _handle_request(self, msg, log, req_id): resp = VariableResult(key=msg.key, value="value") - self.send_msg(resp) + self.send_msg(resp, req_id) supervisor = MinimalSupervisor( id="test", pid=123, - requests_fd=-1, process=MagicMock(), process_log=MagicMock(), client=MagicMock(), @@ -1881,9 +1908,8 @@ def _handle_request(self, msg, log): test_msg = GetVariable(key="test_key") - comms.send_request(log=MagicMock(), msg=test_msg) + response = comms.send(test_msg) # Ensure we got back what we expect - response = comms.get_message() assert isinstance(response, VariableResult) assert response.value == "value" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index d786fdaa5b2bc..5ce460c455bf4 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -22,11 +22,9 @@ import json import os import textwrap -import uuid from collections.abc import Iterable from datetime import datetime, timedelta from pathlib import Path -from socket import socketpair from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -103,7 +101,6 @@ VariableAccessor, ) from airflow.sdk.execution_time.task_runner import ( - CommsDecoder, RuntimeTaskInstance, TaskRunnerMarker, _push_xcom_if_needed, @@ -139,47 +136,6 @@ def execute(self, context): print(f"Hello World {task_id}!") -class TestCommsDecoder: - """Test the communication between the subprocess and the "supervisor".""" - - @pytest.mark.usefixtures("disable_capturing") - def test_recv_StartupDetails(self): - r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"type":"StartupDetails", "ti": {' - b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", ' - b'"dag_id": "c"}, "ti_context":{"dag_run":{"dag_id":"c","run_id":"b",' - b'"logical_date":"2024-12-01T01:00:00Z",' - b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' - b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,' - b'"run_type":"manual","conf":null,"consumed_asset_events":[]},' - b'"max_tries":0,"should_retry":false,"variables":null,"connections":null},"file": "/dev/null",' - b'"start_date":"2024-12-01T01:00:00Z", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' - b'"any-name", "version": "any-version"}, "requests_fd": ' - + str(w2.fileno()).encode("ascii") - + b"}\n" - ) - - decoder = CommsDecoder(input=r.makefile("r")) - - msg = decoder.get_message() - assert isinstance(msg, StartupDetails) - assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") - assert msg.ti.task_id == "a" - assert msg.ti.dag_id == "c" - assert msg.dag_rel_path == "/dev/null" - assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") - assert msg.start_date == timezone.datetime(2024, 12, 1, 1) - - # Since this was a StartupDetails message, the decoder should open the other socket - assert decoder.request_socket is not None - assert decoder.request_socket.writable() - assert decoder.request_socket.fileno() == w2.fileno() - - def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( @@ -192,7 +148,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -248,7 +203,6 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -302,7 +256,6 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -360,8 +313,8 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com assert ti.state == TaskInstanceState.DEFERRED - # send_request will only be called when the TaskDeferred exception is raised - mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) + # send will only be called when the TaskDeferred exception is raised + mock_supervisor_comms.send.assert_any_call(expected_defer_task) def test_run_downstream_skipped(mocked_parse, create_runtime_ti, mock_supervisor_comms): @@ -384,8 +337,8 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] log.info.assert_called_with("Skipping downstream tasks.") - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, msg=SkipDownstreamTasks(tasks=["task1", "task2"], type="SkipDownstreamTasks") + mock_supervisor_comms.send.assert_any_call( + SkipDownstreamTasks(tasks=["task1", "task2"], type="SkipDownstreamTasks") ) @@ -436,8 +389,8 @@ def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comm assert ti.state == TaskInstanceState.SKIPPED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState(state=TaskInstanceState.SKIPPED, end_date=instant), log=mock.ANY + mock_supervisor_comms.send.assert_called_with( + TaskState(state=TaskInstanceState.SKIPPED, end_date=instant) ) @@ -458,12 +411,11 @@ def test_run_raises_base_exception(time_machine, create_runtime_ti, mock_supervi assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( + mock_supervisor_comms.send.assert_called_with( msg=TaskState( state=TaskInstanceState.FAILED, end_date=instant, ), - log=mock.ANY, ) @@ -485,13 +437,7 @@ def test_run_raises_system_exit(time_machine, create_runtime_ti, mock_supervisor assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) log.exception.assert_not_called() log.error.assert_called_with(mock.ANY, exit_code=10) @@ -516,13 +462,7 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms): @@ -545,13 +485,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms assert ti.state == TaskInstanceState.FAILED # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): @@ -573,7 +507,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ), bundle_info=FAKE_BUNDLE, dag_rel_path="", - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -583,7 +516,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm spy_agency.spy_on(task.prepare_for_execution) assert not task._lock_for_execution - # mock_supervisor_comms.get_message.return_value = what run(ti, context=ti.get_template_context(), log=mock.Mock()) spy_agency.assert_spy_called(task.prepare_for_execution) @@ -591,7 +523,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm assert ti.task is not task, "ti.task should be a copy of the original task" assert ti.state == TaskInstanceState.SUCCESS - mock_supervisor_comms.send_request.assert_any_call( + mock_supervisor_comms.send.assert_any_call( msg=SetRenderedFields( rendered_fields={ "bash_command": "echo 'Logical date is 2024-12-01 01:00:00+00:00'", @@ -599,7 +531,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm "env": None, } ), - log=mock.ANY, ) @@ -689,7 +620,6 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -697,22 +627,18 @@ def execute(self, context): time_machine.move_to(instant, tick=False) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what run(*startup()) expected_calls = [ - mock.call.send_request( - msg=SetRenderedFields(rendered_fields=expected_rendered_fields), - log=mock.ANY, - ), - mock.call.send_request( + mock.call.send(SetRenderedFields(rendered_fields=expected_rendered_fields)), + mock.call.send( msg=SucceedTask( end_date=instant, state=TaskInstanceState.SUCCESS, task_outlets=[], outlet_events=[], ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -766,9 +692,8 @@ def execute(self, context): assert ti.state == TaskInstanceState.SUCCESS # Ensure the task is Successful - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send.assert_called_once_with( msg=SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), - log=mock.ANY, ) @@ -815,8 +740,8 @@ def execute(self, context): assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant), log=mock.ANY + mock_supervisor_comms.send.assert_called_once_with( + msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant) ) @@ -834,12 +759,11 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(dag_id=dag_id, run_id="c"), start_date=timezone.utcnow(), ) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what # Set the environment variable for DAG bundles # We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test! @@ -955,7 +879,7 @@ def test_run_with_asset_outlets( validate_mock.assert_called_once() - mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) + mock_supervisor_comms.send.assert_any_call(expected_msg) def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): @@ -967,7 +891,7 @@ def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.get_message.return_value = events_result + mock_supervisor_comms.send.return_value = events_result from airflow.providers.standard.operators.bash import BashOperator @@ -1121,7 +1045,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s dr = runtime_ti._ti_context_from_server.dag_run - mock_supervisor_comms.get_message.return_value = PrevSuccessfulDagRunResult( + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( data_interval_end=dr.logical_date - timedelta(hours=1), data_interval_start=dr.logical_date - timedelta(hours=2), start_date=dr.start_date - timedelta(hours=1), @@ -1171,7 +1095,7 @@ def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, dag_id="basic_task") - mock_supervisor_comms.get_message.return_value = PrevSuccessfulDagRunResult( + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( data_interval_end=timezone.datetime(2025, 1, 1, 2, 0, 0), data_interval_start=timezone.datetime(2025, 1, 1, 1, 0, 0), start_date=timezone.datetime(2025, 1, 1, 1, 0, 0), @@ -1181,13 +1105,13 @@ def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock context = runtime_ti.get_template_context() # Assert lazy attributes are not resolved initially - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access a lazy-loaded attribute to trigger computation assert context["prev_data_interval_start_success"] == timezone.datetime(2025, 1, 1, 1, 0, 0) # Now the lazy attribute should trigger the call - mock_supervisor_comms.get_message.assert_called_once() + mock_supervisor_comms.send.assert_called_once() def test_get_connection_from_context(self, create_runtime_ti, mock_supervisor_comms): """Test that the connection is fetched from the API server via the Supervisor lazily when accessed""" @@ -1206,22 +1130,18 @@ def test_get_connection_from_context(self, create_runtime_ti, mock_supervisor_co ) runtime_ti = create_runtime_ti(task=task, dag_id="test_get_connection_from_context") - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn context = runtime_ti.get_template_context() # Assert that the connection is not fetched from the API server yet! # The connection should be only fetched connection is accessed - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access the connection from the context conn_from_context = context["conn"].test_conn - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, msg=GetConnection(conn_id="test_conn") - ) - mock_supervisor_comms.get_message.assert_called_once_with() + mock_supervisor_comms.send.assert_called_once_with(GetConnection(conn_id="test_conn")) assert conn_from_context == Connection( conn_id="test_conn", @@ -1280,7 +1200,7 @@ def test_template_with_connection( extra='{"extra__asana__workspace": "extra1"}', ) - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn context = runtime_ti.get_template_context() result = runtime_ti.task.render_template(content, context) @@ -1307,22 +1227,18 @@ def test_get_variable_from_context( var = VariableResult(key="test_key", value=var_value) - mock_supervisor_comms.get_message.return_value = var + mock_supervisor_comms.send.return_value = var context = runtime_ti.get_template_context() # Assert that the variable is not fetched from the API server yet! # The variable should be only fetched connection is accessed - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access the variable from the context var_from_context = context["var"][accessor_type].test_key - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, msg=GetVariable(key="test_key") - ) - mock_supervisor_comms.get_message.assert_called_once_with() + mock_supervisor_comms.send.assert_called_once_with(GetVariable(key="test_key")) assert var_from_context == expected_value @@ -1390,16 +1306,14 @@ def execute(self, context): ser_value = BaseXCom.serialize_value(xcom_values) - def mock_get_message_side_effect(*args, **kwargs): - calls = mock_supervisor_comms.send_request.call_args_list - if calls: - last_call = calls[-1] - msg = last_call[1]["msg"] - if isinstance(msg, GetXComSequenceSlice): - return XComSequenceSliceResult(root=[ser_value]) + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + print(f"{args=}, {kwargs=}, {msg=}") + if isinstance(msg, GetXComSequenceSlice): + return XComSequenceSliceResult(root=[ser_value]) return XComResult(key="key", value=ser_value) - mock_supervisor_comms.get_message.side_effect = mock_get_message_side_effect + mock_supervisor_comms.send.side_effect = mock_send_side_effect run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) @@ -1415,8 +1329,7 @@ def mock_get_message_side_effect(*args, **kwargs): task_id = test_task_id for map_index in map_indexes: if map_index == NOTSET: - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( msg=GetXComSequenceSlice( key="key", dag_id="test_dag", @@ -1429,8 +1342,7 @@ def mock_get_message_side_effect(*args, **kwargs): ) else: expected_map_index = map_index if map_index is not None else None - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( msg=GetXCom( key="key", dag_id="test_dag", @@ -1591,11 +1503,8 @@ def execute(self, context): context = runtime_ti.get_template_context() run(runtime_ti, context=context, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: @@ -1644,9 +1553,8 @@ def __init__(self, bash_command, *args, **kwargs): log=mock.MagicMock(), ) - mock_supervisor_comms.send_request.assert_called_with( - msg=SetRenderedFields(rendered_fields={"bash_command": rendered_cmd}), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_with( + msg=SetRenderedFields(rendered_fields={"bash_command": rendered_cmd}) ) @pytest.mark.parametrize( @@ -1669,7 +1577,7 @@ def test_get_first_reschedule_date( task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count) - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate( + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate( start_date=timezone.datetime(2025, 1, 1) ) @@ -1678,7 +1586,7 @@ def test_get_first_reschedule_date( def test_get_ti_count(self, mock_supervisor_comms): """Test that get_ti_count sends the correct request and returns the count.""" - mock_supervisor_comms.get_message.return_value = TICount(count=2) + mock_supervisor_comms.send.return_value = TICount(count=2) count = RuntimeTaskInstance.get_ti_count( dag_id="test_dag", @@ -1689,8 +1597,7 @@ def test_get_ti_count(self, mock_supervisor_comms): states=["success", "failed"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetTICount( dag_id="test_dag", task_ids=["task1", "task2"], @@ -1704,7 +1611,7 @@ def test_get_ti_count(self, mock_supervisor_comms): def test_get_dr_count(self, mock_supervisor_comms): """Test that get_dr_count sends the correct request and returns the count.""" - mock_supervisor_comms.get_message.return_value = DRCount(count=2) + mock_supervisor_comms.send.return_value = DRCount(count=2) count = RuntimeTaskInstance.get_dr_count( dag_id="test_dag", @@ -1713,8 +1620,7 @@ def test_get_dr_count(self, mock_supervisor_comms): states=["success", "failed"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetDRCount( dag_id="test_dag", logical_dates=[timezone.datetime(2024, 1, 1)], @@ -1726,27 +1632,21 @@ def test_get_dr_count(self, mock_supervisor_comms): def test_get_dagrun_state(self, mock_supervisor_comms): """Test that get_dagrun_state sends the correct request and returns the state.""" - mock_supervisor_comms.get_message.return_value = DagRunStateResult(state="running") + mock_supervisor_comms.send.return_value = DagRunStateResult(state="running") state = RuntimeTaskInstance.get_dagrun_state( dag_id="test_dag", run_id="run1", ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, - msg=GetDagRunState( - dag_id="test_dag", - run_id="run1", - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagRunState(dag_id="test_dag", run_id="run1"), ) assert state == "running" def test_get_task_states(self, mock_supervisor_comms): """Test that get_task_states sends the correct request and returns the states.""" - mock_supervisor_comms.get_message.return_value = TaskStatesResult( - task_states={"run1": {"task1": "running"}} - ) + mock_supervisor_comms.send.return_value = TaskStatesResult(task_states={"run1": {"task1": "running"}}) states = RuntimeTaskInstance.get_task_states( dag_id="test_dag", @@ -1754,8 +1654,7 @@ def test_get_task_states(self, mock_supervisor_comms): run_ids=["run1"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetTaskStates( dag_id="test_dag", task_ids=["task1"], @@ -1929,7 +1828,6 @@ def execute(self, context): assert not any( x == mock.call( - log=mock.ANY, msg=SetXCom( key="key", value="pushing to xcom backend!", @@ -1939,7 +1837,7 @@ def execute(self, context): map_index=-1, ), ) - for x in mock_supervisor_comms.send_request.call_args_list + for x in mock_supervisor_comms.send.call_args_list ) def test_xcom_pull_from_custom_xcom_backend( @@ -1966,7 +1864,6 @@ def execute(self, context): assert not any( x == mock.call( - log=mock.ANY, msg=GetXCom( key="key", dag_id="test_dag", @@ -1975,7 +1872,7 @@ def execute(self, context): map_index=-1, ), ) - for x in mock_supervisor_comms.send_request.call_args_list + for x in mock_supervisor_comms.send.call_args_list ) @@ -2006,11 +1903,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dag_overwrite(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -2035,11 +1929,8 @@ def execute(self, context): task=task, dag_id="dag_with_dag_params_overwrite", conf={"value": "new_value"} ) run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dag_default(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -2062,11 +1953,8 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task, dag_id="dag_with_dag_params_default") run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_resolves( @@ -2097,11 +1985,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dagrun_parameterized( @@ -2136,11 +2021,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) @pytest.mark.parametrize("value", [VALUE, 0]) @@ -2167,11 +2049,8 @@ def return_num(num): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_any_call( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) @@ -2234,12 +2113,11 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what mocked_parse(what, "basic_dag", task) runtime_ti, context, log = startup() @@ -2539,17 +2417,15 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): assert msg.state == TaskInstanceState.SUCCESS expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", reset_dag_run=False, logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=SetXCom( key="trigger_run_id", value="test_run_id", @@ -2558,7 +2434,6 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): run_id="test_run", map_index=-1, ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -2586,23 +2461,21 @@ def test_handle_trigger_dag_run_conflict( ti = create_runtime_ti(dag_id="test_handle_trigger_dag_run_conflict", run_id="test_run", task=task) log = mock.MagicMock() - mock_supervisor_comms.get_message.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) + mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) state, msg, _ = run(ti, ti.get_template_context(), log) assert state == expected_state assert msg.state == expected_state expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), run_id="test_run_id", reset_dag_run=False, ), - log=mock.ANY, ), - mock.call.get_message(), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -2647,13 +2520,19 @@ def test_handle_trigger_dag_run_wait_for_completion( ) log = mock.MagicMock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ + # Set RTIF + None, # Successful Dag Run trigger OKResponse(ok=True), + # Set XCOM, + None, # Dag Run is still running DagRunStateResult(state=DagRunState.RUNNING), # Dag Run completes execution on the next poll DagRunStateResult(state=target_dr_state), + # Succeed/Fail task + None, ] with mock.patch("time.sleep", return_value=None): state, msg, _ = run(ti, ti.get_template_context(), log) @@ -2662,16 +2541,14 @@ def test_handle_trigger_dag_run_wait_for_completion( assert msg.state == expected_task_state expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=SetXCom( key="trigger_run_id", value="test_run_id", @@ -2680,22 +2557,18 @@ def test_handle_trigger_dag_run_wait_for_completion( run_id="test_run", map_index=-1, ), - log=mock.ANY, ), - mock.call.send_request( + mock.call.send( msg=GetDagRunState( dag_id="test_dag", run_id="test_run_id", ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=GetDagRunState( dag_id="test_dag", run_id="test_run_id", ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) From 05c78f1ae71520362491503ad8e16266d9666a69 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 16 Jun 2025 18:12:11 +0100 Subject: [PATCH 09/21] Deal with compat in tests This compat issue is only in tests, as nothing in the runtime of airflow-core imports/calls methods directly on SUPERVISOR_COMMS, we are only importing it in tests to mkae assertions about the behavour/to stub the return values. --- devel-common/src/tests_common/pytest_plugin.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index d542bfa034b05..b1900be775d73 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1964,11 +1964,18 @@ def mock_supervisor_comms(monkeypatch): yield None return - import airflow.sdk.execution_time.task_runner - from airflow.sdk.execution_time.comms import CommsDecoder + from airflow.sdk.execution_time import comms, task_runner - comms = mock.create_autospec(CommsDecoder) - monkeypatch.setattr(airflow.sdk.execution_time.task_runner, "SUPERVISOR_COMMS", comms, raising=False) + # Deal with TaskSDK 1.0/1.1 vs 1.2+. Annoying, and shouldn't need to exist once the separation between + # core and TaskSDK is finished + if CommsDecoder := getattr(comms, "CommsDecoder", None): + comms = mock.create_autospec(CommsDecoder) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + else: + CommsDecoder = getattr(task_runner, "CommsDecoder") + comms = mock.create_autospec(CommsDecoder) + comms.send = comms.get_response + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) yield comms From f45647afbcfcacd056f9bab0a314e1029bde4563 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 16 Jun 2025 18:39:29 +0100 Subject: [PATCH 10/21] Code review --- .../src/airflow/dag_processing/processor.py | 6 ++--- .../src/airflow/jobs/triggerer_job_runner.py | 6 +++-- .../tests/unit/dag_processing/test_manager.py | 1 - .../unit/dag_processing/test_processor.py | 2 +- .../src/airflow/sdk/execution_time/comms.py | 23 ++++++++-------- .../airflow/sdk/execution_time/supervisor.py | 26 ++++++++++--------- .../airflow/sdk/execution_time/task_runner.py | 6 ++--- .../task_sdk/execution_time/test_comms.py | 2 +- .../execution_time/test_supervisor.py | 5 ++-- 9 files changed, 40 insertions(+), 37 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 4341d6d91390e..011393f22c886 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -265,7 +265,7 @@ def _on_child_started( bundle_path=bundle_path, callback_requests=callbacks, ) - self.send_msg(msg, in_response_to=0) + self.send_msg(msg, request_id=0) def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse @@ -298,14 +298,14 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int log.error("Unhandled request", msg=msg) self.send_msg( None, - in_response_to=req_id, + request_id=req_id, error=ErrorResponse( detail={"status_code": 400, "message": "Unhandled request"}, ), ) return - self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) @property def is_ready(self) -> bool: diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index e829a4fa6eae0..adc87e802fa6f 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -343,7 +343,7 @@ def start( # type: ignore[override] proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs) msg = messages.StartTriggerer() - proc.send_msg(msg, in_response_to=0) + proc.send_msg(msg, request_id=0) return proc @functools.cached_property @@ -454,7 +454,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r else: raise ValueError(f"Unknown message type {type(msg)}") - self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) def run(self) -> None: """Run synchronously and handle all database reads/writes.""" @@ -714,6 +714,8 @@ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: async def _aread_frame(self): len_bytes = await self._async_reader.readexactly(4) len = int.from_bytes(len_bytes, byteorder="big") + if len >= 2**32: + raise OverflowError(f"Refusing to receive messages larger than 4GiB {len=}") buffer = await self._async_reader.readexactly(len) return self.resp_decoder.decode(buffer) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 98204033ced9e..02121f0194e82 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -603,7 +603,6 @@ def test_serialize_callback_requests(self, callbacks, path, expected_body): read_socket.settimeout(0.1) # Read response from the read end of the socket - read_socket.settimeout(0.1) frame_len = int.from_bytes(read_socket.recv(4), "big") bytes = read_socket.recv(frame_len) frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index eb6a58dfcfbfb..8d77da61cafb9 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -419,7 +419,7 @@ def test_parse_file_entrypoint_parses_dag_callbacks(mocker): w.sendall(bytes) decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - request_socket=r, + socket=r, body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), ) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 67a298a6c14ab..be2f4f8eacd46 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -101,7 +101,8 @@ def _msgpack_enc_hook(obj: Any) -> Any: import pendulum if isinstance(obj, pendulum.DateTime): - # convert the complex to a tuple of real, imag + # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native + # encoding return datetime( obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo ) @@ -137,14 +138,14 @@ def as_bytes(self) -> bytearray: self.req_encoder.encode_into(self, buffer, 4) n = len(buffer) - 4 - if n > 2**32: - raise OverflowError("Cannot send messages larger than 4GiB") + if n >= 2**32: + raise OverflowError(f"Cannot send messages larger than 4GiB {n=}") buffer[:4] = n.to_bytes(4, byteorder="big") return buffer -class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): +class _ResponseFrame(_RequestFrame, frozen=True): id: int """ The id of the request this is a response to @@ -158,7 +159,7 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): """Handle communication between the task in this process and the supervisor parent process.""" log: Logger = attrs.field(repr=False, factory=structlog.get_logger) - request_socket: socket = attrs.field(factory=lambda: socket(fileno=0)) + socket: socket = attrs.field(factory=lambda: socket(fileno=0)) resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False @@ -178,7 +179,7 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None: frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) bytes = frame.as_bytes() - self.request_socket.sendall(bytes) + self.socket.sendall(bytes) return self._get_response() @@ -188,9 +189,9 @@ def _read_frame(self): This will block until the message has been received. """ - if self.request_socket: - self.request_socket.setblocking(True) - len_bytes = self.request_socket.recv(4) + if self.socket: + self.socket.setblocking(True) + len_bytes = self.socket.recv(4) if len_bytes == b"": raise EOFError("Request socket closed before length") @@ -198,13 +199,13 @@ def _read_frame(self): len = int.from_bytes(len_bytes, byteorder="big") buffer = bytearray(len) - nread = self.request_socket.recv_into(buffer) + nread = self.socket.recv_into(buffer) if nread != len: raise RuntimeError( f"unable to read full response in child. (We read {nread}, but expected {len})" ) if nread == 0: - raise EOFError("Request socket closed before response was complete") + raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") return self.resp_decoder.decode(buffer) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index a19acf3d13d18..2ae554cace1a5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -331,8 +331,6 @@ def _fork_main( # Store original stderr for last-chance exception handling last_chance_stderr = _get_last_chance_stderr() - # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno()) - _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) @@ -561,14 +559,19 @@ def _on_socket_closed(self, sock: socket): del self._open_sockets[sock] def send_msg( - self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + self, msg: BaseModel | None, request_id: int, error: ErrorResponse | None = None, **dump_opts ): - """Send the msg as a length-prefixed response frame.""" + """ + Send the msg as a length-prefixed response frame. + + ``request_id`` is the ID that the client sent in it's request, and has no meaning to the server + + """ if msg: - frame = _ResponseFrame(id=in_response_to, body=msg.model_dump(**dump_opts)) + frame = _ResponseFrame(id=request_id, body=msg.model_dump(**dump_opts)) else: err_resp = error.model_dump() if error else None - frame = _ResponseFrame(id=in_response_to, error=err_resp) + frame = _ResponseFrame(id=request_id, error=err_resp) self.stdin.sendall(frame.as_bytes()) @@ -605,7 +608,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _Request "detail": error_details, }, ), - in_response_to=request.id, + request_id=request.id, ) def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: @@ -868,7 +871,7 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st log.debug("Sending", msg=msg) try: - self.send_msg(msg, in_response_to=0) + self.send_msg(msg, request_id=0) except (BrokenPipeError, ConnectionResetError): # Debug is fine, the process will have shown _something_ in it's last_chance exception handler log.debug("Couldn't send startup message to Subprocess - it died very early", pid=self.pid) @@ -1219,7 +1222,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: log.error("Unhandled request", msg=msg) self.send_msg( None, - in_response_to=req_id, + request_id=req_id, error=ErrorResponse( error=ErrorType.API_SERVER_ERROR, detail={"status_code": 400, "message": "Unhandled request"}, @@ -1227,7 +1230,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: ) return - self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) def in_process_api_server(): @@ -1366,7 +1369,7 @@ def _api_client(dag=None): return client def send_msg( - self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + self, msg: BaseModel | None, request_id: int, error: ErrorResponse | None = None, **dump_opts ): """Override to use in-process comms.""" self.comms.messages.append(msg) @@ -1444,7 +1447,6 @@ def cb(sock: socket): if len(buffer): with suppress(StopIteration): gen.send(buffer) - # Tell loop to close this selector return False buffer.extend(read_buffer[:n_received]) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index c545c5eebc7b7..5ad5d4df45feb 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -649,7 +649,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent + # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent # in response to us sending a request. msg = SUPERVISOR_COMMS._get_response() @@ -1256,9 +1256,9 @@ def main(): finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if SUPERVISOR_COMMS and SUPERVISOR_COMMS.request_socket: + if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: with suppress(Exception): - SUPERVISOR_COMMS.request_socket.close() + SUPERVISOR_COMMS.socket.close() if __name__ == "__main__": diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index ee2f956b063c9..5adaa2562abc7 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -71,7 +71,7 @@ def test_recv_StartupDetails(self): bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes) - decoder = CommsDecoder(request_socket=r, log=None) + decoder = CommsDecoder(socket=r, log=None) msg = decoder._get_response() assert isinstance(msg, StartupDetails) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 6e1c3b2e093cb..d2f8ef40c0881 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1777,8 +1777,7 @@ def test_handle_requests( # and deserialize it to the correct message type if frame.body is not None: - # Using BytesIO to simulate a readable stream for CommsDecoder. - decoder = CommsDecoder(request_socket=None).body_decoder + decoder = CommsDecoder(socket=None).body_decoder assert decoder.validate_python(frame.body) == mock_response def test_handle_requests_api_server_error(self, watched_subprocess, mocker): @@ -1824,7 +1823,7 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): } # Verify the error can be decoded correctly - comms = CommsDecoder(request_socket=None) + comms = CommsDecoder(socket=None) with pytest.raises(AirflowRuntimeError) as exc_info: comms._from_frame(frame) From 6b37dcd4afe9cb6e1bb7c3473fbfc98715712dae Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 17 Jun 2025 12:53:56 +0530 Subject: [PATCH 11/21] adding a testin task runner --- .../execution_time/test_task_runner.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 5ce460c455bf4..76f85b47c2688 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -644,6 +644,40 @@ def execute(self, context): mock_supervisor_comms.assert_has_calls(expected_calls) +def test_task_startup_with_user_impersonation( + mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +): + """Test startup of a task with run_as_user specified.""" + + class CustomOperator(BaseOperator): + def execute(self, context): + print("Hi from CustomOperator!") + + task = CustomOperator(task_id="impersonation_task", run_as_user="airflowuser") + instant = timezone.datetime(2024, 12, 3, 10, 0) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="impersonation_task", + dag_id="basic_dag", + run_id="c", + try_number=1, + ), + dag_rel_path="", + bundle_info=FAKE_BUNDLE, + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + mocked_parse(what, "basic_dag", task) + + time_machine.move_to(instant, tick=False) + + mock_supervisor_comms._get_response.return_value = what + + startup() + + @pytest.mark.parametrize( ["command", "rendered_command"], [ From f3b72393c646929637f71758f21d823bb3604204 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 17 Jun 2025 13:10:18 +0530 Subject: [PATCH 12/21] fixing mock --- devel-common/src/tests_common/pytest_plugin.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index b1900be775d73..c164a185ee3b4 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1957,13 +1957,18 @@ def override_caplog(request): @pytest.fixture def mock_supervisor_comms(monkeypatch): - # for back-compat + import socket + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if not AIRFLOW_V_3_0_PLUS: yield None return + # Patch socket(fileno=0) before anything touches CommsDecoder + dummy_sock, _ = socket.socketpair() + monkeypatch.setattr("airflow.sdk.execution_time.comms.socket", lambda fileno: dummy_sock) + from airflow.sdk.execution_time import comms, task_runner # Deal with TaskSDK 1.0/1.1 vs 1.2+. Annoying, and shouldn't need to exist once the separation between From 5ba66b53d0a935b410db291b1ae54811ac49e08e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 17 Jun 2025 14:04:02 +0530 Subject: [PATCH 13/21] adding unit tests --- .../airflow/sdk/execution_time/task_runner.py | 4 ++-- .../execution_time/test_task_runner.py | 23 ++++++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bda6768550cc0..8ea8af820f336 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -670,7 +670,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log.info("Using serialized startup message from environment", msg=msg) else: # normal entry point - msg = SUPERVISOR_COMMS._get_response() + msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] log.info("Received startup message from supervisor", msg=msg) if not isinstance(msg, StartupDetails): @@ -704,7 +704,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" # store startup message in environment for re-exec process os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() - os.set_inheritable(SUPERVISOR_COMMS.request_socket.fileno(), True) + os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) log.info("Running command", command=["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) os.execvp("sudo", ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 76f85b47c2688..e495ac9189e66 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -644,11 +644,11 @@ def execute(self, context): mock_supervisor_comms.assert_has_calls(expected_calls) -def test_task_startup_with_user_impersonation( - mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +@patch("os.execvp") +@patch("os.set_inheritable") +def test_user_impersonation_env_and_execvp( + mock_set_inheritable, mock_execvp, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms ): - """Test startup of a task with run_as_user specified.""" - class CustomOperator(BaseOperator): def execute(self, context): print("Hi from CustomOperator!") @@ -669,13 +669,24 @@ def execute(self, context): ti_context=make_ti_context(), start_date=timezone.utcnow(), ) - mocked_parse(what, "basic_dag", task) + mocked_parse(what, "basic_dag", task) time_machine.move_to(instant, tick=False) mock_supervisor_comms._get_response.return_value = what + mock_supervisor_comms.socket.fileno.return_value = 42 + + with mock.patch.dict(os.environ, {}, clear=True): + startup() + + assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1" + assert "_AIRFLOW__STARTUP_MSG" in os.environ - startup() + mock_set_inheritable.assert_called_once_with(42, True) + expected_suffix = "src/airflow/sdk/execution_time/task_runner.py" + actual_cmd = mock_execvp.call_args.args[1] + assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"] + assert actual_cmd[-1].endswith(expected_suffix) @pytest.mark.parametrize( From 54ee8c5f0a2e81295ba5da1d6ac9090e7813d269 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 17 Jun 2025 20:39:06 +0530 Subject: [PATCH 14/21] rebase errors --- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index ff55bb890535d..2ae554cace1a5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -331,8 +331,6 @@ def _fork_main( # Store original stderr for last-chance exception handling last_chance_stderr = _get_last_chance_stderr() - # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno()) - _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) @@ -612,7 +610,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _Request ), request_id=request.id, ) - return def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: raise NotImplementedError() @@ -1375,8 +1372,7 @@ def send_msg( self, msg: BaseModel | None, request_id: int, error: ErrorResponse | None = None, **dump_opts ): """Override to use in-process comms.""" - if msg is not None: - self.comms.messages.append(msg) + self.comms.messages.append(msg) @property def final_state(self): From e890bad8399624a08bf1f24b95e02dafb054567e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 16:31:43 +0530 Subject: [PATCH 15/21] nuke getattr and use function instead --- .../airflow/sdk/execution_time/task_runner.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 8ea8af820f336..2522200618216 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -644,15 +644,6 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] -def __getattr__(name: str) -> CommsDecoder[ToTask, ToSupervisor]: - """Lazy initialization of supervisor comms.""" - global SUPERVISOR_COMMS - if name == "SUPERVISOR_COMMS": - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=structlog.get_logger(logger_name="task")) - return SUPERVISOR_COMMS - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - # State machine! # 1. Start up (receive details from supervisor) # 2. Execution (run task code, possibly send requests) @@ -706,8 +697,18 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) - log.info("Running command", command=["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) - os.execvp("sudo", ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, __file__]) + # Import main directly from the module instead of re-executing the file. + # This ensures that when other parts modules import + # airflow.sdk.execution_time.task_runner, they get the same module instance + # with the properly initialized SUPERVISOR_COMMS global variable. + # If we re-executed the script, it would load as __main__ and future + # imports would get a fresh copy without the initialized globals. + rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" + log.info( + "Running command", + command=["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code], + ) + os.execvp("sudo", ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]) # ideally, we should never reach here, but if we do, we should return None, None, None return None, None, None From b84df96ae55ac577a21cacd7dbdc7d44e26127c9 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 16:33:59 +0530 Subject: [PATCH 16/21] removing unwanted code --- devel-common/src/tests_common/pytest_plugin.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index e4e9160f2d0dd..41bb7390ce74c 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1957,8 +1957,6 @@ def override_caplog(request): @pytest.fixture def mock_supervisor_comms(monkeypatch): - import socket - # for back-compat from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -1966,10 +1964,6 @@ def mock_supervisor_comms(monkeypatch): yield None return - # Patch socket(fileno=0) before anything touches CommsDecoder - dummy_sock, _ = socket.socketpair() - monkeypatch.setattr("airflow.sdk.execution_time.comms.socket", lambda fileno: dummy_sock) - from airflow.sdk.execution_time import comms, task_runner # Deal with TaskSDK 1.0/1.1 vs 1.2+. Annoying, and shouldn't need to exist once the separation between From 5a7f6d6e66a62dd63683124eed0b0d4400ad6b4d Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 16:37:06 +0530 Subject: [PATCH 17/21] fixing tests --- task-sdk/tests/task_sdk/execution_time/test_task_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index e495ac9189e66..8bcdbfa13a235 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -683,10 +683,9 @@ def execute(self, context): assert "_AIRFLOW__STARTUP_MSG" in os.environ mock_set_inheritable.assert_called_once_with(42, True) - expected_suffix = "src/airflow/sdk/execution_time/task_runner.py" actual_cmd = mock_execvp.call_args.args[1] assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"] - assert actual_cmd[-1].endswith(expected_suffix) + assert actual_cmd[-1] == "from airflow.sdk.execution_time.task_runner import main; main()" @pytest.mark.parametrize( From 5a8e13e88204ff709e81eddfd6f0efdb4c29bd59 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 18:06:25 +0530 Subject: [PATCH 18/21] Apply suggestions from code review Co-authored-by: Ash Berlin-Taylor --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 2522200618216..0978ca06f83a6 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -651,18 +651,17 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent + # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent # in response to us sending a request. log = structlog.get_logger(logger_name="task") if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"): # entrypoint of re-exec process msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) - log.info("Using serialized startup message from environment", msg=msg) + log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] - log.info("Received startup message from supervisor", msg=msg) if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") @@ -701,7 +700,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: # This ensures that when other parts modules import # airflow.sdk.execution_time.task_runner, they get the same module instance # with the properly initialized SUPERVISOR_COMMS global variable. - # If we re-executed the script, it would load as __main__ and future + # If we re-executed the module with `python -m`, it would load as __main__ and future # imports would get a fresh copy without the initialized globals. rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" log.info( From 010d2b2744061942eb6a5192a42582710a6046d9 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 18:09:52 +0530 Subject: [PATCH 19/21] nits from ash --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 0978ca06f83a6..098beb8b33507 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -703,11 +703,12 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: # If we re-executed the module with `python -m`, it would load as __main__ and future # imports would get a fresh copy without the initialized globals. rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" + cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] log.info( "Running command", - command=["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code], + command=cmd, ) - os.execvp("sudo", ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]) + os.execvp("sudo", cmd) # ideally, we should never reach here, but if we do, we should return None, None, None return None, None, None From 6ee1cb02839fac30819fb12d045cb314c8534ad9 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 20:01:16 +0530 Subject: [PATCH 20/21] nits from ash --- task-sdk/src/airflow/sdk/execution_time/comms.py | 1 + task-sdk/src/airflow/sdk/execution_time/task_runner.py | 9 ++++----- .../tests/task_sdk/execution_time/test_task_runner.py | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 46bb204329f50..be2f4f8eacd46 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -180,6 +180,7 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None: bytes = frame.as_bytes() self.socket.sendall(bytes) + return self._get_response() def _read_frame(self): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 098beb8b33507..9638edd2a250f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -40,7 +40,7 @@ from airflow.configuration import conf from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.exceptions import AirflowConfigException, AirflowInactiveAssetInInletOrOutletException +from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, @@ -684,10 +684,9 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: ti.log_url = get_log_url_from_ti(ti) log.debug("DAG file parsed", file=msg.dag_rel_path) - try: - run_as_user = getattr(ti.task, "run_as_user", None) or conf.get("core", "default_impersonation") - except AirflowConfigException: - run_as_user = None + run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( + "core", "default_impersonation", fallback=None + ) if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user: # enters here for re-exec process diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 8bcdbfa13a235..34b30264520aa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -684,8 +684,10 @@ def execute(self, context): mock_set_inheritable.assert_called_once_with(42, True) actual_cmd = mock_execvp.call_args.args[1] + assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"] - assert actual_cmd[-1] == "from airflow.sdk.execution_time.task_runner import main; main()" + assert "python -c" in actual_cmd[5] + " " + actual_cmd[6] + assert actual_cmd[7] == "from airflow.sdk.execution_time.task_runner import main; main()" @pytest.mark.parametrize( From 706542f6b59fb79340bf7d495408fef19e16b06c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Jun 2025 20:58:04 +0530 Subject: [PATCH 21/21] addressing an edge case --- .../airflow/sdk/execution_time/task_runner.py | 3 +- .../execution_time/test_task_runner.py | 41 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 9638edd2a250f..e4292deca2a10 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -98,6 +98,7 @@ ) from airflow.sdk.execution_time.xcom import XCom from airflow.utils.net import get_hostname +from airflow.utils.platform import getuser from airflow.utils.timezone import coerce_datetime if TYPE_CHECKING: @@ -688,7 +689,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: "core", "default_impersonation", fallback=None ) - if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user: + if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): # enters here for re-exec process os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" # store startup message in environment for re-exec process diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 34b30264520aa..9745297f8c7f0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -646,7 +646,7 @@ def execute(self, context): @patch("os.execvp") @patch("os.set_inheritable") -def test_user_impersonation_env_and_execvp( +def test_task_run_with_user_impersonation( mock_set_inheritable, mock_execvp, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms ): class CustomOperator(BaseOperator): @@ -690,6 +690,45 @@ def execute(self, context): assert actual_cmd[7] == "from airflow.sdk.execution_time.task_runner import main; main()" +@patch("airflow.sdk.execution_time.task_runner.getuser") +def test_task_run_with_user_impersonation_default_user( + mock_get_user, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +): + class CustomOperator(BaseOperator): + def execute(self, context): + print("Hi from CustomOperator!") + + task = CustomOperator(task_id="impersonation_task", run_as_user="default_user") + instant = timezone.datetime(2024, 12, 3, 10, 0) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="impersonation_task", + dag_id="basic_dag", + run_id="c", + try_number=1, + ), + dag_rel_path="", + bundle_info=FAKE_BUNDLE, + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + + mocked_parse(what, "basic_dag", task) + time_machine.move_to(instant, tick=False) + + mock_supervisor_comms._get_response.return_value = what + mock_supervisor_comms.socket.fileno.return_value = 42 + mock_get_user.return_value = "default_user" + + with mock.patch.dict(os.environ, {}, clear=True): + startup() + + assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ + assert "_AIRFLOW__STARTUP_MSG" not in os.environ + + @pytest.mark.parametrize( ["command", "rendered_command"], [