diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index a91ec90d8db26..06ccda68e4fb0 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -35,7 +35,7 @@ No significant changes. Bug Fixes """"""""" -- Fix task execution failures with large data by improving internal communication protocol (#51924) +- Fix task execution failures with large data by improving internal communication protocol (#51924, #53194) - Fix reschedule sensors failing after multiple re-queue attempts over long periods (#52706) - Improve ``xcom_pull`` to cover different scenarios for mapped tasks (#51568) - Fix connection retrieval failures in triggerer when schema field is used (#52691) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 512f8a4f0e6aa..95f8d58bf0d7f 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -714,11 +714,11 @@ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None: async def _aread_frame(self): len_bytes = await self._async_reader.readexactly(4) - len = int.from_bytes(len_bytes, byteorder="big") - if len >= 2**32: - raise OverflowError(f"Refusing to receive messages larger than 4GiB {len=}") + length = int.from_bytes(len_bytes, byteorder="big") + if length >= 2**32: + raise OverflowError(f"Refusing to receive messages larger than 4GiB {length=}") - buffer = await self._async_reader.readexactly(len) + buffer = await self._async_reader.readexactly(length) return self.resp_decoder.decode(buffer) async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None: diff --git a/reproducible_build.yaml b/reproducible_build.yaml index a75a2c426a177..89c63b40f73ba 100644 --- a/reproducible_build.yaml +++ b/reproducible_build.yaml @@ -1,2 +1,2 @@ -release-notes-hash: 7079979bff6b12b0e8cff21b3e453319 -source-date-epoch: 1752089294 +release-notes-hash: 6f266812d7639b6204f98b6450b4518e +source-date-epoch: 1752254243 diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 0589f12f866e3..2c6dfea4e601c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -183,9 +183,9 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): def send(self, msg: SendMsgType) -> ReceiveMsgType | None: """Send a request to the parent and block until the response is received.""" frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) - bytes = frame.as_bytes() + frame_bytes = frame.as_bytes() - self.socket.sendall(bytes) + self.socket.sendall(frame_bytes) if isinstance(msg, ResendLoggingFD): if recv_fds is None: return None @@ -225,18 +225,19 @@ def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[i if len_bytes == b"": raise EOFError("Request socket closed before length") - len = int.from_bytes(len_bytes, byteorder="big") + length = int.from_bytes(len_bytes, byteorder="big") - buffer = bytearray(len) - nread = self.socket.recv_into(buffer) - if nread != len: - raise RuntimeError( - f"unable to read full response in child. (We read {nread}, but expected {len})" - ) - if nread == 0: - raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") + buffer = bytearray(length) + mv = memoryview(buffer) - resp = self.resp_decoder.decode(buffer) + pos = 0 + while pos < length: + nread = self.socket.recv_into(mv[pos:]) + if nread == 0: + raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") + pos += nread + + resp = self.resp_decoder.decode(mv) if maxfds: return resp, fds or [] return resp diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 5adaa2562abc7..48c7ad74a1501 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -17,6 +17,7 @@ from __future__ import annotations +import threading import uuid from socket import socketpair @@ -81,3 +82,32 @@ def test_recv_StartupDetails(self): assert msg.dag_rel_path == "/dev/null" assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") assert msg.start_date == timezone.datetime(2024, 12, 1, 1) + + def test_huge_payload(self): + r, w = socketpair() + + msg = { + "type": "XComResult", + "key": "a", + "value": ("a" * 10 * 1024 * 1024) + "b", # A 10mb xcom value + } + + w.settimeout(1.0) + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + + # Since `sendall` blocks, we need to do the send in another thread, so we can perform the read here + t = threading.Thread(target=w.sendall, args=(len(bytes).to_bytes(4, byteorder="big") + bytes,)) + t.start() + + decoder = CommsDecoder(socket=r, log=None) + + try: + msg = decoder._get_response() + finally: + t.join(2) + + assert msg is not None + + # It actually failed to read at all for large values, but lets just make sure we get it all + assert len(msg.value) == 10 * 1024 * 1024 + 1 + assert msg.value[-1] == "b"