diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 89a9fbb1aa746..c3b4f809411ce 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -77,6 +77,8 @@ from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: + from socket import socket + from sqlalchemy.orm import Session from airflow.callbacks.callback_requests import CallbackRequest @@ -388,7 +390,7 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): """ events = self.selector.select(timeout=timeout) for key, _ in events: - socket_handler = key.data + socket_handler, on_close = key.data # BrokenPipeError should be caught and treated as if the handler returned false, similar # to EOF case @@ -397,8 +399,9 @@ def _service_processor_sockets(self, timeout: float | None = 1.0): except (BrokenPipeError, ConnectionResetError): need_more = False if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() def _queue_requested_files_for_parsing(self) -> None: """Queue any files requested for parsing as requested by users via UI/API.""" diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 5f69082c758aa..011393f22c886 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel): bundle_path: Path """Passing bundle path around lets us figure out relative file path.""" - requests_fd: int callback_requests: list[CallbackRequest] = Field(default_factory=list) type: Literal["DagFileParseRequest"] = "DagFileParseRequest" @@ -102,18 +101,16 @@ class DagFileParsingResult(BaseModel): def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time import comms, task_runner # Parse DAG file, send JSON back up! - comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager]( - input=sys.stdin, - decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), + comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) - msg = comms_decoder.get_message() + msg = comms_decoder._get_response() if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") @@ -125,7 +122,7 @@ def _parse_file_entrypoint(): result = _parse_file(msg, log) if result is not None: - comms_decoder.send_request(log, result) + comms_decoder.send(result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: @@ -266,20 +263,18 @@ def _on_child_started( msg = DagFileParseRequest( file=os.fspath(path), bundle_path=bundle_path, - requests_fd=self._requests_fd, callback_requests=callbacks, ) - self.send_msg(msg) + self.send_msg(msg, request_id=0) - def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse resp: BaseModel | None = None dump_opts = {} if isinstance(msg, DagFileParsingResult): self.parsing_result = msg - return - if isinstance(msg, GetConnection): + elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) @@ -301,10 +296,16 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # resp = self.client.variables.delete(msg.key) else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + request_id=req_id, + error=ErrorResponse( + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) @property def is_ready(self) -> bool: @@ -312,7 +313,7 @@ def is_ready(self) -> bool: # Process still alive, def can't be finished yet return False - return self._num_open_sockets == 0 + return not self._open_sockets def wait(self) -> int: raise NotImplementedError(f"Don't call wait on {type(self).__name__} objects") diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 5df6a03261523..adc87e802fa6f 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -28,6 +28,7 @@ from collections.abc import Generator, Iterable from contextlib import suppress from datetime import datetime +from socket import socket from traceback import format_exception from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypedDict, Union @@ -43,6 +44,7 @@ from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger from airflow.sdk.execution_time.comms import ( + CommsDecoder, ConnectionResult, DagRunStateResult, DRCount, @@ -58,6 +60,7 @@ TICount, VariableResult, XComResult, + _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader from airflow.stats import Stats @@ -70,8 +73,6 @@ from airflow.utils.session import provide_session if TYPE_CHECKING: - from socket import socket - from sqlalchemy.orm import Session from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -181,7 +182,6 @@ class messages: class StartTriggerer(BaseModel): """Tell the async trigger runner process to start, and where to send status update messages.""" - requests_fd: int type: Literal["StartTriggerer"] = "StartTriggerer" class TriggerStateChanges(BaseModel): @@ -295,7 +295,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess): """ TriggerRunnerSupervisor is responsible for monitoring the subprocess and marshalling DB access. - This class (which runs in the main process) is responsible for querying the DB, sending RunTrigger + This class (which runs in the main/sync process) is responsible for querying the DB, sending RunTrigger workload messages to the subprocess, and collecting results and updating them in the DB. """ @@ -342,8 +342,8 @@ def start( # type: ignore[override] ): proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs) - msg = messages.StartTriggerer(requests_fd=proc._requests_fd) - proc.send_msg(msg) + msg = messages.StartTriggerer() + proc.send_msg(msg, request_id=0) return proc @functools.cached_property @@ -355,7 +355,7 @@ def client(self) -> Client: client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override] + def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, TaskStatesResponse, @@ -454,8 +454,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) - else: raise ValueError(f"Unknown message type {type(msg)}") - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) def run(self) -> None: """Run synchronously and handle all database reads/writes.""" @@ -628,7 +627,7 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke ), ) - def _process_log_messages_from_subprocess(self) -> Generator[None, bytes, None]: + def _process_log_messages_from_subprocess(self) -> Generator[None, bytes | bytearray, None]: import msgspec from structlog.stdlib import NAME_TO_LEVEL @@ -691,14 +690,60 @@ class TriggerDetails(TypedDict): events: int +@attrs.define(kw_only=True) +class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]): + _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer") + _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader") + + body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field( + factory=lambda: TypeAdapter(ToTriggerRunner), repr=False + ) + + _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + + def _read_frame(self): + from asgiref.sync import async_to_sync + + return async_to_sync(self._aread_frame)() + + def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + from asgiref.sync import async_to_sync + + return async_to_sync(self.asend)(msg) + + async def _aread_frame(self): + len_bytes = await self._async_reader.readexactly(4) + len = int.from_bytes(len_bytes, byteorder="big") + if len >= 2**32: + raise OverflowError(f"Refusing to receive messages larger than 4GiB {len=}") + + buffer = await self._async_reader.readexactly(len) + return self.resp_decoder.decode(buffer) + + async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None: + frame = await self._aread_frame() + if frame.id != expect_id: + # Given the lock we take out in `asend`, this _shouldn't_ be possible, but I'd rather fail with + # this explicit error return the wrong type of message back to a Trigger + raise RuntimeError(f"Response read out of order! Got {frame.id=}, {expect_id=}") + return self._from_frame(frame) + + async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + async with self._lock: + self._async_writer.write(bytes) + + return await self._aget_response(frame.id) + + class TriggerRunner: """ Runtime environment for all triggers. - Mainly runs inside its own thread, where it hands control off to an asyncio - event loop, but is also sometimes interacted with from the main thread - (where all the DB queries are done). All communication between threads is - done via Deques. + Mainly runs inside its own process, where it hands control off to an asyncio + event loop. All communication between this and it's (sync) supervisor is done via sockets """ # Maps trigger IDs to their running tasks and other info @@ -726,10 +771,7 @@ class TriggerRunner: # TODO: connect this to the parent process log: FilteringBoundLogger = structlog.get_logger() - requests_sock: asyncio.StreamWriter - response_sock: asyncio.StreamReader - - decoder: TypeAdapter[ToTriggerRunner] + comms_decoder: TriggerCommsDecoder def __init__(self): super().__init__() @@ -740,7 +782,6 @@ def __init__(self): self.events = deque() self.failed_triggers = deque() self.job_id = None - self.decoder = TypeAdapter(ToTriggerRunner) def run(self): """Sync entrypoint - just run a run in an async loop.""" @@ -796,36 +837,21 @@ async def init_comms(self): """ from airflow.sdk.execution_time import task_runner - loop = asyncio.get_event_loop() + # Yes, we read and write to stdin! It's a socket, not a normal stdin. + reader, writer = await asyncio.open_connection(sock=socket(fileno=0)) - comms_decoder = task_runner.CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]( - input=sys.stdin, - decoder=self.decoder, + self.comms_decoder = TriggerCommsDecoder( + async_writer=writer, + async_reader=reader, ) - task_runner.SUPERVISOR_COMMS = comms_decoder - - async def connect_stdin() -> asyncio.StreamReader: - reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, sys.stdin) - return reader - - self.response_sock = await connect_stdin() + task_runner.SUPERVISOR_COMMS = self.comms_decoder - line = await self.response_sock.readline() + msg = await self.comms_decoder._aget_response(expect_id=0) - msg = self.decoder.validate_json(line) if not isinstance(msg, messages.StartTriggerer): raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - writer_transport, writer_protocol = await loop.connect_write_pipe( - lambda: asyncio.streams.FlowControlMixin(loop=loop), - comms_decoder.request_socket, - ) - self.requests_sock = asyncio.streams.StreamWriter(writer_transport, writer_protocol, None, loop) - async def create_triggers(self): """Drain the to_create queue and create all new triggers that have been requested in the DB.""" while self.to_create: @@ -934,8 +960,6 @@ async def cleanup_finished_triggers(self) -> list[int]: return finished_ids async def sync_state_to_supervisor(self, finished_ids: list[int]): - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Copy out of our deques in threadsafe manner to sync state with parent events_to_send = [] while self.events: @@ -961,19 +985,17 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]): if not finished_ids: msg.finished = None - # Block triggers from making any requests for the duration of this - async with SUPERVISOR_COMMS.lock: - # Tell the monitor that we've finished triggers so it can update things - self.requests_sock.write(msg.model_dump_json(exclude_none=True).encode() + b"\n") - line = await self.response_sock.readline() - - if line == b"": # EoF received! + # Tell the monitor that we've finished triggers so it can update things + try: + resp = await self.comms_decoder.asend(msg) + except asyncio.IncompleteReadError: if task := asyncio.current_task(): task.cancel("EOF - shutting down") + return + raise - resp = self.decoder.validate_json(line) if not isinstance(resp, messages.TriggerStateSync): - raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got f{type(msg)}") + raise RuntimeError(f"Expected to get a TriggerStateSync message, instead we got {type(msg)}") self.to_create.extend(resp.to_create) self.to_cancel.extend(resp.to_cancel) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index c9e974b8cdb9a..02121f0194e82 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -30,10 +30,11 @@ from datetime import datetime, timedelta from logging.config import dictConfig from pathlib import Path -from socket import socket +from socket import socket, socketpair from unittest import mock from unittest.mock import MagicMock +import msgspec import pytest import time_machine from sqlalchemy import func, select @@ -54,7 +55,6 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.execution_time.supervisor import mkpipe from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import create_session @@ -138,20 +138,19 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces logger_filehandle = MagicMock() proc.create_time.return_value = time.time() proc.wait.return_value = 0 - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socketpair() ret = DagFileProcessorProcess( process_log=MagicMock(), id=uuid7(), pid=1234, process=proc, stdin=write_end, - requests_fd=123, logger_filehandle=logger_filehandle, client=MagicMock(), ) if start_time: ret.start_time = start_time - ret._num_open_sockets = 0 + ret._open_sockets.clear() return ret, read_end @pytest.fixture @@ -552,18 +551,17 @@ def test_kill_timed_out_processors_no_kill(self): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.parametrize( - ["callbacks", "path", "expected_buffer"], + ["callbacks", "path", "expected_body"], [ pytest.param( [], "/opt/airflow/dags/test_dag.py", - b"{" - b'"file":"/opt/airflow/dags/test_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,' - b'"callback_requests":[],' - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/test_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [], + "type": "DagFileParseRequest", + }, ), pytest.param( [ @@ -577,44 +575,39 @@ def test_kill_timed_out_processors_no_kill(self): ) ], "/opt/airflow/dags/dag_callback_dag.py", - b"{" - b'"file":"/opt/airflow/dags/dag_callback_dag.py",' - b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,"callback_requests":' - b"[" - b"{" - b'"filepath":"dag_callback_dag.py",' - b'"bundle_name":"testing",' - b'"bundle_version":null,' - b'"msg":null,' - b'"dag_id":"dag_id",' - b'"run_id":"run_id",' - b'"is_failure_callback":false,' - b'"type":"DagCallbackRequest"' - b"}" - b"]," - b'"type":"DagFileParseRequest"' - b"}\n", + { + "file": "/opt/airflow/dags/dag_callback_dag.py", + "bundle_path": "/opt/airflow/dags", + "callback_requests": [ + { + "filepath": "dag_callback_dag.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": None, + "dag_id": "dag_id", + "run_id": "run_id", + "is_failure_callback": False, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ), ], ) - def test_serialize_callback_requests(self, callbacks, path, expected_buffer): + def test_serialize_callback_requests(self, callbacks, path, expected_body): + from airflow.sdk.execution_time.comms import _ResponseFrame + processor, read_socket = self.mock_processor() processor._on_child_started(callbacks, path, bundle_path=Path("/opt/airflow/dags")) read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(4096) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError): - # no response written, valid for some message types. - pass - - assert val == expected_buffer + # Read response from the read end of the socket + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.body == expected_body @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 28ce7a8c23fe6..8d77da61cafb9 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -42,7 +42,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.api.client import Client -from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -87,7 +87,6 @@ def _process_file( DagFileParseRequest( file=file_path, bundle_path=TEST_DAG_FOLDER, - requests_fd=1, callback_requests=callback_requests or [], ), log=structlog.get_logger(), @@ -393,26 +392,38 @@ def disable_capturing(): @pytest.mark.usefixtures("testing_dag_bundle") @pytest.mark.usefixtures("disable_capturing") -def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency): +def test_parse_file_entrypoint_parses_dag_callbacks(mocker): r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags","requests_fd":' - + str(w2.fileno()).encode("ascii") - + b',"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' - b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": ' - b'"manual__2024-12-30T21:02:55.203691+00:00", ' - b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": "DagFileParseRequest"}\n' + + frame = comms._ResponseFrame( + id=1, + body={ + "file": "/files/dags/wait.py", + "bundle_path": "/files/dags", + "callback_requests": [ + { + "filepath": "wait.py", + "bundle_name": "testing", + "bundle_version": None, + "msg": "task_failure", + "dag_id": "wait_to_fail", + "run_id": "manual__2024-12-30T21:02:55.203691+00:00", + "is_failure_callback": True, + "type": "DagCallbackRequest", + } + ], + "type": "DagFileParseRequest", + }, ) + bytes = frame.as_bytes() + w.sendall(bytes) - decoder = CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - input=r.makefile("r"), - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder = comms.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + socket=r, + body_decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), ) - msg = decoder.get_message() + msg = decoder._get_response() assert isinstance(msg, DagFileParseRequest) assert msg.file == "/files/dags/wait.py" assert msg.callback_requests == [ @@ -455,7 +466,7 @@ def fake_collect_dags(self, *args, **kwargs): ) ] _parse_file( - DagFileParseRequest(file="A", bundle_path="no matter", requests_fd=1, callback_requests=requests), + DagFileParseRequest(file="A", bundle_path="no matter", callback_requests=requests), log=structlog.get_logger(), ) @@ -489,8 +500,6 @@ def fake_collect_dags(self, *args, **kwargs): bundle_version=None, ) ] - _parse_file( - DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() - ) + _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) assert called is True diff --git a/airflow-core/tests/unit/hooks/test_base.py b/airflow-core/tests/unit/hooks/test_base.py index 8c2d44f84854b..6d54827e67ced 100644 --- a/airflow-core/tests/unit/hooks/test_base.py +++ b/airflow-core/tests/unit/hooks/test_base.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -from unittest import mock - import pytest from airflow.exceptions import AirflowNotFoundException @@ -54,18 +52,18 @@ def test_get_connection(self, mock_supervisor_comms): extra='{"extra_key": "extra_value"}', ) - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn hook = BaseHook(logger_name="") hook.get_connection(conn_id="test_conn") - mock_supervisor_comms.send_request.assert_called_once_with( - msg=GetConnection(conn_id="test_conn"), log=mock.ANY + mock_supervisor_comms.send.assert_called_once_with( + msg=GetConnection(conn_id="test_conn"), ) def test_get_connection_not_found(self, mock_supervisor_comms): conn_id = "test_conn" hook = BaseHook() - mock_supervisor_comms.get_message.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) with pytest.raises(AirflowNotFoundException, match=rf".*{conn_id}.*"): hook.get_connection(conn_id=conn_id) @@ -85,5 +83,4 @@ def test_get_connection_secrets_backend_configured(self, mock_supervisor_comms, assert retrieved_conn.conn_id == "CONN_A" - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d3c9ae6f27dc9..d3cecb5fc94e6 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -34,6 +34,7 @@ from airflow.hooks.base import BaseHook from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( + TriggerCommsDecoder, TriggererJobRunner, TriggerRunner, TriggerRunnerSupervisor, @@ -172,7 +173,6 @@ def builder(job=None): pid=process.pid, stdin=mocker.Mock(), process=process, - requests_fd=-1, capacity=10, ) # Mock the selector @@ -302,10 +302,9 @@ async def test_invalid_trigger(self, supervisor_builder): id=1, ti=None, classpath="fake.classpath", encrypted_kwargs={} ) trigger_runner = TriggerRunner() - trigger_runner.requests_sock = MagicMock() - trigger_runner.response_sock = AsyncMock() - trigger_runner.response_sock.readline.return_value = ( - b'{"type": "TriggerStateSync", "to_create": [], "to_cancel": []}\n' + trigger_runner.comms_decoder = AsyncMock(spec=TriggerCommsDecoder) + trigger_runner.comms_decoder.asend.return_value = messages.TriggerStateSync( + to_create=[], to_cancel=[] ) trigger_runner.to_create.append(workload) @@ -316,9 +315,9 @@ async def test_invalid_trigger(self, supervisor_builder): await trigger_runner.sync_state_to_supervisor(ids) # Check that we sent the right info in the failure message - assert trigger_runner.requests_sock.write.call_count == 1 - blob = trigger_runner.requests_sock.write.mock_calls[0].args[0] - msg = messages.TriggerStateChanges.model_validate_json(blob) + assert trigger_runner.comms_decoder.asend.call_count == 1 + msg = trigger_runner.comms_decoder.asend.mock_calls[0].args[0] + assert isinstance(msg, messages.TriggerStateChanges) assert msg.events is None assert msg.failures is not None @@ -552,6 +551,7 @@ def test_failed_trigger(session, dag_maker, supervisor_builder): ) ], ), + req_id=1, log=MagicMock(), ) @@ -622,10 +622,6 @@ def handle_events(self): super().handle_events() -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio @pytest.mark.execution_timeout(20) async def test_trigger_can_access_variables_connections_and_xcoms(session, dag_maker): @@ -726,13 +722,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"count": dag_run_states_count, "dag_run_state": dag_run_state}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" # Create the test DAG and task @@ -822,13 +813,8 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: yield TriggerEvent({"ti_count": ti_count, "dr_count": dr_count, "task_states": task_states}) -@pytest.mark.xfail( - reason="We know that test is flaky and have no time to fix it before 3.0. " - "We should fix it later. TODO: AIP-72" -) @pytest.mark.asyncio -@pytest.mark.flaky(reruns=2, reruns_delay=10) -@pytest.mark.execution_timeout(30) +@pytest.mark.execution_timeout(10) async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states.""" # Create the test DAG and task diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index b33159fb9deea..cb2ab917a23f4 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -1885,7 +1885,7 @@ def producer_with_inactive(*, outlet_events): def test_inlet_asset_extra(self, dag_maker, session, mock_supervisor_comms): from airflow.sdk.definitions.asset import Asset - mock_supervisor_comms.get_message.return_value = AssetEventsResult( + mock_supervisor_comms.send.return_value = AssetEventsResult( asset_events=[ AssetEventResponse( id=1, @@ -1959,7 +1959,7 @@ def read(*, inlet_events): @pytest.mark.need_serialized_dag def test_inlet_unresolved_asset_alias(self, dag_maker, session, mock_supervisor_comms): asset_alias_name = "test_inlet_asset_extra_asset_alias" - mock_supervisor_comms.get_message.return_value = AssetEventsResult(asset_events=[]) + mock_supervisor_comms.send.return_value = AssetEventsResult(asset_events=[]) asset_alias_model = AssetAliasModel(name=asset_alias_name) session.add(asset_alias_model) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index d445243cc4f8d..2288fefc26880 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1948,17 +1948,27 @@ def override_caplog(request): @pytest.fixture -def mock_supervisor_comms(): +def mock_supervisor_comms(monkeypatch): # for back-compat from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if not AIRFLOW_V_3_0_PLUS: yield None return - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as supervisor_comms: - yield supervisor_comms + + from airflow.sdk.execution_time import comms, task_runner + + # Deal with TaskSDK 1.0/1.1 vs 1.2+. Annoying, and shouldn't need to exist once the separation between + # core and TaskSDK is finished + if CommsDecoder := getattr(comms, "CommsDecoder", None): + comms = mock.create_autospec(CommsDecoder) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + else: + CommsDecoder = getattr(task_runner, "CommsDecoder") + comms = mock.create_autospec(CommsDecoder) + comms.send = comms.get_message + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + yield comms @pytest.fixture @@ -1983,7 +1993,6 @@ def mocked_parse(spy_agency): id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1 ), file="", - requests_fd=0, ), "example_dag_id", CustomOperator(task_id="hello"), @@ -2190,9 +2199,10 @@ def _create_task_instance( ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, ti_context=ti_context, start_date=start_date, # type: ignore + # Back-compat of task-sdk. Only affects us when we manually create these objects in tests. + **({"requests_fd": 0} if "requests_fd" in StartupDetails.model_fields else {}), # type: ignore ) ti = mocked_parse(startup_details, dag_id, task) diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_athena.py b/providers/amazon/tests/unit/amazon/aws/links/test_athena.py index 8abe65cf2cb8c..99a3536d17838 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_athena.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_athena.py @@ -30,7 +30,7 @@ class TestAthenaQueryResultsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=AthenaQueryResultsLink.key, value={ "region_name": "eu-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_batch.py b/providers/amazon/tests/unit/amazon/aws/links/test_batch.py index 8ecf7022f2162..70cd65655bfec 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_batch.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_batch.py @@ -34,7 +34,7 @@ class TestBatchJobDefinitionLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -58,7 +58,7 @@ class TestBatchJobDetailsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "cn-north-1", @@ -80,7 +80,7 @@ class TestBatchJobQueueLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py b/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py index 2c60aa27d5085..9b88270b5bd30 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_comprehend.py @@ -34,7 +34,7 @@ class TestComprehendPiiEntitiesDetectionLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): test_job_id = "123-345-678" if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -61,7 +61,7 @@ def test_extra_link(self, mock_supervisor_comms): "arn:aws:comprehend:us-east-1:0123456789:document-classifier/test-custom-document-classifier" ) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py b/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py index cd7f5be5cbe63..79c8469b701f7 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_datasync.py @@ -34,7 +34,7 @@ class TestDataSyncTaskLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): task_id = TASK_ID if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", @@ -56,7 +56,7 @@ class TestDataSyncTaskExecutionLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py index 79680d3caa136..f451c910058cf 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_ec2.py @@ -32,7 +32,7 @@ class TestEC2InstanceLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -66,7 +66,7 @@ def test_instance_id_filter(self): def test_extra_link(self, mock_supervisor_comms): instance_list = ",:".join(self.INSTANCE_IDS) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_emr.py b/providers/amazon/tests/unit/amazon/aws/links/test_emr.py index cbe2544bd86db..feda067f7cc7a 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_emr.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_emr.py @@ -45,7 +45,7 @@ class TestEmrClusterLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", @@ -82,7 +82,7 @@ class TestEmrLogsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-2", @@ -127,7 +127,7 @@ def test_extra_link(self, mocked_emr_serverless_hook, mock_supervisor_comms): mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "conn_id": "aws-test", @@ -160,7 +160,7 @@ def test_extra_link(self, mocked_emr_serverless_hook, mock_supervisor_comms): mocked_client.get_dashboard_for_job_run.return_value = {"url": "https://example.com/?authToken=1234"} if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "conn_id": "aws-test", @@ -254,7 +254,7 @@ class TestEmrServerlessS3LogsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", @@ -282,7 +282,7 @@ class TestEmrServerlessCloudWatchLogsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_glue.py b/providers/amazon/tests/unit/amazon/aws/links/test_glue.py index b73c65182f662..2b1f076e149df 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_glue.py @@ -30,7 +30,7 @@ class TestGlueJobRunDetailsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "ap-southeast-2", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_logs.py b/providers/amazon/tests/unit/amazon/aws/links/test_logs.py index 8c642e55c1a6a..2c90eecd232ad 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_logs.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_logs.py @@ -30,7 +30,7 @@ class TestCloudWatchEventsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-west-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py index 25bae4efdfa62..f08d7df93d509 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker.py @@ -31,7 +31,7 @@ class TestSageMakerTransformDetailsLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="sagemaker_transform_job_details", value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py index 1b3cbed1f142d..bb749727323e2 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py @@ -30,7 +30,7 @@ class TestSageMakerUnifiedStudioLink(BaseAwsLinksTestCase): def test_extra_link(self, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "us-east-1", diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py b/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py index ca4505855e9f8..acfad7e98e96c 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_step_function.py @@ -47,7 +47,7 @@ class TestStateMachineDetailsLink(BaseAwsLinksTestCase): ) def test_extra_link(self, state_machine_arn, expected_url: str, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", @@ -82,7 +82,7 @@ class TestStateMachineExecutionsDetailsLink(BaseAwsLinksTestCase): ) def test_extra_link(self, execution_arn, expected_url: str, mock_supervisor_comms): if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=self.link_class.key, value={ "region_name": "eu-west-1", diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py b/providers/common/io/tests/unit/common/io/xcom/test_backend.py index 99fb46a66c7e5..50c71c99be938 100644 --- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py +++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py @@ -106,9 +106,7 @@ def test_value_db(self, task_instance, mock_supervisor_comms, session): if AIRFLOW_V_3_0_PLUS: # When using XComObjectStorageBackend, the value is stored in the db is serialized with json dumps # so we need to mimic that same behavior below. - mock_supervisor_comms.get_message.return_value = XComResult( - key="return_value", value={"key": "value"} - ) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value={"key": "value"}) value = XCom.get_value( key=XCOM_RETURN_KEY, @@ -169,7 +167,7 @@ def test_value_storage(self, task_instance, mock_supervisor_comms, session): assert p.exists() is True if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=XCOM_RETURN_KEY, value={"key": "bigvaluebigvaluebigvalue" * 100} ) @@ -213,7 +211,12 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): ) if AIRFLOW_V_3_0_PLUS: - path = mock_supervisor_comms.send_request.call_args_list[-1].kwargs["msg"].value + if hasattr(mock_supervisor_comms, "send_request"): + # Back-compat of task-sdk. Only affects us when we manually create these objects in tests. + last_call = mock_supervisor_comms.send_request.call_args_list[-1] + else: + last_call = mock_supervisor_comms.send.call_args_list[-1] + path = (last_call.kwargs.get("msg") or last_call.args[0]).value XComModel.set( key=XCOM_RETURN_KEY, value=path, @@ -251,7 +254,7 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): assert p.exists() is True if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=XCOM_RETURN_KEY, value={"key": "superlargevalue" * 100} ) value = XCom.get_value( @@ -261,7 +264,7 @@ def test_clear(self, task_instance, session, mock_supervisor_comms): assert value if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult(key=XCOM_RETURN_KEY, value=path) + mock_supervisor_comms.send.return_value = XComResult(key=XCOM_RETURN_KEY, value=path) XCom.delete( dag_id=task_instance.dag_id, task_id=task_instance.task_id, @@ -356,7 +359,7 @@ def test_compression(self, task_instance, session, mock_supervisor_comms): assert data.endswith(".gz") if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key=XCOM_RETURN_KEY, value={"key": "superlargevalue" * 100} ) diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py index 9e8408c56955c..616423d5ba984 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py @@ -665,7 +665,7 @@ def test_run_job_operator_link( ti.xcom_push(key="job_run_url", value=_run_response["data"]["href"]) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="job_run_url", value=EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( account_id=account_id or DEFAULT_ACCOUNT_ID, diff --git a/providers/google/tests/unit/google/cloud/links/test_dataplex.py b/providers/google/tests/unit/google/cloud/links/test_dataplex.py index 3ec2905751390..babdf9127c8e1 100644 --- a/providers/google/tests/unit/google/cloud/links/test_dataplex.py +++ b/providers/google/tests/unit/google/cloud/links/test_dataplex.py @@ -123,7 +123,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "lake_id": ti.task.lake_id, @@ -153,7 +153,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "project_id": ti.task.project_id, @@ -183,7 +183,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "lake_id": ti.task.lake_id, @@ -212,7 +212,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "entry_group_id": ti.task.entry_group_id, @@ -242,7 +242,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -270,7 +270,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "entry_type_id": ti.task.entry_type_id, @@ -300,7 +300,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -328,7 +328,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "aspect_type_id": ti.task.aspect_type_id, @@ -358,7 +358,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -387,7 +387,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "entry_id": ti.task.entry_id, diff --git a/providers/google/tests/unit/google/cloud/links/test_translate.py b/providers/google/tests/unit/google/cloud/links/test_translate.py index ddfc3e205e8e2..640b958a3841e 100644 --- a/providers/google/tests/unit/google/cloud/links/test_translate.py +++ b/providers/google/tests/unit/google/cloud/links/test_translate.py @@ -62,7 +62,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, dataset_id=DATASET, project_id=GCP_PROJECT_ID) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={"location": ti.task.location, "dataset_id": DATASET, "project_id": GCP_PROJECT_ID}, ) @@ -85,7 +85,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis session.commit() link.persist(context={"ti": ti}, task_instance=ti.task, project_id=GCP_PROJECT_ID) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "project_id": GCP_PROJECT_ID, @@ -121,7 +121,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis project_id=GCP_PROJECT_ID, ) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, @@ -158,7 +158,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis project_id=GCP_PROJECT_ID, ) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={ "location": ti.task.location, diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 35072befd72e7..76a21a04be380 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -1132,7 +1132,7 @@ def test_create_cluster_operator_extra_links( assert operator_extra_link.name == "Dataproc Cluster" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value="", ) @@ -1142,7 +1142,7 @@ def test_create_cluster_operator_extra_links( ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="key", value={"cluster_id": "cluster_name", "project_id": "test-project", "region": "test-location"}, ) @@ -2014,7 +2014,7 @@ def test_submit_job_operator_extra_links( assert operator_extra_link.name == "Dataproc Job" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_job", value="", ) @@ -2025,7 +2025,7 @@ def test_submit_job_operator_extra_links( ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_job", value=DATAPROC_JOB_EXPECTED, ) @@ -2230,7 +2230,7 @@ def test_update_cluster_operator_extra_links( assert operator_extra_link.name == "Dataproc Cluster" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_cluster", value="", ) @@ -2240,7 +2240,7 @@ def test_update_cluster_operator_extra_links( ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, ) @@ -2456,7 +2456,7 @@ def test_instantiate_workflow_operator_extra_links( assert operator_extra_link.name == "Dataproc Workflow" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value="", ) @@ -2465,7 +2465,7 @@ def test_instantiate_workflow_operator_extra_links( ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED, ) @@ -3138,7 +3138,7 @@ def test_instantiate_inline_workflow_operator_extra_links( operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value="", ) @@ -3147,7 +3147,7 @@ def test_instantiate_inline_workflow_operator_extra_links( ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) if AIRFLOW_V_3_0_PLUS: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py index 788f52efc3772..23d7d18d4a407 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_data_factory.py @@ -253,7 +253,7 @@ def test_run_pipeline_operator_link( ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"], ) diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py index 45c538c2c400c..62a1021ead8a5 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_synapse.py @@ -292,7 +292,7 @@ def test_run_pipeline_operator_link(self, create_task_instance_of_operator, mock ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"], ) diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 8f24273c50730..beb1cdce9502c 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -47,7 +47,6 @@ classifiers = [ dependencies = [ "apache-airflow-core<3.1.0,>=3.0.2", - "aiologic>=0.14.0", "attrs>=24.2.0, !=25.2.0", "fsspec>=2023.10.0", "httpx>=0.27.0", diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 0c330652956f2..c75b2e9c75d98 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -70,9 +70,8 @@ def set( map_index=map_index, ) - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -107,9 +106,8 @@ def _set_xcom_in_db( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -181,23 +179,16 @@ def _get_xcom_db_ref( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -241,23 +232,16 @@ def get_one( """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - include_prior_dates=include_prior_dates, - ), - ) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send( + GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + include_prior_dates=include_prior_dates, + ), + ) if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -322,9 +306,8 @@ def delete( map_index=map_index, ) cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( + SUPERVISOR_COMMS.send( + DeleteXCom( key=key, dag_id=dag_id, task_id=task_id, diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index d899dd3718336..44479fbb9cd42 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -69,18 +69,16 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> ) def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: - import structlog - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name=self.__class__.__qualname__) - def _fetch_asset(name: str) -> Asset: - SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name)) - if isinstance(msg := SUPERVISOR_COMMS.get_message(), ErrorResponse): - raise AirflowRuntimeError(msg) - return Asset(**msg.model_dump(exclude={"type"})) + resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + if resp is None: + raise RuntimeError("Empty non-error response received") + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) + return Asset(**resp.model_dump(exclude={"type"})) value: Any for key, param in inspect.signature(self.python_callable).parameters.items(): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d0622cf6ffc86..be2f4f8eacd46 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -19,12 +19,17 @@ Communication protocol between the Supervisor and the task process ================================================================== -* All communication is done over stdout/stdin in the form of "JSON lines" (each - message is a single JSON document terminated by `\n` character) -* Messages from the subprocess are all log messages and are sent directly to the log +* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame + (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same + encoding +* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) * No messages are sent to task process except in response to a request. (This is because the task process will be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom value etc.) +* Every request returns a response, even if the frame is otherwise empty. +* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a + bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending + requests. The reason this communication protocol exists, rather than the task process speaking directly to the Task Execution API server is because: @@ -43,15 +48,20 @@ from __future__ import annotations +import itertools from collections.abc import Iterator from datetime import datetime from functools import cached_property -from typing import Annotated, Any, Literal, Union +from pathlib import Path +from socket import socket +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from uuid import UUID import attrs +import msgspec +import structlog from fastapi import Body -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetEventDagRunReference, @@ -80,6 +90,145 @@ ) from airflow.sdk.exceptions import ErrorType +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + +SendMsgType = TypeVar("SendMsgType", bound=BaseModel) +ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) + + +def _msgpack_enc_hook(obj: Any) -> Any: + import pendulum + + if isinstance(obj, pendulum.DateTime): + # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native + # encoding + return datetime( + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo + ) + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, BaseModel): + return obj.model_dump(exclude_unset=True) + + # Raise a NotImplementedError for other types + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def _new_encoder() -> msgspec.msgpack.Encoder: + return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) + + +class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The request id, set by the sender. + + This is used to allow "pipeling" of requests and to be able to tie response to requests, which is + particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. + """ + body: dict[str, Any] | None + + req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() + + def as_bytes(self) -> bytearray: + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + buffer = bytearray(256) + + self.req_encoder.encode_into(self, buffer, 4) + + n = len(buffer) - 4 + if n >= 2**32: + raise OverflowError(f"Cannot send messages larger than 4GiB {n=}") + buffer[:4] = n.to_bytes(4, byteorder="big") + + return buffer + + +class _ResponseFrame(_RequestFrame, frozen=True): + id: int + """ + The id of the request this is a response to + """ + body: dict[str, Any] | None = None + error: dict[str, Any] | None = None + + +@attrs.define() +class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): + """Handle communication between the task in this process and the supervisor parent process.""" + + log: Logger = attrs.field(repr=False, factory=structlog.get_logger) + socket: socket = attrs.field(factory=lambda: socket(fileno=0)) + + resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( + factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False + ) + + id_counter: Iterator[int] = attrs.field(factory=itertools.count) + + # We could be "clever" here and set the default to this based type parameters and a custom + # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a + # "sort of wrong default" + body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + def send(self, msg: SendMsgType) -> ReceiveMsgType | None: + """Send a request to the parent and block until the response is received.""" + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + bytes = frame.as_bytes() + + self.socket.sendall(bytes) + + return self._get_response() + + def _read_frame(self): + """ + Get a message from the parent. + + This will block until the message has been received. + """ + if self.socket: + self.socket.setblocking(True) + len_bytes = self.socket.recv(4) + + if len_bytes == b"": + raise EOFError("Request socket closed before length") + + len = int.from_bytes(len_bytes, byteorder="big") + + buffer = bytearray(len) + nread = self.socket.recv_into(buffer) + if nread != len: + raise RuntimeError( + f"unable to read full response in child. (We read {nread}, but expected {len})" + ) + if nread == 0: + raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") + + return self.resp_decoder.decode(buffer) + + def _from_frame(self, frame) -> ReceiveMsgType | None: + from airflow.sdk.exceptions import AirflowRuntimeError + + if frame.error is not None: + err = self.err_decoder.validate_python(frame.error) + raise AirflowRuntimeError(error=err) + + if frame.body is None: + return None + + try: + return self.body_decoder.validate_python(frame.body) + except Exception: + self.log.exception("Unable to decode message") + raise + + def _get_response(self) -> ReceiveMsgType | None: + frame = self._read_frame() + return self._from_frame(frame) + class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -87,13 +236,7 @@ class StartupDetails(BaseModel): ti: TaskInstance dag_rel_path: str bundle_info: BundleInfo - requests_fd: int start_date: datetime - """ - The channel for the task to send requests over. - - Responses will come back on stdin - """ ti_context: TIRunContext type: Literal["StartupDetails"] = "StartupDetails" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index de5c77cb54635..40890d05cc2a9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -150,13 +150,7 @@ def _get_connection(conn_id: str) -> Connection: from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -203,13 +197,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) - # we need to make sure that we "atomically" send a request and get the response to that - # back so that two triggers don't end up interleaving requests and create a possible - # race condition where the wrong trigger reads the response. - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -259,11 +247,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) def _delete_variable(key: str) -> None: @@ -275,12 +259,7 @@ def _delete_variable(key: str) -> None: from airflow.sdk.execution_time.comms import DeleteVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was - # primarily added for triggers but it doesn't make sense to have it in some places - # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=DeleteVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -383,23 +362,29 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: @staticmethod def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: from airflow.sdk.definitions.asset import Asset - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.comms import ( + ErrorResponse, + GetAssetByName, + GetAssetByUri, + ToSupervisor, + ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if name: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + msg = GetAssetByName(name=name) elif uri: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + msg = GetAssetByUri(uri=uri) else: raise ValueError("Either name or uri must be provided") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetResult) - return Asset(**msg.model_dump(exclude={"type"})) + assert isinstance(resp, AssetResult) + return Asset(**resp.model_dump(exclude={"type"})) @attrs.define @@ -565,9 +550,11 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve ErrorResponse, GetAssetEventByAsset, GetAssetEventByAssetAlias, + ToSupervisor, ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + msg: ToSupervisor if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): @@ -577,31 +564,33 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve if isinstance(obj, Asset): asset = self._assets[AssetUniqueKey.from_asset(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri)) + msg = GetAssetEventByAsset(name=asset.name, uri=asset.uri) elif isinstance(obj, AssetNameRef): try: asset = next(a for k, a in self._assets.items() if k.name == obj.name) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=None)) + msg = GetAssetEventByAsset(name=asset.name, uri=None) elif isinstance(obj, AssetUriRef): try: asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=None, uri=asset.uri)) + msg = GetAssetEventByAsset(name=None, uri=asset.uri) elif isinstance(obj, AssetAlias): asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name)) + msg = GetAssetEventByAssetAlias(alias_name=asset_alias.name) + else: + raise TypeError(f"`key` is of unknown type ({type(key).__name__})") - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) if TYPE_CHECKING: - assert isinstance(msg, AssetEventsResult) + assert isinstance(resp, AssetEventsResult) - return list(msg.iter_asset_event_results()) + return list(resp.iter_asset_event_results()) @cache # Prevent multiple API access. @@ -613,8 +602,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 9cf9acfac81bb..d8b356870c7aa 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -91,16 +91,14 @@ def __len__(self) -> int: task = self._xcom_arg.operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComCount( + msg = SUPERVISOR_COMMS.send( + GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, run_id=self._ti.run_id, task_id=task.task_id, ), ) - msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise RuntimeError(msg) if not isinstance(msg, XComCountResponse): @@ -127,43 +125,37 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if isinstance(key, slice): start, stop, step = _coerce_slice(key) - with SUPERVISOR_COMMS.lock: - source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( - key=xcom_arg.key, - dag_id=source.dag_id, - task_id=source.task_id, - run_id=self._ti.run_id, - start=start, - stop=stop, - step=step, - ), - ) - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComSequenceSliceResult): - raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") - return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] - - if not isinstance(key, int): - if (index := getattr(key, "__index__", None)) is not None: - key = index() - raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") - - with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceItem( + msg = SUPERVISOR_COMMS.send( + GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, task_id=source.task_id, run_id=self._ti.run_id, - offset=key, + start=start, + stop=stop, + step=step, ), ) - msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") + return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] + + if not isinstance(key, int): + if (index := getattr(key, "__index__", None)) is not None: + key = index() + raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") + + source = (xcom_arg := self._xcom_arg).operator + msg = SUPERVISOR_COMMS.send( + GetXComSequenceItem( + key=xcom_arg.key, + dag_id=source.dag_id, + task_id=source.task_id, + run_id=self._ti.run_id, + offset=key, + ), + ) if isinstance(msg, ErrorResponse): raise IndexError(key) if not isinstance(msg, XComSequenceIndexResult): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index dc63daeef7bb7..2ae554cace1a5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,12 +27,13 @@ import signal import sys import time +import weakref from collections import deque from collections.abc import Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus -from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair +from socket import socket, socketpair from typing import ( TYPE_CHECKING, BinaryIO, @@ -44,7 +45,6 @@ ) from uuid import UUID -import aiologic import attrs import httpx import msgspec @@ -64,6 +64,7 @@ XComSequenceIndexResponse, ) from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -109,6 +110,8 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -180,23 +183,6 @@ ********************************************************************************************************""" -def mkpipe( - remote_read: bool = False, -) -> tuple[socket, socket]: - """Create a pair of connected sockets.""" - rsock, wsock = socketpair() - local, remote = (wsock, rsock) if remote_read else (rsock, wsock) - - if remote_read: - # Setting a 4KB buffer here if possible, if not, it still works, so we will suppress all exceptions - with suppress(Exception): - local.setsockopt(SO_SNDBUF, SOL_SOCKET, BUFFER_SIZE) - # set nonblocking to True so that send or sendall waits till all data is sent - local.setblocking(True) - - return remote, local - - def _subprocess_main(): from airflow.sdk.execution_time.task_runner import main @@ -224,14 +210,13 @@ def _configure_logs_over_json_channel(log_fd: int): def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, fd, sock, mode in ( - ("stdin", 0, child_stdin, "r"), + # Yes, we want to re-open stdin in write mode! This is cause it is a bi-directional socket, so we can + # read and write to it. + ("stdin", 0, child_stdin, "w"), ("stdout", 1, child_stdout, "w"), ("stderr", 2, child_stderr, "w"), ): - handle = getattr(sys, handle_name) - handle.close() os.dup2(sock.fileno(), fd) del sock @@ -319,7 +304,7 @@ def __getattr__(name: str): def _fork_main( - child_stdin: socket, + requests: socket, child_stdout: socket, child_stderr: socket, log_fd: int, @@ -349,7 +334,7 @@ def _fork_main( _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) - _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + _reopen_std_io_handles(requests, child_stdout, child_stderr) def exit(n: int) -> NoReturn: with suppress(ValueError, OSError): @@ -360,11 +345,11 @@ def exit(n: int) -> NoReturn: last_chance_stderr.flush() # Explicitly close the child-end of our supervisor sockets so - # the parent sees EOF on both "requests" and "logs" channels. + # the parent sees EOF on "logs" channel. with suppress(OSError): os.close(log_fd) with suppress(OSError): - os.close(child_stdin.fileno()) + os.close(requests.fileno()) os._exit(n) if hasattr(atexit, "_clear"): @@ -432,16 +417,18 @@ class WatchedSubprocess: """The decoder to use for incoming messages from the child process.""" _process: psutil.Process = attrs.field(repr=False) - _requests_fd: int """File descriptor for request handling.""" - _num_open_sockets: int = 4 _exit_code: int | None = attrs.field(default=None, init=False) _process_exit_monotonic: float | None = attrs.field(default=None, init=False) - _fd_to_socket_type: dict[int, str] = attrs.field(factory=dict, init=False) + _open_sockets: weakref.WeakKeyDictionary[socket, str] = attrs.field( + factory=weakref.WeakKeyDictionary, init=False + ) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector, repr=False) + _frame_encoder: msgspec.msgpack.Encoder = attrs.field(factory=comms._new_encoder, repr=False) + process_log: FilteringBoundLogger = attrs.field(repr=False) subprocess_logs_to_stdout: bool = False @@ -460,18 +447,19 @@ def start( ) -> Self: """Fork and start a new subprocess with the specified target function.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess - child_stdin, feed_stdin = mkpipe(remote_read=True) - child_stdout, read_stdout = mkpipe() - child_stderr, read_stderr = mkpipe() + child_stdout, read_stdout = socketpair() + child_stderr, read_stderr = socketpair() + + # Place for child to send requests/read responses, and the server side to read/respond + child_requests, read_requests = socketpair() - # Open these socketpair before forking off the child, so that it is open when we fork. - child_comms, read_msgs = mkpipe() - child_logs, read_logs = mkpipe() + # Open the socketpair before forking off the child, so that it is open when we fork. + child_logs, read_logs = socketpair() pid = os.fork() if pid == 0: # Close and delete of the parent end of the sockets. - cls._close_unused_sockets(feed_stdin, read_stdout, read_stderr, read_msgs, read_logs) + cls._close_unused_sockets(read_requests, read_stdout, read_stderr, read_logs) # Python GC should delete these for us, but lets make double sure that we don't keep anything # around in the forked processes, especially things that might involve open files or sockets! @@ -480,28 +468,28 @@ def start( try: # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + _fork_main(child_requests, child_stdout, child_stderr, child_logs.fileno(), target) except BaseException as e: + import traceback + with suppress(BaseException): # We can't use log here, as if we except out of _fork_main something _weird_ went on. - print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + print("Exception in _fork_main, exiting with code 124", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr) # It's really super super important we never exit this block. We are in the forked child, and if we # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) os._exit(124) - requests_fd = child_comms.fileno() - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) + cls._close_unused_sockets(child_stdout, child_stderr, child_logs) logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc = cls( pid=pid, - stdin=feed_stdin, + stdin=read_requests, process=psutil.Process(pid), - requests_fd=requests_fd, process_log=logger, start_time=time.monotonic(), **constructor_kwargs, @@ -510,7 +498,7 @@ def start( proc._register_pipe_readers( stdout=read_stdout, stderr=read_stderr, - requests=read_msgs, + requests=read_requests, logs=read_logs, ) @@ -523,24 +511,26 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - # Track socket types for debugging - self._fd_to_socket_type = { - stdout.fileno(): "stdout", - stderr.fileno(): "stderr", - requests.fileno(): "requests", - logs.fileno(): "logs", - } + # Track the open sockets, and for debugging what type each one is + self._open_sockets.update( + ( + (stdout, "stdout"), + (stderr, "stderr"), + (logs, "logs"), + (requests, "requests"), + ) + ) target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) if self.subprocess_logs_to_stdout: target_loggers += (log,) self.selector.register( - stdout, selectors.EVENT_READ, self._create_socket_handler(target_loggers, channel="stdout") + stdout, selectors.EVENT_READ, self._create_log_forwarder(target_loggers, channel="stdout") ) self.selector.register( stderr, selectors.EVENT_READ, - self._create_socket_handler(target_loggers, channel="stderr", log_level=logging.ERROR), + self._create_log_forwarder(target_loggers, channel="stderr", log_level=logging.ERROR), ) self.selector.register( logs, @@ -552,37 +542,52 @@ def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socke self.selector.register( requests, selectors.EVENT_READ, - make_buffered_socket_reader(self.handle_requests(log), on_close=self._on_socket_closed), + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), ) - def _create_socket_handler(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: + def _create_log_forwarder(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: """Create a socket handler that forwards logs to a logger.""" return make_buffered_socket_reader( forward_to_log(loggers, chan=channel, level=log_level), on_close=self._on_socket_closed ) - def _on_socket_closed(self): + def _on_socket_closed(self, sock: socket): # We want to keep servicing this process until we've read up to EOF from all the sockets. - self._num_open_sockets -= 1 - def send_msg(self, msg: BaseModel, **dump_opts): - """Send the given pydantic message to the subprocess at once by encoding it and adding a line break.""" - b = msg.model_dump_json(**dump_opts).encode() + b"\n" - self.stdin.sendall(b) + with suppress(KeyError): + self.selector.unregister(sock) + del self._open_sockets[sock] + + def send_msg( + self, msg: BaseModel | None, request_id: int, error: ErrorResponse | None = None, **dump_opts + ): + """ + Send the msg as a length-prefixed response frame. + + ``request_id`` is the ID that the client sent in it's request, and has no meaning to the server + + """ + if msg: + frame = _ResponseFrame(id=request_id, body=msg.model_dump(**dump_opts)) + else: + err_resp = error.model_dump() if error else None + frame = _ResponseFrame(id=request_id, error=err_resp) + + self.stdin.sendall(frame.as_bytes()) - def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _RequestFrame, None]: """Handle incoming requests from the task process, respond with the appropriate data.""" while True: - line = yield + request = yield try: - msg = self.decoder.validate_json(line) + msg = self.decoder.validate_python(request.body) except Exception: - log.exception("Unable to decode message", line=line) + log.exception("Unable to decode message", body=request.body) continue try: - self._handle_request(msg, log) + self._handle_request(msg, log, request.id) except ServerResponseError as e: error_details = e.response.json() if e.response else None log.error( @@ -594,27 +599,25 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N # Send error response back to task so that the error appears in the task logs self.send_msg( - ErrorResponse( + msg=None, + error=ErrorResponse( error=ErrorType.API_SERVER_ERROR, detail={ "status_code": e.response.status_code, "message": str(e), "detail": error_details, }, - ) + ), + request_id=request.id, ) - def _handle_request(self, msg, log: FilteringBoundLogger) -> None: + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: raise NotImplementedError() @staticmethod def _close_unused_sockets(*sockets): """Close unused ends of sockets after fork.""" for sock in sockets: - if isinstance(sock, SocketIO): - # If we have the socket IO object, we need to close the underlying socket foricebly here too, - # else we get unclosed socket warnings, and likely leaking FDs too - sock._sock.close() sock.close() def _cleanup_open_sockets(self): @@ -624,20 +627,18 @@ def _cleanup_open_sockets(self): # sockets the supervisor would wait forever thinking they are still # active. This cleanup ensures we always release resources and exit. stuck_sockets = [] - for key in list(self.selector.get_map().values()): - socket_type = self._fd_to_socket_type.get(key.fd, f"unknown-{key.fd}") - stuck_sockets.append(f"{socket_type}({key.fd})") + for sock, socket_type in self._open_sockets.items(): + fileno = "unknown" with suppress(Exception): - self.selector.unregister(key.fileobj) - with suppress(Exception): - key.fileobj.close() # type: ignore[union-attr] + fileno = sock.fileno() + sock.close() + stuck_sockets.append(f"{socket_type}(fd={fileno})") if stuck_sockets: log.warning("Force-closed stuck sockets", pid=self.pid, sockets=stuck_sockets) self.selector.close() - self._close_unused_sockets(self.stdin) - self._num_open_sockets = 0 + self.stdin.close() def kill( self, @@ -736,7 +737,7 @@ def _service_subprocess( events = self.selector.select(timeout=timeout) for key, _ in events: # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) - socket_handler = key.data + socket_handler, on_close = key.data # Example of handler behavior: # If the subprocess writes "Hello, World!" to stdout: @@ -746,15 +747,16 @@ def _service_subprocess( # to EOF case try: need_more = socket_handler(key.fileobj) - except BrokenPipeError: + except (BrokenPipeError, ConnectionResetError): need_more = False # If the handler signals that the file object is no longer needed (EOF, closed, etc.) # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors # are removed. if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -773,16 +775,16 @@ def _check_subprocess_exit( raise else: self._process_exit_monotonic = time.monotonic() - self._close_unused_sockets(self.stdin) - # Put a message in the viewable task logs if expect_signal is not None and self._exit_code == -expect_signal: # Bypass logging, the caller expected us to exit with this return self._exit_code - # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): + # Put a message in the viewable task logs + if self._exit_code == -signal.SIGSEGV: self.process_log.critical(SIGSEGV_MESSAGE) + # psutil turns signal exit codes into an enum for us. Handy. (Otherwise it's a plain integer) if exit_code and (name := getattr(exit_code, "name")): elif name := getattr(self._exit_code, "name", None): message = "Process terminated by signal" level = logging.ERROR @@ -809,7 +811,7 @@ class ActivitySubprocess(WatchedSubprocess): _last_heartbeat_attempt: float = attrs.field(default=0, init=False) # After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we - # will kill the process. This is to handle temporary network issues etc. ensuring that the process + # will kill theprocess. This is to handle temporary network issues etc. ensuring that the process # does not hang around forever. failed_heartbeats: int = attrs.field(default=0, init=False) @@ -861,7 +863,6 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st ti=ti, dag_rel_path=os.fspath(dag_rel_path), bundle_info=bundle_info, - requests_fd=self._requests_fd, ti_context=ti_context, start_date=start_date, ) @@ -870,8 +871,8 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st log.debug("Sending", msg=msg) try: - self.send_msg(msg) - except BrokenPipeError: + self.send_msg(msg, request_id=0) + except (BrokenPipeError, ConnectionResetError): # Debug is fine, the process will have shown _something_ in it's last_chance exception handler log.debug("Couldn't send startup message to Subprocess - it died very early", pid=self.pid) @@ -930,7 +931,7 @@ def _monitor_subprocess(self): - Processes events triggered on the monitored file objects, such as data availability or EOF. - Sends heartbeats to ensure the process is alive and checks if the subprocess has exited. """ - while self._exit_code is None or self._num_open_sockets > 0: + while self._exit_code is None or self._open_sockets: last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. @@ -946,16 +947,11 @@ def _monitor_subprocess(self): # This listens for activity (e.g., subprocess output) on registered file objects alive = self._service_subprocess(max_wait_time=max_wait_time) is None - if self._exit_code is not None and self._num_open_sockets > 0: + if self._exit_code is not None and self._open_sockets: if ( self._process_exit_monotonic and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT ): - log.debug( - "Forcefully closing remaining sockets", - open_sockets=self._num_open_sockets, - pid=self.pid, - ) self._cleanup_open_sockets() if alive: @@ -1051,7 +1047,7 @@ def final_state(self): return SERVER_TERMINATED return TaskInstanceState.FAILED - def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} @@ -1224,10 +1220,17 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): dump_opts = {"exclude_unset": True} else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + request_id=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return - if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, request_id=req_id, error=None, **dump_opts) def in_process_api_server(): @@ -1237,24 +1240,26 @@ def in_process_api_server(): return api -@attrs.define +@attrs.define(kw_only=True) class InProcessSupervisorComms: """In-process communication handler that uses deques instead of sockets.""" + log: FilteringBoundLogger = attrs.field(repr=False, factory=structlog.get_logger) supervisor: InProcessTestSupervisor - messages: deque[BaseModel] = attrs.field(factory=deque) - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock) + messages: deque[BaseModel | None] = attrs.field(factory=deque) - def get_message(self) -> BaseModel: + def _get_response(self) -> BaseModel | None: """Get a message from the supervisor. Blocks until a message is available.""" return self.messages.popleft() - def send_request(self, log, msg: BaseModel): + def send(self, msg: BaseModel): """Send a request to the supervisor.""" - log.debug("Sending request", msg=msg) + self.log.debug("Sending request", msg=msg) with set_supervisor_comms(None): - self.supervisor._handle_request(msg, log) # type: ignore[arg-type] + self.supervisor._handle_request(msg, log, 0) # type: ignore[arg-type] + + return self._get_response() @attrs.define @@ -1272,7 +1277,8 @@ class InProcessTestSupervisor(ActivitySubprocess): """A supervisor that runs tasks in-process for easier testing.""" comms: InProcessSupervisorComms = attrs.field(init=False) - stdin = attrs.field(init=False) + + stdin: socket = attrs.field(init=False) @classmethod def start( # type: ignore[override] @@ -1298,7 +1304,6 @@ def start( # type: ignore[override] id=what.id, pid=os.getpid(), # Use current process process=psutil.Process(), # Current process - requests_fd=-1, # Not used in in-process mode process_log=logger or structlog.get_logger(logger_name="task").bind(), client=cls._api_client(task.dag), **kwargs, @@ -1363,7 +1368,9 @@ def _api_client(dag=None): client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def send_msg(self, msg: BaseModel, **dump_opts): + def send_msg( + self, msg: BaseModel | None, request_id: int, error: ErrorResponse | None = None, **dump_opts + ): """Override to use in-process comms.""" self.comms.messages.append(msg) @@ -1421,9 +1428,9 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: # to a (sync) generator def make_buffered_socket_reader( gen: Generator[None, bytes | bytearray, None], - on_close: Callable, + on_close: Callable[[socket], None], buffer_size: int = 4096, -) -> Callable[[socket], bool]: +): buffer = bytearray() # This will hold our accumulated binary data read_buffer = bytearray(buffer_size) # Temporary buffer for each read @@ -1440,8 +1447,6 @@ def cb(sock: socket): if len(buffer): with suppress(StopIteration): gen.send(buffer) - # Tell loop to close this selector - on_close() return False buffer.extend(read_buffer[:n_received]) @@ -1452,18 +1457,62 @@ def cb(sock: socket): try: gen.send(line) except StopIteration: - on_close() return False buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data return True - return cb + return cb, on_close + + +def length_prefixed_frame_reader( + gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None] +): + length_needed: int | None = None + # This will hold our accumulated/partial binary frame if it doesn't come in a single read + buffer: memoryview | None = None + # position in the buffer to store next read + pos = 0 + decoder = msgspec.msgpack.Decoder[_RequestFrame](_RequestFrame) + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, length_needed, pos + + if length_needed is None: + # Read the 32bit length of the frame + bytes = sock.recv(4) + if bytes == b"": + return False + + length_needed = int.from_bytes(bytes, byteorder="big") + buffer = memoryview(bytearray(length_needed)) + if length_needed and buffer: + n = sock.recv_into(buffer[pos:]) + if n == 0: + # EOF + return False + pos += n + + if pos >= length_needed: + request = decoder.decode(buffer) + buffer = None + pos = 0 + length_needed = None + try: + gen.send(request) + except StopIteration: + return False + return True + + return cb, on_close def process_log_messages_from_subprocess( loggers: tuple[FilteringBoundLogger, ...], -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: from structlog.stdlib import NAME_TO_LEVEL while True: @@ -1499,10 +1548,9 @@ def process_log_messages_from_subprocess( def forward_to_log( target_loggers: tuple[FilteringBoundLogger, ...], chan: str, level: int -) -> Generator[None, bytes, None]: +) -> Generator[None, bytes | bytearray, None]: while True: - buf = yield - line = bytes(buf) + line = yield # Strip off new line line = line.rstrip() try: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 33583ffe9c7b8..1dbe8e0221608 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -28,16 +28,14 @@ from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from datetime import datetime, timezone -from io import FileIO from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TextIO, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Literal -import aiologic import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -59,6 +57,7 @@ from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, + CommsDecoder, DagRunStateResult, DeferTask, DRCount, @@ -410,10 +409,9 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - SUPERVISOR_COMMS.send_request( - log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) + response = SUPERVISOR_COMMS.send( + msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -431,22 +429,17 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTICount( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, TICount) @@ -463,21 +456,16 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTaskStates( - dag_id=dag_id, - map_index=map_index, - task_ids=task_ids, - task_group_id=task_group_id, - logical_dates=logical_dates, - run_ids=run_ids, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetTaskStates( + dag_id=dag_id, + map_index=map_index, + task_ids=task_ids, + task_group_id=task_group_id, + logical_dates=logical_dates, + run_ids=run_ids, + ), + ) if TYPE_CHECKING: assert isinstance(response, TaskStatesResult) @@ -492,19 +480,14 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of DAG runs matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( - dag_id=dag_id, - logical_dates=logical_dates, - run_ids=run_ids, - states=states, - ), - ) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send( + GetDRCount( + dag_id=dag_id, + logical_dates=logical_dates, + run_ids=run_ids, + states=states, + ), + ) if TYPE_CHECKING: assert isinstance(response, DRCount) @@ -514,10 +497,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the DAG run with the given Run ID.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -619,62 +599,6 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -SendMsgType = TypeVar("SendMsgType", bound=BaseModel) -ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) - - -@attrs.define() -class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): - """Handle communication between the task in this process and the supervisor parent process.""" - - input: TextIO - - request_socket: FileIO = attrs.field(init=False, default=None) - - # We could be "clever" here and set the default to this based type parameters and a custom - # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a - # "sort of wrong default" - decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) - - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False) - - def get_message(self) -> ReceiveMsgType: - """ - Get a message from the parent. - - This will block until the message has been received. - """ - line = None - - # TODO: Investigate why some empty lines are sent to the processes stdin. - # That was highlighted when working on https://github.com/apache/airflow/issues/48183 - # and is maybe related to deferred/triggerer only context. - while not line: - line = self.input.readline() - - try: - msg = self.decoder.validate_json(line) - except Exception: - structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) - raise - - if isinstance(msg, StartupDetails): - # If we read a startup message, pull out the FDs we care about! - if msg.requests_fd > 0: - self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - elif isinstance(msg, ErrorResponse) and msg.error == ErrorType.API_SERVER_ERROR: - structlog.get_logger(logger_name="task").error("Error response from the API Server") - raise AirflowRuntimeError(error=msg) - - return msg - - def send_request(self, log: Logger, msg: SendMsgType): - encoded_msg = msg.model_dump_json().encode() + b"\n" - - log.debug("Sending request", json=encoded_msg) - self.request_socket.write(encoded_msg) - - # This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # @@ -694,30 +618,32 @@ def send_request(self, log: Logger, msg: SendMsgType): def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - msg = SUPERVISOR_COMMS.get_message() + # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent + # in response to us sending a request. + msg = SUPERVISOR_COMMS._get_response() + + if not isinstance(msg, StartupDetails): + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") log = structlog.get_logger(logger_name="task") + # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 + os_type = sys.platform + if os_type == "darwin": + log.debug("Mac OS detected, skipping setproctitle") + else: + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + try: get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) except Exception: log.exception("error calling listener") - if isinstance(msg, StartupDetails): - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 - os_type = sys.platform - if os_type == "darwin": - log.debug("Mac OS detected, skipping setproctitle") - else: - from setproctitle import setproctitle - - setproctitle(f"airflow worker -- {msg.ti.id}") - - with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): - ti = parse(msg, log) - log.debug("DAG file parsed", file=msg.dag_rel_path) - else: - raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): + ti = parse(msg, log) + log.debug("DAG file parsed", file=msg.dag_rel_path) return ti, ti.get_template_context(), log @@ -765,7 +691,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) _validate_task_inlets_and_outlets(ti=ti, log=log) @@ -785,8 +711,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log) - inactive_assets_resp = SUPERVISOR_COMMS.get_message() + inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -882,7 +807,7 @@ def run( except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send_request(log=log, msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -942,7 +867,7 @@ def run( error = e finally: if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + SUPERVISOR_COMMS.send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -980,9 +905,8 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - SUPERVISOR_COMMS.send_request( - log=log, - msg=TriggerDagRun( + comms_msg = SUPERVISOR_COMMS.send( + TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, logical_date=drte.logical_date, @@ -991,7 +915,6 @@ def _handle_trigger_dag_run( ), ) - comms_msg = SUPERVISOR_COMMS.get_message() if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: if drte.skip_when_already_exists: log.info( @@ -1046,10 +969,9 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - SUPERVISOR_COMMS.send_request( - log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) + comms_msg = SUPERVISOR_COMMS.send( + GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) ) - comms_msg = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1238,10 +1160,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)), - ) + SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1282,9 +1201,11 @@ def finalize( def main(): - # TODO: add an exception here, it causes an oof of a stack trace! + # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! + log = structlog.get_logger(logger_name="task") + global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) try: ti, context, log = startup() @@ -1295,19 +1216,17 @@ def main(): state, msg, error = run(ti, context, log) finalize(ti, state, context, log, error) except KeyboardInterrupt: - log = structlog.get_logger(logger_name="task") log.exception("Ctrl-c hit") exit(2) except Exception: - log = structlog.get_logger(logger_name="task") log.exception("Top level error") exit(1) finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if SUPERVISOR_COMMS and SUPERVISOR_COMMS.request_socket: + if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: with suppress(Exception): - SUPERVISOR_COMMS.request_socket.close() + SUPERVISOR_COMMS.socket.close() if __name__ == "__main__": diff --git a/task-sdk/tests/task_sdk/bases/test_sensor.py b/task-sdk/tests/task_sdk/bases/test_sensor.py index 2c0a82783d4d6..d0e2b430d2c74 100644 --- a/task-sdk/tests/task_sdk/bases/test_sensor.py +++ b/task-sdk/tests/task_sdk/bases/test_sensor.py @@ -205,7 +205,7 @@ def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_su time_machine.coordinates.shift(sensor.poke_interval) # Mocking values from DB/API-server - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=date1) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1) state, msg, error = run_task(task=sensor, context_update={"task_reschedule_count": 1}) assert state == State.FAILED @@ -227,7 +227,7 @@ def test_soft_fail_with_reschedule(self, run_task, make_sensor, time_machine, mo time_machine.coordinates.shift(sensor.poke_interval) # Mocking values from DB/API-server - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=date1) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1) state, msg, _ = run_task(task=sensor, context_update={"task_reschedule_count": 1}) assert state == State.SKIPPED @@ -254,7 +254,7 @@ def run_duration(): return (timezone.utcnow() - task_start_date).total_seconds() new_interval = 0 - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=task_start_date) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=task_start_date) # loop poke returns false for _poke_count in range(1, false_count + 1): @@ -516,9 +516,9 @@ def _run_task(): # For timeout calculation, we need to use the first reschedule date # This ensures the timeout is calculated from the start of the task if test_state["first_reschedule_date"] is None: - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate(start_date=None) + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=None) else: - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate( + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate( start_date=test_state["first_reschedule_date"] ) diff --git a/task-sdk/tests/task_sdk/definitions/conftest.py b/task-sdk/tests/task_sdk/definitions/conftest.py index c0ac385628388..3f89f34b4d2da 100644 --- a/task-sdk/tests/task_sdk/definitions/conftest.py +++ b/task-sdk/tests/task_sdk/definitions/conftest.py @@ -36,12 +36,12 @@ def run(dag: DAG, task_id: str, map_index: int): log = structlog.get_logger(__name__) - mock_supervisor_comms.send_request.reset_mock() + mock_supervisor_comms.send.reset_mock() ti = create_runtime_ti(dag.task_dict[task_id], map_index=map_index) run(ti, ti.get_template_context(), log) - for call in mock_supervisor_comms.send_request.mock_calls: - msg = call.kwargs["msg"] + for call in mock_supervisor_comms.send.mock_calls: + msg = call.kwargs.get("msg") or call.args[0] if isinstance(msg, (TaskState, SucceedTask)): return msg.state raise RuntimeError("Unable to find call to TaskState") diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py index 264cfee06290a..c066b70d99c9f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py @@ -295,7 +295,7 @@ def test_determine_kwargs( example_asset_func_with_valid_arg_as_inlet_asset ) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult( name="example_asset_func", uri="s3://bucket/object", @@ -326,12 +326,9 @@ def test_determine_kwargs( } assert mock_supervisor_comms.mock_calls == [ - mock.call.send_request(mock.ANY, GetAssetByName(name="example_asset_func")), - mock.call.get_message(), - mock.call.send_request(mock.ANY, GetAssetByName(name="inlet_asset_1")), - mock.call.get_message(), - mock.call.send_request(mock.ANY, GetAssetByName(name="inlet_asset_2")), - mock.call.get_message(), + mock.call.send(GetAssetByName(name="example_asset_func")), + mock.call.send(GetAssetByName(name="inlet_asset_1")), + mock.call.send(GetAssetByName(name="inlet_asset_2")), ] @mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) @@ -342,7 +339,7 @@ def test_determine_kwargs_defaults( ): asset_definition = asset(schedule=None)(example_asset_func_with_valid_arg_as_inlet_asset_and_default) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult(name="inlet_asset_1", uri="s3://bucket/object1", group="asset", extra=None), ] @@ -360,6 +357,5 @@ def test_determine_kwargs_defaults( } assert mock_supervisor_comms.mock_calls == [ - mock.call.send_request(mock.ANY, GetAssetByName(name="inlet_asset_1")), - mock.call.get_message(), + mock.call.send(GetAssetByName(name="inlet_asset_1")), ] diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py b/task-sdk/tests/task_sdk/definitions/test_connections.py index 102e85d36b30d..3bbb63a769788 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connections.py +++ b/task-sdk/tests/task_sdk/definitions/test_connections.py @@ -104,7 +104,7 @@ def test_get_uri(self): def test_conn_get(self, mock_supervisor_comms): conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result conn = Connection.get(conn_id="mysql_conn") assert conn is not None diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index 6cdb4b520f21e..f9c075352a0d4 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -251,7 +251,7 @@ def execute(self, context): ) mapped = callable(mapped, task1.output) - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["{{ ds }}"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["{{ ds }}"]) mapped_ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={task1.task_id: 1}) @@ -299,7 +299,7 @@ def test_expand_kwargs_render_template_fields_validating_operator( task1 = BaseOperator(task_id="op1") mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - mock_supervisor_comms.get_message.return_value = XComResult( + mock_supervisor_comms.send.return_value = XComResult( key="return_value", value=[{"arg1": "{{ ds }}"}, {"arg1": 2}] ) @@ -427,16 +427,14 @@ def show(number, letter): show.expand(number=emit_numbers(), letter=emit_letters()) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - task = dag.get_task(last_request.task_id) + task = dag.get_task(msg.task_id) value = task.python_callable() return XComResult(key="return_value", value=value) - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get states = [run_ti(dag, "show", map_index) for map_index in range(6)] assert states == [TaskInstanceState.SUCCESS] * 6 @@ -467,16 +465,14 @@ def show(a, b): emit_task = emit_numbers() show.expand(a=emit_task, b=emit_task) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - task = dag.get_task(last_request.task_id) + task = dag.get_task(msg.task_id) value = task.python_callable() return XComResult(key="return_value", value=value) - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get states = [run_ti(dag, "show", map_index) for map_index in range(4)] assert states == [TaskInstanceState.SUCCESS] * 4 @@ -594,22 +590,20 @@ def tg(va): # Aggregates results from task group. t.override(task_id="t3")(tg1) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - key = (last_request.task_id, last_request.map_index) + key = (msg.task_id, msg.map_index) if key in expected_values: value = expected_values[key] return XComResult(key="return_value", value=value) - if last_request.map_index is None: + if msg.map_index is None: # Get all mapped XComValues for this ti - value = [v for k, v in expected_values.items() if k[0] == last_request.task_id] + value = [v for k, v in expected_values.items() if k[0] == msg.task_id] return XComResult(key="return_value", value=value) return mock.DEFAULT - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get expected_values = { ("tg.t1", 0): ["a", "b"], @@ -683,10 +677,9 @@ def group(x): ti.task.execute(context) assert ti - mock_supervisor_comms.send_request.assert_has_calls( + mock_supervisor_comms.send.assert_has_calls( [ mock.call( - log=mock.ANY, msg=SetXCom( key="skipmixin_key", value={"skipped": ["group.empty_task"]}, diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 8bcf9ef28199a..c85924df6a6d9 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -51,7 +51,7 @@ class TestVariables: ) def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_comms): var_result = VariableResult(key="my_key", value=value) - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result var = Variable.get(key="my_key", deserialize_json=deserialize_json) assert var is not None @@ -83,8 +83,7 @@ def test_var_set(self, key, value, description, serialize_json, mock_supervisor_ if serialize_json: expected_value = json.dumps(value, indent=2) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=PutVariable( key=key, value=expected_value, description=description, serialize_json=serialize_json ), diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index d313b45dae265..73468dcb9e12b 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -52,7 +52,7 @@ def pull(value): assert set(dag.task_dict) == {"push", "pull"} # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) for map_index in range(3): assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS @@ -81,7 +81,7 @@ def c_to_none(v): pull.expand(value=push().map(c_to_none)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # Run "pull". This should automatically convert "c" to None. for map_index in range(3): @@ -111,7 +111,7 @@ def c_to_none(v): pull.expand_kwargs(push().map(c_to_none)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # The first two "pull" tis should succeed. for map_index in range(2): @@ -165,7 +165,7 @@ def does_not_work_with_c(v): pull.expand_kwargs(push().map(does_not_work_with_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # The third one (for "c") will fail. assert run_ti(dag, "pull", 2) == TaskInstanceState.FAILED @@ -209,7 +209,7 @@ def pull(value): pull.expand_kwargs(converted) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # Now "pull" should apply the mapping functions in order. for map_index in range(3): @@ -243,20 +243,18 @@ def convert_zipped(zipped): pull.expand(value=combined.map(convert_zipped)) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - if last_request.task_id == "push_letters": + if msg.task_id == "push_letters": value = push_letters.function() return XComResult(key="return_value", value=value) - if last_request.task_id == "push_numbers": + if msg.task_id == "push_numbers": value = push_numbers.function() return XComResult(key="return_value", value=value) return mock.DEFAULT - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get # Run "pull". for map_index in range(4): @@ -286,7 +284,7 @@ def skip_c(v): forward.expand_kwargs(push().map(skip_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.send.return_value = XComResult(key="return_value", value=["a", "b", "c"]) # Run "forward". This should automatically skip "c". states = [run_ti(dag, "forward", map_index) for map_index in range(3)] @@ -341,20 +339,18 @@ def pull_all(value): pull_one.expand(value=pushed_values) pull_all(pushed_values) - def xcom_get(): - # TODO: Tidy this after #45927 is reopened and fixed properly - last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] - if not isinstance(last_request, GetXCom): + def xcom_get(msg): + if not isinstance(msg, GetXCom): return mock.DEFAULT - if last_request.task_id == "push_letters": + if msg.task_id == "push_letters": value = push_letters.function() return XComResult(key="return_value", value=value) - if last_request.task_id == "push_numbers": + if msg.task_id == "push_numbers": value = push_numbers.function() return XComResult(key="return_value", value=value) return mock.DEFAULT - mock_supervisor_comms.get_message.side_effect = xcom_get + mock_supervisor_comms.send.side_effect = xcom_get # Run "pull_one" and "pull_all". assert run_ti(dag, "pull_all", None) == TaskInstanceState.SUCCESS diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py new file mode 100644 index 0000000000000..5adaa2562abc7 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +import msgspec +import pytest + +from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails, _ResponseFrame +from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.utils import timezone + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + @pytest.mark.usefixtures("disable_capturing") + def test_recv_StartupDetails(self): + r, w = socketpair() + + msg = { + "type": "StartupDetails", + "ti": { + "id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + "task_id": "a", + "try_number": 1, + "run_id": "b", + "dag_id": "c", + }, + "ti_context": { + "dag_run": { + "dag_id": "c", + "run_id": "b", + "logical_date": "2024-12-01T01:00:00Z", + "data_interval_start": "2024-12-01T00:00:00Z", + "data_interval_end": "2024-12-01T01:00:00Z", + "start_date": "2024-12-01T01:00:00Z", + "run_after": "2024-12-01T01:00:00Z", + "end_date": None, + "run_type": "manual", + "conf": None, + "consumed_asset_events": [], + }, + "max_tries": 0, + "should_retry": False, + "variables": None, + "connections": None, + }, + "file": "/dev/null", + "start_date": "2024-12-01T01:00:00Z", + "dag_rel_path": "/dev/null", + "bundle_info": {"name": "any-name", "version": "any-version"}, + } + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes) + + decoder = CommsDecoder(socket=r, log=None) + + msg = decoder._get_response() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.dag_rel_path == "/dev/null" + assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") + assert msg.start_date == timezone.datetime(2024, 12, 1, 1) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 59266da758892..6880c656b01cd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -190,7 +190,7 @@ def test_getattr_connection(self, mock_supervisor_comms): # Conn from the supervisor / API Server conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection; triggers __getattr__ conn = accessor.mysql_conn @@ -203,7 +203,7 @@ def test_get_method_valid_connection(self, mock_supervisor_comms): accessor = ConnectionAccessor() conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result conn = accessor.get("mysql_conn") assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) @@ -216,7 +216,7 @@ def test_get_method_with_default(self, mock_supervisor_comms): error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"} ) - mock_supervisor_comms.get_message.return_value = error_response + mock_supervisor_comms.send.return_value = error_response conn = accessor.get("nonexistent_conn", default_conn=default_conn) assert conn == default_conn @@ -233,7 +233,7 @@ def test_getattr_connection_for_extra_dejson(self, mock_supervisor_comms): extra='{"extra_key": "extra_value"}', ) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection's dejson; triggers __getattr__ dejson = accessor.mysql_conn.extra_dejson @@ -251,7 +251,7 @@ def test_getattr_connection_for_extra_dejson_decode_error(self, mock_log, mock_s conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306, extra="This is not JSON!" ) - mock_supervisor_comms.get_message.return_value = conn_result + mock_supervisor_comms.send.return_value = conn_result # Fetch the connection's dejson; triggers __getattr__ dejson = accessor.mysql_conn.extra_dejson @@ -274,7 +274,7 @@ def test_getattr_variable(self, mock_supervisor_comms): # Variable from the supervisor / API Server var_result = VariableResult(key="test_key", value="test_value") - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result # Fetch the variable; triggers __getattr__ value = accessor.test_key @@ -286,7 +286,7 @@ def test_get_method_valid_variable(self, mock_supervisor_comms): accessor = VariableAccessor(deserialize_json=False) var_result = VariableResult(key="test_key", value="test_value") - mock_supervisor_comms.get_message.return_value = var_result + mock_supervisor_comms.send.return_value = var_result val = accessor.get("test_key") assert val == var_result.value @@ -297,7 +297,7 @@ def test_get_method_with_default(self, mock_supervisor_comms): accessor = VariableAccessor(deserialize_json=False) error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"}) - mock_supervisor_comms.get_message.return_value = error_response + mock_supervisor_comms.send.return_value = error_response val = accessor.get("nonexistent_var_key", default="default_value") assert val == "default_value" @@ -367,7 +367,7 @@ class TestOutletEventAccessor: ), ) def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="") + mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") outlet_event_accessor = OutletEventAccessor(key=key, extra={}) outlet_event_accessor.add(add_arg) @@ -398,7 +398,7 @@ def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): ), ) def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="") + mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) outlet_event_accessor.add(add_arg, extra={}) @@ -497,11 +497,11 @@ def test_getitem_name_ref( resolved_asset, result_indexes, ): - mock_supervisor_comms.get_message.return_value = resolved_asset + mock_supervisor_comms.send.return_value = resolved_asset expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes] assert accessor[Asset.ref(name=name)] == expected - assert len(mock_supervisor_comms.send_request.mock_calls) == 1 - assert mock_supervisor_comms.send_request.mock_calls[0].kwargs["msg"] == GetAssetByName(name=name) + + mock_supervisor_comms.send.assert_called_once_with(GetAssetByName(name=name, type="GetAssetByName")) assert _AssetRefResolutionMixin._asset_ref_cache @pytest.mark.parametrize( @@ -520,11 +520,10 @@ def test_getitem_uri_ref( resolved_asset, result_indexes, ): - mock_supervisor_comms.get_message.return_value = resolved_asset + mock_supervisor_comms.send.return_value = resolved_asset expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes] assert accessor[Asset.ref(uri=uri)] == expected - assert len(mock_supervisor_comms.send_request.mock_calls) == 1 - assert mock_supervisor_comms.send_request.mock_calls[0].kwargs["msg"] == GetAssetByUri(uri=uri) + mock_supervisor_comms.send.assert_called_once_with(GetAssetByUri(uri=uri)) assert _AssetRefResolutionMixin._asset_ref_cache def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): @@ -534,22 +533,19 @@ def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor): assert source == AssetEventSourceTaskInstance(dag_id="d1", task_id="t2", run_id="r1", map_index=-1) mock_supervisor_comms.reset_mock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComResult(key="return_value", value="__example_xcom_value__"), ] assert source.xcom_pull() == "__example_xcom_value__" - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call( - log=mock.ANY, - msg=GetXCom( - key="return_value", - dag_id="d1", - run_id="r1", - task_id="t2", - map_index=-1, - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXCom( + key="return_value", + dag_id="d1", + run_id="r1", + task_id="t2", + map_index=-1, ), - ] + ) TEST_ASSET = Asset(name="test_uri", uri="test://test") @@ -593,7 +589,7 @@ def test__get_item__asset_ref(self, access_key, asset, mock_supervisor_comms): assert len(outlet_event_accessors) == 0 # Asset from the API Server via the supervisor - mock_supervisor_comms.get_message.return_value = AssetResult( + mock_supervisor_comms.send.return_value = AssetResult( name=asset.name, uri=asset.uri, group=asset.group, @@ -628,7 +624,7 @@ def test_for_asset_alias(self, mocked__getitem__): class TestInletEventAccessor: @pytest.fixture def sample_inlet_evnets_accessor(self, mock_supervisor_comms): - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetResult(name="test_uri", uri="test://test", group="asset"), AssetResult(name="test_uri", uri="test://test", group="asset"), ] @@ -656,7 +652,7 @@ def test__get_item__(self, key, sample_inlet_evnets_accessor, mock_supervisor_co asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.get_message.side_effect = [events_result] * 4 + mock_supervisor_comms.send.side_effect = [events_result] * 4 assert sample_inlet_evnets_accessor[key] == [asset_event_resp] @@ -684,7 +680,7 @@ def test_for_asset_alias(self, mocked__getitem__, sample_inlet_evnets_accessor): assert mocked__getitem__.call_args[0][0] == TEST_ASSET_ALIAS def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock_supervisor_comms): - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ AssetEventsResult( asset_events=[ AssetEventResponse( @@ -707,9 +703,7 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock ) ] events = sample_inlet_evnets_accessor[Asset.ref(name="test_uri")] - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call(log=mock.ANY, msg=GetAssetEventByAsset(name="test_uri", uri=None)), - ] + mock_supervisor_comms.send.assert_called_once_with(GetAssetEventByAsset(name="test_uri", uri=None)) assert len(events) == 2 assert events[1].source_task_instance is None @@ -723,19 +717,16 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock ) mock_supervisor_comms.reset_mock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComResult(key="return_value", value="__example_xcom_value__"), ] assert source.xcom_pull() == "__example_xcom_value__" - assert mock_supervisor_comms.send_request.mock_calls == [ - mock.call( - log=mock.ANY, - msg=GetXCom( - key="return_value", - dag_id="__dag__", - run_id="__run__", - task_id="__task__", - map_index=0, - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXCom( + key="return_value", + dag_id="__dag__", + run_id="__run__", + task_id="__task__", + map_index=0, ), - ] + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py index e4943196a09da..65dd30b2c7c7b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -17,7 +17,7 @@ from __future__ import annotations -from unittest.mock import ANY, Mock, call +from unittest.mock import Mock, call import pytest @@ -66,121 +66,109 @@ def deserialize_value(cls, xcom): def test_len(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3) + mock_supervisor_comms.send.return_value = XComCountResponse(len=3) assert len(lazy_sequence) == 3 - assert mock_supervisor_comms.send_request.mock_calls == [ - call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run")), - ] + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComCount(key="return_value", dag_id="dag", task_id="task", run_id="run"), + ) def test_iter(mock_supervisor_comms, lazy_sequence): it = iter(lazy_sequence) - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ XComSequenceIndexResult(root="f"), ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}), ] assert list(it) == ["f"] - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=0, + mock_supervisor_comms.send.assert_has_calls( + [ + call( + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=0, + ), ), - ), - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=1, + call( + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=1, + ), ), - ), - ] + ] + ) def test_getitem_index(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="f") + mock_supervisor_comms.send.return_value = XComSequenceIndexResult(root="f") assert lazy_sequence[4] == "f" - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) @conf_vars({("core", "xcom_backend"): "task_sdk.execution_time.test_lazy_sequence.CustomXCom"}) def test_getitem_calls_correct_deserialise(monkeypatch, mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceIndexResult(root="some-value") + mock_supervisor_comms.send.return_value = XComSequenceIndexResult(root="some-value") xcom = resolve_xcom_backend() assert xcom.__name__ == "CustomXCom" monkeypatch.setattr(airflow.sdk.execution_time.xcom, "XCom", xcom) assert lazy_sequence[4] == "Made with CustomXCom: some-value" - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = ErrorResponse( + mock_supervisor_comms.send.return_value = ErrorResponse( error=ErrorType.XCOM_NOT_FOUND, detail={"oops": "sorry!"}, ) with pytest.raises(IndexError) as ctx: lazy_sequence[4] assert ctx.value.args == (4,) - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceItem( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - offset=4, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, ), - ] + ) def test_getitem_slice(mock_supervisor_comms, lazy_sequence): - mock_supervisor_comms.get_message.return_value = XComSequenceSliceResult(root=[6, 4, 1]) + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=[6, 4, 1]) assert lazy_sequence[:5] == [6, 4, 1] - assert mock_supervisor_comms.send_request.mock_calls == [ - call( - log=ANY, - msg=GetXComSequenceSlice( - key="return_value", - dag_id="dag", - task_id="task", - run_id="run", - start=None, - stop=5, - step=None, - ), + mock_supervisor_comms.send.assert_called_once_with( + GetXComSequenceSlice( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + start=None, + stop=5, + step=None, ), - ] + ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4f5e4cfc7ac9d..d2f8ef40c0881 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,13 +27,14 @@ import socket import sys import time -from io import BytesIO from operator import attrgetter +from random import randint from time import sleep from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import httpx +import msgspec import psutil import pytest from pytest_unordered import unordered @@ -56,6 +57,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, DagRunStateResult, DeferTask, @@ -97,17 +99,16 @@ XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.supervisor import ( - BUFFER_SIZE, ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, - mkpipe, set_supervisor_comms, supervise, ) -from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz if TYPE_CHECKING: @@ -136,18 +137,25 @@ def local_dag_bundle_cfg(path, name="my-bundle"): } +@pytest.fixture +def client_with_ti_start(make_ti_context): + client = MagicMock(spec=sdk_client.Client) + client.task_instances.start.return_value = make_ti_context() + return client + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @pytest.fixture(autouse=True) def disable_log_upload(self, spy_agency): spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) - def test_reading_from_pipes(self, captured_logs, time_machine): + def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start): def subprocess_main(): # This is run in the subprocess! - # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + # Ensure we follow the "protocol" and get the startup message before we do anything else + CommsDecoder()._get_response() import logging import warnings @@ -180,7 +188,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -228,12 +236,12 @@ def subprocess_main(): ] ) - def test_subprocess_sigkilled(self): + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid() def subprocess_main(): # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() assert os.getpid() != main_pid os.kill(os.getpid(), signal.SIGKILL) @@ -248,7 +256,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -285,7 +293,7 @@ def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch, mocker, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -314,7 +322,7 @@ def test_no_heartbeat_in_overtime(self, spy_agency: kgb.SpyAgency, monkeypatch, monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -340,7 +348,7 @@ def _on_child_started(self, *args, **kwargs): assert proc.wait() == 0 spy_agency.assert_spy_not_called(heartbeat_spy) - def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context): + def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, client_with_ti_start): """Test running a simple DAG in a subprocess and capturing the output.""" instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) @@ -355,11 +363,6 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker try_number=1, ) - # Create a mock client to assert calls to the client - # We assume the implementation of the client is correct and only need to check the calls - mock_client = mocker.Mock(spec=sdk_client.Client) - mock_client.task_instances.start.return_value = make_ti_context() - bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): exit_code = supervise( @@ -368,7 +371,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker token="", server="", dry_run=True, - client=mock_client, + client=client_with_ti_start, bundle_info=bundle_info, ) assert exit_code == 0, captured_logs @@ -498,7 +501,7 @@ def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, m monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.0) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() sleep(5) # Shouldn't get here exit(5) @@ -611,7 +614,6 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t stdin=mocker.MagicMock(), client=client, process=mock_process, - requests_fd=-1, ) time_now = tz.datetime(2024, 11, 28, 12, 0, 0) @@ -701,7 +703,6 @@ def test_overtime_handling( stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock(), - requests_fd=-1, ) # Set the terminal state and task end datetime @@ -738,7 +739,7 @@ def test_overtime_handling( ), ), ) - def test_exit_by_signal(self, monkeypatch, signal_to_raise, log_pattern, cap_structlog): + def test_exit_by_signal(self, signal_to_raise, log_pattern, cap_structlog, client_with_ti_start): def subprocess_main(): import faulthandler import os @@ -748,7 +749,7 @@ def subprocess_main(): faulthandler.disable() # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() os.kill(os.getpid(), signal_to_raise) @@ -762,7 +763,7 @@ def subprocess_main(): run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -791,26 +792,26 @@ def test_cleanup_sockets_after_delay(self, monkeypatch, mocker, time_machine): stdin=mocker.MagicMock(), client=mocker.MagicMock(), process=mock_process, - requests_fd=-1, ) proc.selector = mocker.MagicMock() proc.selector.select.return_value = [] proc._exit_code = 0 - proc._num_open_sockets = 1 + # Create a fake placeholder in the open socket weakref + proc._open_sockets[mocker.MagicMock()] = "test placeholder" proc._process_exit_monotonic = time.monotonic() mocker.patch.object( ActivitySubprocess, "_cleanup_open_sockets", - side_effect=lambda: setattr(proc, "_num_open_sockets", 0), + side_effect=lambda: setattr(proc, "_open_sockets", {}), ) time_machine.shift(2) proc._monitor_subprocess() - assert proc._num_open_sockets == 0 + assert len(proc._open_sockets) == 0 class TestWatchedSubprocessKill: @@ -829,7 +830,6 @@ def watched_subprocess(self, mocker, mock_process): stdin=mocker.Mock(), client=mocker.Mock(), process=mock_process, - requests_fd=-1, ) # Mock the selector mock_selector = mocker.Mock(spec=selectors.DefaultSelector) @@ -888,7 +888,7 @@ def test_kill_process_custom_signal(self, watched_subprocess, mock_process): ), ], ) - def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, captured_logs, monkeypatch): + def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, client_with_ti_start): def subprocess_main(): import signal @@ -905,7 +905,7 @@ def _handler(sig, frame): signal.signal(signal.SIGINT, _handler) signal.signal(signal.SIGTERM, _handler) try: - sys.stdin.readline() + CommsDecoder()._get_response() print("Ready") sleep(10) except Exception as e: @@ -919,7 +919,7 @@ def _handler(sig, frame): dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -976,9 +976,11 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker): mock_stdout_handler = mocker.Mock(return_value=False) # Simulate EOF for stdout mock_stderr_handler = mocker.Mock(return_value=True) # Continue processing for stderr + mock_on_close = mocker.Mock() + # Mock selector to return events - mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=mock_stdout_handler) - mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=mock_stderr_handler) + mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=(mock_stdout_handler, mock_on_close)) + mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=(mock_stderr_handler, mock_on_close)) watched_subprocess.selector.select.return_value = [(mock_key_stdout, None), (mock_key_stderr, None)] # Mock to simulate process exited successfully @@ -996,8 +998,7 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker): mock_stderr_handler.assert_called_once_with(mock_stderr) # Validate unregistering and closing of EOF file object - watched_subprocess.selector.unregister.assert_called_once_with(mock_stdout) - mock_stdout.close.assert_called_once() + mock_on_close.assert_called_once_with(mock_stdout) # Validate that `_check_subprocess_exit` is called mock_process.wait.assert_called_once_with(timeout=0) @@ -1073,16 +1074,15 @@ def test_max_wait_time_calculation_edge_cases( class TestHandleRequest: @pytest.fixture def watched_subprocess(self, mocker): - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socket.socketpair() subprocess = ActivitySubprocess( process_log=mocker.MagicMock(), id=TI_ID, pid=12345, - stdin=write_end, # this is the writer side + stdin=write_end, client=mocker.Mock(), process=mocker.Mock(), - requests_fd=-1, ) return subprocess, read_end @@ -1091,7 +1091,7 @@ def watched_subprocess(self, mocker): @pytest.mark.parametrize( [ "message", - "expected_buffer", + "expected_body", "client_attr_path", "method_arg", "method_kwarg", @@ -1101,7 +1101,7 @@ def watched_subprocess(self, mocker): [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1111,7 +1111,12 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n', + { + "conn_id": "test_conn", + "conn_type": "mysql", + "password": "password", + "type": "ConnectionResult", + }, "connections.get", ("test_conn",), {}, @@ -1121,7 +1126,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "schema": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1131,7 +1136,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', + {"key": "test_key", "value": "test_value", "type": "VariableResult"}, "variables.get", ("test_key",), {}, @@ -1141,7 +1146,7 @@ def watched_subprocess(self, mocker): ), pytest.param( PutVariable(key="test_key", value="test_value", description="test_description"), - b"", + None, "variables.set", ("test_key", "test_value", "test_description"), {}, @@ -1151,7 +1156,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeleteVariable(key="test_key"), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "variables.delete", ("test_key",), {}, @@ -1161,7 +1166,7 @@ def watched_subprocess(self, mocker): ), pytest.param( DeferTask(next_method="execute_callback", classpath="my-classpath"), - b"", + None, "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), {}, @@ -1174,7 +1179,7 @@ def watched_subprocess(self, mocker): reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), end_date=timezone.parse("2024-10-31T12:00:00Z"), ), - b"", + None, "task_instances.reschedule", ( TI_ID, @@ -1190,7 +1195,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1202,7 +1207,7 @@ def watched_subprocess(self, mocker): GetXCom( dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 ), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2, False), {}, @@ -1212,7 +1217,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1228,7 +1233,7 @@ def watched_subprocess(self, mocker): key="test_key", include_prior_dates=True, ), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, True), {}, @@ -1244,7 +1249,7 @@ def watched_subprocess(self, mocker): key="test_key", value='{"key": "test_key", "value": {"key2": "value2"}}', ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1269,7 +1274,7 @@ def watched_subprocess(self, mocker): value='{"key": "test_key", "value": {"key2": "value2"}}', map_index=2, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1295,7 +1300,7 @@ def watched_subprocess(self, mocker): map_index=2, mapped_length=3, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1319,15 +1324,9 @@ def watched_subprocess(self, mocker): key="test_key", map_index=2, ), - b"", + None, "xcoms.delete", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - 2, - ), + ("test_dag", "test_run", "test_task", "test_key", 2), {}, OKResponse(ok=True), None, @@ -1337,7 +1336,7 @@ def watched_subprocess(self, mocker): # if it can handle TaskState message pytest.param( TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), - b"", + None, "", (), {}, @@ -1349,7 +1348,7 @@ def watched_subprocess(self, mocker): RetryTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" ), - b"", + None, "task_instances.retry", (), { @@ -1363,7 +1362,7 @@ def watched_subprocess(self, mocker): ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), - b"", + None, "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), {}, @@ -1373,7 +1372,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByName(name="asset"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"name": "asset"}, @@ -1383,7 +1382,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetByUri(uri="s3://bucket/obj"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"uri": "s3://bucket/obj"}, @@ -1393,11 +1392,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": "test"}, @@ -1416,11 +1421,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name=None), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": None}, @@ -1439,11 +1450,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAsset(uri=None, name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": None, "name": "test"}, @@ -1462,11 +1479,17 @@ def watched_subprocess(self, mocker): ), pytest.param( GetAssetEventByAssetAlias(alias_name="test_alias"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"alias_name": "test_alias"}, @@ -1485,7 +1508,10 @@ def watched_subprocess(self, mocker): ), pytest.param( ValidateInletsAndOutlets(ti_id=TI_ID), - b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n', + { + "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], + "type": "InactiveAssetsResult", + }, "task_instances.validate_inlets_and_outlets", (TI_ID,), {}, @@ -1499,7 +1525,7 @@ def watched_subprocess(self, mocker): SucceedTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" ), - b"", + None, "task_instances.succeed", (), { @@ -1515,11 +1541,13 @@ def watched_subprocess(self, mocker): ), pytest.param( GetPrevSuccessfulDagRun(ti_id=TI_ID), - ( - b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",' - b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",' - b'"type":"PrevSuccessfulDagRunResult"}\n' - ), + { + "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), + "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), + "start_date": timezone.parse("2025-01-10T12:00:00Z"), + "end_date": timezone.parse("2025-01-10T14:00:00Z"), + "type": "PrevSuccessfulDagRunResult", + }, "task_instances.get_previous_successful_dagrun", (TI_ID,), {}, @@ -1540,7 +1568,7 @@ def watched_subprocess(self, mocker): logical_date=timezone.datetime(2025, 1, 1), reset_dag_run=True, ), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "dag_runs.trigger", ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), {}, @@ -1549,8 +1577,9 @@ def watched_subprocess(self, mocker): id="dag_run_trigger", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR TriggerDagRun(dag_id="test_dag", run_id="test_run"), - b'{"error":"DAGRUN_ALREADY_EXISTS","detail":null,"type":"ErrorResponse"}\n', + {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, "dag_runs.trigger", ("test_dag", "test_run", None, None, False), {}, @@ -1560,7 +1589,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDagRunState(dag_id="test_dag", run_id="test_run"), - b'{"state":"running","type":"DagRunStateResult"}\n', + {"state": "running", "type": "DagRunStateResult"}, "dag_runs.get_state", ("test_dag", "test_run"), {}, @@ -1570,7 +1599,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskRescheduleStartDate(ti_id=TI_ID), - b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n', + {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": "TaskRescheduleStartDate"}, "task_instances.get_reschedule_start_date", (TI_ID, 1), {}, @@ -1580,7 +1609,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), - b'{"count":2,"type":"TICount"}\n', + {"count": 2, "type": "TICount"}, "task_instances.get_count", (), { @@ -1598,7 +1627,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetDRCount(dag_id="test_dag", states=["success", "failed"]), - b'{"count":2,"type":"DRCount"}\n', + {"count": 2, "type": "DRCount"}, "dag_runs.get_count", (), { @@ -1613,7 +1642,10 @@ def watched_subprocess(self, mocker): ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), - b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n', + { + "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, + "type": "TaskStatesResult", + }, "task_instances.get_task_states", (), { @@ -1636,7 +1668,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=0, ), - b'{"root":"test_value","type":"XComSequenceIndexResult"}\n', + {"root": "test_value", "type": "XComSequenceIndexResult"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 0), {}, @@ -1645,6 +1677,7 @@ def watched_subprocess(self, mocker): id="get_xcom_seq_item", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR GetXComSequenceItem( key="test_key", dag_id="test_dag", @@ -1652,7 +1685,7 @@ def watched_subprocess(self, mocker): task_id="test_task", offset=2, ), - b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + {"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 2), {}, @@ -1670,7 +1703,7 @@ def watched_subprocess(self, mocker): stop=None, step=None, ), - b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n', + {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", ("test_dag", "test_run", "test_task", "test_key", None, None, None), {}, @@ -1687,7 +1720,7 @@ def test_handle_requests( mocker, time_machine, message, - expected_buffer, + expected_body, client_attr_path, method_arg, method_kwarg, @@ -1715,8 +1748,9 @@ def test_handle_requests( generator = watched_subprocess.handle_requests(log=mocker.Mock()) # Initialize the generator next(generator) - msg = message.model_dump_json().encode() + b"\n" - generator.send(msg) + + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) + generator.send(req_frame) if mask_secret_args: mock_mask_secret.assert_called_with(*mask_secret_args) @@ -1729,33 +1763,22 @@ def test_handle_requests( # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(BUFFER_SIZE) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError, socket.timeout): - # no response written, valid for some message types like setters and TI operations. - pass + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id # Verify the response was added to the buffer - assert val == expected_buffer + assert frame.body == expected_body # Verify the response is correctly decoded # This is important because the subprocess/task runner will read the response # and deserialize it to the correct message type - # Only decode the buffer if it contains data. An empty buffer implies no response was written. - if not val and (mock_response and not isinstance(mock_response, OKResponse)): - pytest.fail("Expected a response, but got an empty buffer.") - - if val: - # Using BytesIO to simulate a readable stream for CommsDecoder. - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) - assert decoder.get_message() == mock_response + if frame.body is not None: + decoder = CommsDecoder(socket=None).body_decoder + assert decoder.validate_python(frame.body) == mock_response def test_handle_requests_api_server_error(self, watched_subprocess, mocker): """Test that API server errors are properly handled and sent back to the task.""" @@ -1777,28 +1800,32 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): next(generator) - msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")).model_dump_json().encode() + b"\n" - generator.send(msg) + msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")) + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg.model_dump()) + generator.send(req_frame) - # Read response directly from the reader socket + # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - val += read_socket.recv(4096) - except (BlockingIOError, TimeoutError): - pass - - assert val == ( - b'{"error":"API_SERVER_ERROR","detail":{"status_code":500,"message":"API Server Error",' - b'"detail":{"detail":"Internal Server Error"}},"type":"ErrorResponse"}\n' - ) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id + + assert frame.error == { + "error": "API_SERVER_ERROR", + "detail": { + "status_code": 500, + "message": "API Server Error", + "detail": {"detail": "Internal Server Error"}, + }, + "type": "ErrorResponse", + } # Verify the error can be decoded correctly - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) + comms = CommsDecoder(socket=None) with pytest.raises(AirflowRuntimeError) as exc_info: - decoder.get_message() + comms._from_frame(frame) assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR assert exc_info.value.error.detail == { @@ -1864,14 +1891,13 @@ def test_inprocess_supervisor_comms_roundtrip(self): """ class MinimalSupervisor(InProcessTestSupervisor): - def _handle_request(self, msg, log): + def _handle_request(self, msg, log, req_id): resp = VariableResult(key=msg.key, value="value") - self.send_msg(resp) + self.send_msg(resp, req_id) supervisor = MinimalSupervisor( id="test", pid=123, - requests_fd=-1, process=MagicMock(), process_log=MagicMock(), client=MagicMock(), @@ -1881,9 +1907,8 @@ def _handle_request(self, msg, log): test_msg = GetVariable(key="test_key") - comms.send_request(log=MagicMock(), msg=test_msg) + response = comms.send(test_msg) # Ensure we got back what we expect - response = comms.get_message() assert isinstance(response, VariableResult) assert response.value == "value" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 09d4d9709e789..7dd6e07695255 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -22,11 +22,9 @@ import json import os import textwrap -import uuid from collections.abc import Iterable from datetime import datetime, timedelta from pathlib import Path -from socket import socketpair from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -101,7 +99,6 @@ VariableAccessor, ) from airflow.sdk.execution_time.task_runner import ( - CommsDecoder, RuntimeTaskInstance, TaskRunnerMarker, _push_xcom_if_needed, @@ -136,47 +133,6 @@ def execute(self, context): print(f"Hello World {task_id}!") -class TestCommsDecoder: - """Test the communication between the subprocess and the "supervisor".""" - - @pytest.mark.usefixtures("disable_capturing") - def test_recv_StartupDetails(self): - r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"type":"StartupDetails", "ti": {' - b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", ' - b'"dag_id": "c"}, "ti_context":{"dag_run":{"dag_id":"c","run_id":"b",' - b'"logical_date":"2024-12-01T01:00:00Z",' - b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' - b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,' - b'"run_type":"manual","conf":null,"consumed_asset_events":[]},' - b'"max_tries":0,"should_retry":false,"variables":null,"connections":null},"file": "/dev/null",' - b'"start_date":"2024-12-01T01:00:00Z", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' - b'"any-name", "version": "any-version"}, "requests_fd": ' - + str(w2.fileno()).encode("ascii") - + b"}\n" - ) - - decoder = CommsDecoder(input=r.makefile("r")) - - msg = decoder.get_message() - assert isinstance(msg, StartupDetails) - assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") - assert msg.ti.task_id == "a" - assert msg.ti.dag_id == "c" - assert msg.dag_rel_path == "/dev/null" - assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") - assert msg.start_date == timezone.datetime(2024, 12, 1, 1) - - # Since this was a StartupDetails message, the decoder should open the other socket - assert decoder.request_socket is not None - assert decoder.request_socket.writable() - assert decoder.request_socket.fileno() == w2.fileno() - - def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( @@ -189,7 +145,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -245,7 +200,6 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -299,7 +253,6 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -357,8 +310,8 @@ def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_com assert ti.state == TaskInstanceState.DEFERRED - # send_request will only be called when the TaskDeferred exception is raised - mock_supervisor_comms.send_request.assert_any_call(msg=expected_defer_task, log=mock.ANY) + # send will only be called when the TaskDeferred exception is raised + mock_supervisor_comms.send.assert_any_call(expected_defer_task) def test_run_downstream_skipped(mocked_parse, create_runtime_ti, mock_supervisor_comms): @@ -381,8 +334,8 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] log.info.assert_called_with("Skipping downstream tasks.") - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, msg=SkipDownstreamTasks(tasks=["task1", "task2"], type="SkipDownstreamTasks") + mock_supervisor_comms.send.assert_any_call( + SkipDownstreamTasks(tasks=["task1", "task2"], type="SkipDownstreamTasks") ) @@ -433,8 +386,8 @@ def test_run_basic_skipped(time_machine, create_runtime_ti, mock_supervisor_comm assert ti.state == TaskInstanceState.SKIPPED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState(state=TaskInstanceState.SKIPPED, end_date=instant), log=mock.ANY + mock_supervisor_comms.send.assert_called_with( + TaskState(state=TaskInstanceState.SKIPPED, end_date=instant) ) @@ -455,12 +408,11 @@ def test_run_raises_base_exception(time_machine, create_runtime_ti, mock_supervi assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( + mock_supervisor_comms.send.assert_called_with( msg=TaskState( state=TaskInstanceState.FAILED, end_date=instant, ), - log=mock.ANY, ) @@ -482,13 +434,7 @@ def test_run_raises_system_exit(time_machine, create_runtime_ti, mock_supervisor assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) log.exception.assert_not_called() log.error.assert_called_with(mock.ANY, exit_code=10) @@ -513,13 +459,7 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms): @@ -542,13 +482,7 @@ def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms assert ti.state == TaskInstanceState.FAILED # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send_request.assert_called_with( - msg=TaskState( - state=TaskInstanceState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): @@ -570,7 +504,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ), bundle_info=FAKE_BUNDLE, dag_rel_path="", - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -580,7 +513,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm spy_agency.spy_on(task.prepare_for_execution) assert not task._lock_for_execution - # mock_supervisor_comms.get_message.return_value = what run(ti, context=ti.get_template_context(), log=mock.Mock()) spy_agency.assert_spy_called(task.prepare_for_execution) @@ -588,7 +520,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm assert ti.task is not task, "ti.task should be a copy of the original task" assert ti.state == TaskInstanceState.SUCCESS - mock_supervisor_comms.send_request.assert_any_call( + mock_supervisor_comms.send.assert_any_call( msg=SetRenderedFields( rendered_fields={ "bash_command": "echo 'Logical date is 2024-12-01 01:00:00+00:00'", @@ -596,7 +528,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm "env": None, } ), - log=mock.ANY, ) @@ -686,7 +617,6 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -694,22 +624,18 @@ def execute(self, context): time_machine.move_to(instant, tick=False) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what run(*startup()) expected_calls = [ - mock.call.send_request( - msg=SetRenderedFields(rendered_fields=expected_rendered_fields), - log=mock.ANY, - ), - mock.call.send_request( + mock.call.send(SetRenderedFields(rendered_fields=expected_rendered_fields)), + mock.call.send( msg=SucceedTask( end_date=instant, state=TaskInstanceState.SUCCESS, task_outlets=[], outlet_events=[], ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -763,9 +689,8 @@ def execute(self, context): assert ti.state == TaskInstanceState.SUCCESS # Ensure the task is Successful - mock_supervisor_comms.send_request.assert_called_once_with( + mock_supervisor_comms.send.assert_called_once_with( msg=SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), - log=mock.ANY, ) @@ -812,8 +737,8 @@ def execute(self, context): assert ti.state == TaskInstanceState.FAILED - mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant), log=mock.ANY + mock_supervisor_comms.send.assert_called_once_with( + msg=TaskState(state=TaskInstanceState.FAILED, end_date=instant) ) @@ -831,12 +756,11 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(dag_id=dag_id, run_id="c"), start_date=timezone.utcnow(), ) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what # Set the environment variable for DAG bundles # We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test! @@ -952,7 +876,7 @@ def test_run_with_asset_outlets( validate_mock.assert_called_once() - mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) + mock_supervisor_comms.send.assert_any_call(expected_msg) def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): @@ -964,7 +888,7 @@ def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.get_message.return_value = events_result + mock_supervisor_comms.send.return_value = events_result from airflow.providers.standard.operators.bash import BashOperator @@ -1118,7 +1042,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s dr = runtime_ti._ti_context_from_server.dag_run - mock_supervisor_comms.get_message.return_value = PrevSuccessfulDagRunResult( + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( data_interval_end=dr.logical_date - timedelta(hours=1), data_interval_start=dr.logical_date - timedelta(hours=2), start_date=dr.start_date - timedelta(hours=1), @@ -1168,7 +1092,7 @@ def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, dag_id="basic_task") - mock_supervisor_comms.get_message.return_value = PrevSuccessfulDagRunResult( + mock_supervisor_comms.send.return_value = PrevSuccessfulDagRunResult( data_interval_end=timezone.datetime(2025, 1, 1, 2, 0, 0), data_interval_start=timezone.datetime(2025, 1, 1, 1, 0, 0), start_date=timezone.datetime(2025, 1, 1, 1, 0, 0), @@ -1178,13 +1102,13 @@ def test_lazy_loading_not_triggered_until_accessed(self, create_runtime_ti, mock context = runtime_ti.get_template_context() # Assert lazy attributes are not resolved initially - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access a lazy-loaded attribute to trigger computation assert context["prev_data_interval_start_success"] == timezone.datetime(2025, 1, 1, 1, 0, 0) # Now the lazy attribute should trigger the call - mock_supervisor_comms.get_message.assert_called_once() + mock_supervisor_comms.send.assert_called_once() def test_get_connection_from_context(self, create_runtime_ti, mock_supervisor_comms): """Test that the connection is fetched from the API server via the Supervisor lazily when accessed""" @@ -1203,22 +1127,18 @@ def test_get_connection_from_context(self, create_runtime_ti, mock_supervisor_co ) runtime_ti = create_runtime_ti(task=task, dag_id="test_get_connection_from_context") - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn context = runtime_ti.get_template_context() # Assert that the connection is not fetched from the API server yet! # The connection should be only fetched connection is accessed - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access the connection from the context conn_from_context = context["conn"].test_conn - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, msg=GetConnection(conn_id="test_conn") - ) - mock_supervisor_comms.get_message.assert_called_once_with() + mock_supervisor_comms.send.assert_called_once_with(GetConnection(conn_id="test_conn")) assert conn_from_context == Connection( conn_id="test_conn", @@ -1277,7 +1197,7 @@ def test_template_with_connection( extra='{"extra__asana__workspace": "extra1"}', ) - mock_supervisor_comms.get_message.return_value = conn + mock_supervisor_comms.send.return_value = conn context = runtime_ti.get_template_context() result = runtime_ti.task.render_template(content, context) @@ -1304,22 +1224,18 @@ def test_get_variable_from_context( var = VariableResult(key="test_key", value=var_value) - mock_supervisor_comms.get_message.return_value = var + mock_supervisor_comms.send.return_value = var context = runtime_ti.get_template_context() # Assert that the variable is not fetched from the API server yet! # The variable should be only fetched connection is accessed - mock_supervisor_comms.send_request.assert_not_called() - mock_supervisor_comms.get_message.assert_not_called() + mock_supervisor_comms.send.assert_not_called() # Access the variable from the context var_from_context = context["var"][accessor_type].test_key - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, msg=GetVariable(key="test_key") - ) - mock_supervisor_comms.get_message.assert_called_once_with() + mock_supervisor_comms.send.assert_called_once_with(GetVariable(key="test_key")) assert var_from_context == expected_value @@ -1386,7 +1302,11 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task, **extra_for_ti) ser_value = BaseXCom.serialize_value(xcom_values) - mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=ser_value) + + def mock_send_side_effect(*args, **kwargs): + return XComResult(key="key", value=ser_value) + + mock_supervisor_comms.send.side_effect = mock_send_side_effect run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) @@ -1403,8 +1323,7 @@ def execute(self, context): for map_index in map_indexes: if map_index == NOTSET: map_index = -1 - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( msg=GetXCom( key="key", dag_id="test_dag", @@ -1457,7 +1376,7 @@ def execute(self, context): value = {"a": 1, "b": 2} # API server returns serialised value for xcom result, staging it in that way xcom_value = BaseXCom.serialize_value(value) - mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value) + mock_supervisor_comms.send.return_value = XComResult(key="key", value=xcom_value) returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes) assert returned_xcom == expected_value @@ -1544,11 +1463,8 @@ def execute(self, context): context = runtime_ti.get_template_context() run(runtime_ti, context=context, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: @@ -1597,9 +1513,8 @@ def __init__(self, bash_command, *args, **kwargs): log=mock.MagicMock(), ) - mock_supervisor_comms.send_request.assert_called_with( - msg=SetRenderedFields(rendered_fields={"bash_command": rendered_cmd}), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_with( + msg=SetRenderedFields(rendered_fields={"bash_command": rendered_cmd}) ) @pytest.mark.parametrize( @@ -1622,7 +1537,7 @@ def test_get_first_reschedule_date( task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count) - mock_supervisor_comms.get_message.return_value = TaskRescheduleStartDate( + mock_supervisor_comms.send.return_value = TaskRescheduleStartDate( start_date=timezone.datetime(2025, 1, 1) ) @@ -1631,7 +1546,7 @@ def test_get_first_reschedule_date( def test_get_ti_count(self, mock_supervisor_comms): """Test that get_ti_count sends the correct request and returns the count.""" - mock_supervisor_comms.get_message.return_value = TICount(count=2) + mock_supervisor_comms.send.return_value = TICount(count=2) count = RuntimeTaskInstance.get_ti_count( dag_id="test_dag", @@ -1642,8 +1557,7 @@ def test_get_ti_count(self, mock_supervisor_comms): states=["success", "failed"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetTICount( dag_id="test_dag", task_ids=["task1", "task2"], @@ -1657,7 +1571,7 @@ def test_get_ti_count(self, mock_supervisor_comms): def test_get_dr_count(self, mock_supervisor_comms): """Test that get_dr_count sends the correct request and returns the count.""" - mock_supervisor_comms.get_message.return_value = DRCount(count=2) + mock_supervisor_comms.send.return_value = DRCount(count=2) count = RuntimeTaskInstance.get_dr_count( dag_id="test_dag", @@ -1666,8 +1580,7 @@ def test_get_dr_count(self, mock_supervisor_comms): states=["success", "failed"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetDRCount( dag_id="test_dag", logical_dates=[timezone.datetime(2024, 1, 1)], @@ -1679,27 +1592,21 @@ def test_get_dr_count(self, mock_supervisor_comms): def test_get_dagrun_state(self, mock_supervisor_comms): """Test that get_dagrun_state sends the correct request and returns the state.""" - mock_supervisor_comms.get_message.return_value = DagRunStateResult(state="running") + mock_supervisor_comms.send.return_value = DagRunStateResult(state="running") state = RuntimeTaskInstance.get_dagrun_state( dag_id="test_dag", run_id="run1", ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, - msg=GetDagRunState( - dag_id="test_dag", - run_id="run1", - ), + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagRunState(dag_id="test_dag", run_id="run1"), ) assert state == "running" def test_get_task_states(self, mock_supervisor_comms): """Test that get_task_states sends the correct request and returns the states.""" - mock_supervisor_comms.get_message.return_value = TaskStatesResult( - task_states={"run1": {"task1": "running"}} - ) + mock_supervisor_comms.send.return_value = TaskStatesResult(task_states={"run1": {"task1": "running"}}) states = RuntimeTaskInstance.get_task_states( dag_id="test_dag", @@ -1707,8 +1614,7 @@ def test_get_task_states(self, mock_supervisor_comms): run_ids=["run1"], ) - mock_supervisor_comms.send_request.assert_called_once_with( - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( msg=GetTaskStates( dag_id="test_dag", task_ids=["task1"], @@ -1882,7 +1788,6 @@ def execute(self, context): assert not any( x == mock.call( - log=mock.ANY, msg=SetXCom( key="key", value="pushing to xcom backend!", @@ -1892,7 +1797,7 @@ def execute(self, context): map_index=-1, ), ) - for x in mock_supervisor_comms.send_request.call_args_list + for x in mock_supervisor_comms.send.call_args_list ) def test_xcom_pull_from_custom_xcom_backend( @@ -1921,7 +1826,6 @@ def execute(self, context): assert not any( x == mock.call( - log=mock.ANY, msg=GetXCom( key="key", dag_id="test_dag", @@ -1930,7 +1834,7 @@ def execute(self, context): map_index=-1, ), ) - for x in mock_supervisor_comms.send_request.call_args_list + for x in mock_supervisor_comms.send.call_args_list ) @@ -1961,11 +1865,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dag_overwrite(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -1990,11 +1891,8 @@ def execute(self, context): task=task, dag_id="dag_with_dag_params_overwrite", conf={"value": "new_value"} ) run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dag_default(self, create_runtime_ti, mock_supervisor_comms, time_machine): @@ -2017,11 +1915,8 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task, dag_id="dag_with_dag_params_default") run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_resolves( @@ -2052,11 +1947,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) def test_dag_param_dagrun_parameterized( @@ -2091,11 +1983,8 @@ def execute(self, context): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_called_once_with( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) @pytest.mark.parametrize("value", [VALUE, 0]) @@ -2122,11 +2011,8 @@ def return_num(num): run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_any_call( - msg=SucceedTask( - state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] - ), - log=mock.ANY, + mock_supervisor_comms.send.assert_any_call( + SucceedTask(state=TaskInstanceState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[]), ) @@ -2189,12 +2075,11 @@ def execute(self, context): ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) - mock_supervisor_comms.get_message.return_value = what + mock_supervisor_comms._get_response.return_value = what mocked_parse(what, "basic_dag", task) runtime_ti, context, log = startup() @@ -2492,17 +2377,15 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): assert msg.state == TaskInstanceState.SUCCESS expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", reset_dag_run=False, logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=SetXCom( key="trigger_run_id", value="test_run_id", @@ -2511,7 +2394,6 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): run_id="test_run", map_index=-1, ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -2539,23 +2421,21 @@ def test_handle_trigger_dag_run_conflict( ti = create_runtime_ti(dag_id="test_handle_trigger_dag_run_conflict", run_id="test_run", task=task) log = mock.MagicMock() - mock_supervisor_comms.get_message.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) + mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) state, msg, _ = run(ti, ti.get_template_context(), log) assert state == expected_state assert msg.state == expected_state expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), run_id="test_run_id", reset_dag_run=False, ), - log=mock.ANY, ), - mock.call.get_message(), ] mock_supervisor_comms.assert_has_calls(expected_calls) @@ -2600,13 +2480,19 @@ def test_handle_trigger_dag_run_wait_for_completion( ) log = mock.MagicMock() - mock_supervisor_comms.get_message.side_effect = [ + mock_supervisor_comms.send.side_effect = [ + # Set RTIF + None, # Successful Dag Run trigger OKResponse(ok=True), + # Set XCOM, + None, # Dag Run is still running DagRunStateResult(state=DagRunState.RUNNING), # Dag Run completes execution on the next poll DagRunStateResult(state=target_dr_state), + # Succeed/Fail task + None, ] with mock.patch("time.sleep", return_value=None): state, msg, _ = run(ti, ti.get_template_context(), log) @@ -2615,16 +2501,14 @@ def test_handle_trigger_dag_run_wait_for_completion( assert msg.state == expected_task_state expected_calls = [ - mock.call.send_request( + mock.call.send( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=SetXCom( key="trigger_run_id", value="test_run_id", @@ -2633,22 +2517,18 @@ def test_handle_trigger_dag_run_wait_for_completion( run_id="test_run", map_index=-1, ), - log=mock.ANY, ), - mock.call.send_request( + mock.call.send( msg=GetDagRunState( dag_id="test_dag", run_id="test_run_id", ), - log=mock.ANY, ), - mock.call.get_message(), - mock.call.send_request( + mock.call.send( msg=GetDagRunState( dag_id="test_dag", run_id="test_run_id", ), - log=mock.ANY, ), ] mock_supervisor_comms.assert_has_calls(expected_calls)