Skip to content

Commit

Permalink
Refactored byte_receiver to be more object oriented
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Sep 25, 2024
1 parent 3dc3d3c commit 5e65e3e
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 195 deletions.
26 changes: 16 additions & 10 deletions nvflare/fuel/f3/streaming/blob_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from typing import Callable, Optional

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.byte_receiver import ByteReceiver
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_TYPE_BLOB, ByteStreamer
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_CHUNK_SIZE, STREAM_TYPE_BLOB, ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import EOS
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, stream_thread_pool, wrap_view
Expand Down Expand Up @@ -84,6 +86,7 @@ def __str__(self):
class BlobHandler:
def __init__(self, blob_cb: Callable):
self.blob_cb = blob_cb
self.chunk_size = CommConfigurator().get_streaming_chunk_size(STREAM_CHUNK_SIZE)

def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *args, **kwargs) -> int:

Expand All @@ -98,16 +101,15 @@ def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *ar

return 0

@staticmethod
def _read_stream(blob_task: BlobTask):
def _read_stream(self, blob_task: BlobTask):

try:
# It's most efficient to use the same chunk size as the stream
chunk_size = ByteStreamer.get_chunk_size()

# It's most efficient to read the whole chunk
size = self.chunk_size
thread_id = threading.get_native_id()
buf_size = 0
while True:
buf = blob_task.stream.read(chunk_size)
buf = blob_task.stream.read(size)
if not buf:
break

Expand All @@ -116,22 +118,26 @@ def _read_stream(blob_task: BlobTask):
if blob_task.pre_allocated:
remaining = len(blob_task.buffer) - buf_size
if length > remaining:
log.error(f"{blob_task} Buffer overrun: {remaining=} {length=} {buf_size=}")
log.error(f"{blob_task} Buffer overrun: {thread_id=} {remaining=} {length=} {buf_size=}")
if remaining > 0:
blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining]
buf_size += remaining
break
else:
blob_task.buffer[buf_size : buf_size + length] = buf
else:
blob_task.buffer.append(buf)
except Exception as ex:
log.error(f"{blob_task} memoryview error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
log.error(
f"{blob_task} memoryview error: {ex} Debug info: "
f"{thread_id=} {length=} {buf_size=} {type(buf)=}"
)
raise ex

buf_size += length

if blob_task.size and blob_task.size != buf_size:
log.warning(f"Stream {blob_task} Size doesn't match: " f"{blob_task.size} <> {buf_size}")
log.warning(f"Stream {blob_task} Size doesn't match: {blob_task.size} <> {buf_size} {thread_id=}")

if blob_task.pre_allocated:
result = blob_task.buffer
Expand Down
Loading

0 comments on commit 5e65e3e

Please sign in to comment.