diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index fb9bdbaf6ba61..74a53991ff286 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -234,15 +234,16 @@ def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[i length = int.from_bytes(len_bytes, byteorder="big") buffer = bytearray(length) - nread = self.socket.recv_into(buffer) - if nread != length: - raise RuntimeError( - f"unable to read full response in child. (We read {nread}, but expected {length})" - ) - if nread == 0: - raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") - - resp = self.resp_decoder.decode(buffer) + mv = memoryview(buffer) + + pos = 0 + while pos < length: + nread = self.socket.recv_into(mv[pos:]) + if nread == 0: + raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") + pos += nread + + resp = self.resp_decoder.decode(mv) if maxfds: return resp, fds or [] return resp diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index b2e7f5e71c0fa..5595fc2775fab 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -17,6 +17,7 @@ from __future__ import annotations +import threading import uuid from socket import socketpair @@ -82,3 +83,32 @@ def test_recv_StartupDetails(self): assert msg.dag_rel_path == "/dev/null" assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") assert msg.start_date == timezone.datetime(2024, 12, 1, 1) + + def test_huge_payload(self): + r, w = socketpair() + + msg = { + "type": "XComResult", + "key": "a", + "value": ("a" * 10 * 1024 * 1024) + "b", # A 10mb xcom value + } + + w.settimeout(1.0) + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + + # Since `sendall` blocks, we need to do the send in another thread, so we can perform the read here + t = threading.Thread(target=w.sendall, args=(len(bytes).to_bytes(4, byteorder="big") + bytes,)) + t.start() + + decoder = CommsDecoder(socket=r, log=None) + + try: + msg = decoder._get_response() + finally: + t.join(2) + + assert msg is not None + + # It actually failed to read at all for large values, but lets just make sure we get it all + assert len(msg.value) == 10 * 1024 * 1024 + 1 + assert msg.value[-1] == "b"