From 5e65e3e78fca258c2b374511fab2d6c17df025dd Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Wed, 25 Sep 2024 15:47:46 -0400 Subject: [PATCH] Refactored byte_receiver to be more object oriented --- nvflare/fuel/f3/streaming/blob_streamer.py | 26 +- nvflare/fuel/f3/streaming/byte_receiver.py | 401 +++++++++++---------- nvflare/fuel/f3/streaming/stream_types.py | 8 +- 3 files changed, 240 insertions(+), 195 deletions(-) diff --git a/nvflare/fuel/f3/streaming/blob_streamer.py b/nvflare/fuel/f3/streaming/blob_streamer.py index 506ecd1692..7b8e17a61f 100644 --- a/nvflare/fuel/f3/streaming/blob_streamer.py +++ b/nvflare/fuel/f3/streaming/blob_streamer.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import threading from typing import Callable, Optional +from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.connection import BytesAlike from nvflare.fuel.f3.message import Message from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver -from nvflare.fuel.f3.streaming.byte_streamer import STREAM_TYPE_BLOB, ByteStreamer +from nvflare.fuel.f3.streaming.byte_streamer import STREAM_CHUNK_SIZE, STREAM_TYPE_BLOB, ByteStreamer from nvflare.fuel.f3.streaming.stream_const import EOS from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, stream_thread_pool, wrap_view @@ -84,6 +86,7 @@ def __str__(self): class BlobHandler: def __init__(self, blob_cb: Callable): self.blob_cb = blob_cb + self.chunk_size = CommConfigurator().get_streaming_chunk_size(STREAM_CHUNK_SIZE) def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int: @@ -98,16 +101,15 @@ def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *ar return 0 - @staticmethod - def _read_stream(blob_task: BlobTask): + def _read_stream(self, blob_task: BlobTask): try: - # It's most efficient to use the same chunk size as the stream - chunk_size = ByteStreamer.get_chunk_size() - + # It's most efficient to read the whole chunk + size = self.chunk_size + thread_id = threading.get_native_id() buf_size = 0 while True: - buf = blob_task.stream.read(chunk_size) + buf = blob_task.stream.read(size) if not buf: break @@ -116,22 +118,26 @@ def _read_stream(blob_task: BlobTask): if blob_task.pre_allocated: remaining = len(blob_task.buffer) - buf_size if length > remaining: - log.error(f"{blob_task} Buffer overrun: {remaining=} {length=} {buf_size=}") + log.error(f"{blob_task} Buffer overrun: {thread_id=} {remaining=} {length=} {buf_size=}") if remaining > 0: blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining] + buf_size += remaining break else: blob_task.buffer[buf_size : buf_size + length] = buf else: blob_task.buffer.append(buf) except Exception as ex: - log.error(f"{blob_task} memoryview error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}") + log.error( + f"{blob_task} memoryview error: {ex} Debug info: " + f"{thread_id=} {length=} {buf_size=} {type(buf)=}" + ) raise ex buf_size += length if blob_task.size and blob_task.size != buf_size: - log.warning(f"Stream {blob_task} Size doesn't match: " f"{blob_task.size} <> {buf_size}") + log.warning(f"Stream {blob_task} Size doesn't match: {blob_task.size} <> {buf_size} {thread_id=}") if blob_task.pre_allocated: result = blob_task.buffer diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py index f2aff8d0aa..7e70a8113b 100644 --- a/nvflare/fuel/f3/streaming/byte_receiver.py +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -14,7 +14,7 @@ import logging import threading from collections import deque -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Deque, Dict, Optional, Tuple from nvflare.fuel.f3.cellnet.core_cell import CoreCell from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey @@ -41,115 +41,267 @@ ACK_INTERVAL = 1024 * 1024 * 4 READ_TIMEOUT = 300 COUNTER_NAME_RECEIVED = "received" + +# Read result status RESULT_DATA = 0 -RESULT_WAIT = 1 +RESULT_NO_DATA = 1 RESULT_EOS = 2 class RxTask: """Receiving task for ByteStream""" - def __init__(self, sid: int, origin: str): + rx_task_map = {} + map_lock = threading.Lock() + + def __init__(self, sid: int, origin: str, cell: CoreCell): self.sid = sid self.origin = origin + self.cell = cell + self.channel = None self.topic = None self.headers = None self.size = 0 - # The reassembled buffer in a double-ended queue - self.buffers = deque() - # Out-of-sequence buffers to be assembled - self.out_seq_buffers: Dict[int, Tuple[bool, BytesAlike]] = {} + # The reassembled chunks in a double-ended queue + self.chunks: Deque[Tuple[bool, BytesAlike]] = deque() + self.chunk_offset = 0 # Start of the remaining data for partially read left-most chunk + + # Out-of-sequence chunks to be assembled + self.out_seq_chunks: Dict[int, Tuple[bool, BytesAlike]] = {} self.stream_future = None self.next_seq = 0 self.offset = 0 self.offset_ack = 0 - self.eos = False self.waiter = threading.Event() - self.task_lock = threading.Lock() + self.lock = threading.Lock() + self.eos = False self.last_chunk_received = False + self.timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT) + self.ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL) + self.max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS) + def __str__(self): return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]" + @classmethod + def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTask"]: -class RxStream(Stream): - """A stream that's used to read streams from the buffer""" + sid = message.get_header(StreamHeaderKey.STREAM_ID) + origin = message.get_header(MessageHeaderKey.ORIGIN) + error = message.get_header(StreamHeaderKey.ERROR_MSG, None) - def __init__(self, byte_receiver: "ByteReceiver", task: RxTask): - super().__init__(task.size, task.headers) - self.byte_receiver = byte_receiver - self.task = task - self.timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT) - self.ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL) + with cls.map_lock: + task = cls.rx_task_map.get(sid, None) + if not task: + if error: + log.warning(f"Received error for non-existing stream: SID {sid} from {origin}") + return None - def read(self, chunk_size: int) -> bytes: - if self.closed: - raise StreamError("Read from closed stream") + task = RxTask(sid, origin, cell) + cls.rx_task_map[sid] = task + else: + if error: + task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False) + return None + + return task + + def read(self, size: int) -> BytesAlike: count = 0 while True: - result_code, result = self._read_chunk(chunk_size) + result_code, result = self._try_to_read(size) if result_code == RESULT_EOS: return EOS elif result_code == RESULT_DATA: return result - # Block if buffers are empty + # result_code == RESULT_NO_DATA Block until chunks are received if count > 0: - log.warning(f"{self.task} Read block is unblocked multiple times: {count}") + log.warning(f"{self} Read block is unblocked multiple times: {count}") - if not self.task.waiter.wait(self.timeout): - error = StreamError(f"{self.task} read timed out after {self.timeout} seconds") - self.byte_receiver.stop_task(self.task, error) + if not self.waiter.wait(self.timeout): + error = StreamError(f"{self} read timed out after {self.timeout} seconds") + self.stop(error) raise error count += 1 - def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]: + def process_chunk(self, message: Message) -> bool: + """Returns True if a new stream is created""" + + new_stream = False + with self.lock: + seq = message.get_header(StreamHeaderKey.SEQUENCE) + if seq == 0: + if self.stream_future: + log.warning(f"{self} Received duplicate chunk 0, ignored") + return new_stream + + self._handle_new_stream(message) + new_stream = True + + self._handle_incoming_data(seq, message) + return new_stream - with self.task.task_lock: + def _handle_new_stream(self, message: Message): + self.channel = message.get_header(StreamHeaderKey.CHANNEL) + self.topic = message.get_header(StreamHeaderKey.TOPIC) + self.headers = message.headers + self.size = message.get_header(StreamHeaderKey.SIZE, 0) - if not self.task.buffers: - if self.task.eos: - return RESULT_EOS, None + self.stream_future = StreamFuture(self.sid, self.headers) + self.stream_future.set_size(self.size) + + def _handle_incoming_data(self, seq: int, message: Message): + + data_type = message.get_header(StreamHeaderKey.DATA_TYPE) + + last_chunk = data_type == StreamDataType.FINAL + if last_chunk: + self.last_chunk_received = True + + if seq < self.next_seq: + log.warning(f"{self} Duplicate chunk ignored {seq=}") + return + + if seq == self.next_seq: + self._append((last_chunk, message.payload)) + + # Try to reassemble out-of-seq chunks + while self.next_seq in self.out_seq_chunks: + chunk = self.out_seq_chunks.pop(self.next_seq) + self._append(chunk) + else: + # Save out-of-seq chunks + if len(self.out_seq_chunks) >= self.max_out_seq: + self.stop(StreamError(f"{self} Too many out-of-sequence chunks: {len(self.out_seq_chunks)}")) + return + else: + if seq not in self.out_seq_chunks: + self.out_seq_chunks[seq] = last_chunk, message.payload else: - self.task.waiter.clear() - return RESULT_WAIT, None + log.warning(f"{self} Duplicate out-of-seq chunk ignored {seq=}") + + # If all chunks are lined up and last chunk received, the task can be deleted + if not self.out_seq_chunks and self.chunks: + last_chunk, _ = self.chunks[-1] + if last_chunk: + self.stop() + + def stop(self, error: StreamError = None, notify=True): + + with RxTask.map_lock: + RxTask.rx_task_map.pop(self.sid, None) + + if not error: + return + + if self.headers: + optional = self.headers.get(StreamHeaderKey.OPTIONAL, False) + else: + optional = False + + msg = f"Stream error: {error}" + if optional: + log.debug(msg) + else: + log.error(msg) + + self.stream_future.set_exception(error) + + if notify: + message = Message() - last_chunk, buf = self.task.buffers.popleft() + message.add_headers( + { + StreamHeaderKey.STREAM_ID: self.sid, + StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, + StreamHeaderKey.ERROR_MSG: str(error), + } + ) + self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message) + + def _try_to_read(self, size: int) -> Tuple[int, Optional[BytesAlike]]: + + with self.lock: + if self.eos: + return RESULT_EOS, None + + if not self.chunks: + self.waiter.clear() + return RESULT_NO_DATA, None + + # Get the left most chunk + last_chunk, buf = self.chunks[0] if buf is None: buf = bytes(0) - - if 0 < chunk_size < len(buf): - result = buf[0:chunk_size] - # Put leftover to the head of the queue - self.task.buffers.appendleft((last_chunk, buf[chunk_size:])) + end_offset = self.chunk_offset + size + if 0 < end_offset < len(buf): + # Partial read + result = buf[self.chunk_offset : end_offset] + self.chunk_offset = end_offset else: - result = buf + # Whole chunk is consumed + if self.chunk_offset: + result = buf[self.chunk_offset :] + else: + result = buf + + self.chunk_offset = 0 + self.chunks.popleft() + if last_chunk: - self.task.eos = True + self.eos = True - self.task.offset += len(result) + self.offset += len(result) - if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > self.ack_interval): + if not self.last_chunk_received and (self.offset - self.offset_ack > self.ack_interval): # Send ACK message = Message() message.add_headers( { - StreamHeaderKey.STREAM_ID: self.task.sid, + StreamHeaderKey.STREAM_ID: self.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.ACK, - StreamHeaderKey.OFFSET: self.task.offset, + StreamHeaderKey.OFFSET: self.offset, } ) - self.byte_receiver.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.task.origin, message) - self.task.offset_ack = self.task.offset + self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, self.origin, message) + self.offset_ack = self.offset - self.task.stream_future.set_progress(self.task.offset) + self.stream_future.set_progress(self.offset) return RESULT_DATA, result + def _append(self, buf: Tuple[bool, BytesAlike]): + if self.eos: + log.error(f"{self} Data after EOS is ignored") + return + + self.chunks.append(buf) + self.next_seq += 1 + + # Wake up blocking read() + if not self.waiter.is_set(): + self.waiter.set() + + +class RxStream(Stream): + """A stream that's used to read streams from the streaming task""" + + def __init__(self, task: RxTask): + super().__init__(task.size, task.headers) + self.task = task + + def read(self, size: int) -> bytes: + if self.closed: + raise StreamError("Read from closed stream") + + return self.task.read(size) + def close(self): if not self.task.stream_future.done(): self.task.stream_future.set_result(self.task.offset) @@ -161,9 +313,6 @@ def __init__(self, cell: CoreCell): self.cell = cell self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_DATA_TOPIC, cb=self._data_handler) self.registry = Registry() - self.rx_task_map = {} - self.map_lock = threading.Lock() - self.max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS) self.received_stream_counter_pool = StatsPoolManager.add_counter_pool( name="Received_Stream_Counters", @@ -182,148 +331,38 @@ def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args self.registry.set(channel, topic, Callback(stream_cb, args, kwargs)) - def stop_task(self, task: RxTask, error: StreamError = None, notify=True): - - with self.map_lock: - self.rx_task_map.pop(task.sid, None) - - if error: - if task.headers: - optional = task.headers.get(StreamHeaderKey.OPTIONAL, False) - else: - optional = False - - msg = f"Stream error: {error}" - if optional: - log.debug(msg) - else: - log.error(msg) - - task.stream_future.set_exception(error) - - if notify: - message = Message() - - message.add_headers( - { - StreamHeaderKey.STREAM_ID: task.sid, - StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, - StreamHeaderKey.ERROR_MSG: str(error), - } - ) - self.cell.fire_and_forget(STREAM_CHANNEL, STREAM_ACK_TOPIC, task.origin, message) - - task.eos = True - def _data_handler(self, message: Message): - sid = message.get_header(StreamHeaderKey.STREAM_ID) - origin = message.get_header(MessageHeaderKey.ORIGIN) - seq = message.get_header(StreamHeaderKey.SEQUENCE) - error = message.get_header(StreamHeaderKey.ERROR_MSG, None) - - payload = message.payload - - with self.map_lock: - task = self.rx_task_map.get(sid, None) - if not task: - if error: - log.debug(f"Received error for non-existing stream: SID {sid} from {origin}") - return - - task = RxTask(sid, origin) - self.rx_task_map[sid] = task - - if error: - self.stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False) + task = RxTask.find_or_create_task(message, self.cell) + if not task: return - with task.task_lock: - if seq == 0: - # Handle new stream - task.channel = message.get_header(StreamHeaderKey.CHANNEL) - task.topic = message.get_header(StreamHeaderKey.TOPIC) - task.headers = message.headers - - # GRPC may re-send the same request, causing seq 0 delivered more than once - if task.stream_future: - log.warning(f"{task} Received duplicate chunk 0, ignored") - return - - task.stream_future = StreamFuture(sid, message.headers) - task.size = message.get_header(StreamHeaderKey.SIZE, 0) - task.stream_future.set_size(task.size) - - # Invoke callback - callback = self.registry.find(task.channel, task.topic) - if not callback: - self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}")) - return - - self.received_stream_counter_pool.increment( - category=stream_stats_category(task.channel, task.topic, "stream"), - counter_name=COUNTER_NAME_RECEIVED, - ) - - self.received_stream_size_pool.record_value( - category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB - ) - - stream_thread_pool.submit(self._callback_wrapper, task, callback) - - data_type = message.get_header(StreamHeaderKey.DATA_TYPE) - last_chunk = data_type == StreamDataType.FINAL - if last_chunk: - task.last_chunk_received = True - - if seq < task.next_seq: - log.warning(f"{task} Duplicate chunk ignored {seq=}") + new_stream = task.process_chunk(message) + if new_stream: + # Invoke callback + callback = self.registry.find(task.channel, task.topic) + if not callback: + task.stop(StreamError(f"{task} No callback is registered for {task.channel}/{task.topic}")) return - if seq == task.next_seq: - self._append(task, (last_chunk, payload)) - task.next_seq += 1 - - # Try to reassemble out-of-seq buffers - while task.next_seq in task.out_seq_buffers: - chunk = task.out_seq_buffers.pop(task.next_seq) - self._append(task, chunk) - task.next_seq += 1 + self.received_stream_counter_pool.increment( + category=stream_stats_category(task.channel, task.topic, "stream"), + counter_name=COUNTER_NAME_RECEIVED, + ) - else: - # Out-of-seq chunk reassembly - if len(task.out_seq_buffers) >= self.max_out_seq: - self.stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}")) - return - else: - task.out_seq_buffers[seq] = last_chunk, payload + self.received_stream_size_pool.record_value( + category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB + ) - # If all chunks are lined up, the task can be deleted - if not task.out_seq_buffers and task.buffers: - last_chunk, _ = task.buffers[-1] - if last_chunk: - self.stop_task(task) + stream_thread_pool.submit(self._callback_wrapper, task, callback) - def _callback_wrapper(self, task: RxTask, callback: Callback): + @staticmethod + def _callback_wrapper(task: RxTask, callback: Callback): """A wrapper to catch all exceptions in the callback""" try: - stream = RxStream(self, task) + stream = RxStream(task) return callback.cb(task.stream_future, stream, False, *callback.args, **callback.kwargs) except Exception as ex: msg = f"{task} callback {callback.cb} throws exception: {ex}" log.error(msg) - self.stop_task(task, StreamError(msg)) - - @staticmethod - def _append(task: RxTask, buf: Tuple[bool, BytesAlike]): - if not buf: - return - - if task.eos: - log.error(f"{task} Data after EOS is ignored") - else: - task.buffers.append(buf) - - # Wake up blocking read() - if not task.waiter.is_set(): - task.waiter.set() + task.stop(StreamError(msg)) diff --git a/nvflare/fuel/f3/streaming/stream_types.py b/nvflare/fuel/f3/streaming/stream_types.py index 679f1b54dd..bff6c23954 100644 --- a/nvflare/fuel/f3/streaming/stream_types.py +++ b/nvflare/fuel/f3/streaming/stream_types.py @@ -60,15 +60,15 @@ def get_headers(self) -> Optional[dict]: return self.headers @abstractmethod - def read(self, chunk_size: int) -> BytesAlike: - """Read and return up to chunk_size bytes. It can return less but not more than the chunk_size. + def read(self, size: int) -> BytesAlike: + """Read and return up to size bytes. It can return less but not more than the size. An empty bytes object is returned if the stream reaches the end. Args: - chunk_size: Up to (but maybe less) this many bytes will be returned + size: Up to (but maybe less) this many bytes will be returned Returns: - Binary data. If empty, it means the stream is depleted (EOF) + Binary data. If empty, it means the stream is depleted (EOS) """ pass