Skip to content

Commit

Permalink
Enable StreamCell for all application channels (#2407)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Mar 22, 2024
1 parent 295ef96 commit b8d2c1a
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,25 @@
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.private.defs import CellChannel

CHANNELS_TO_HANDLE = (CellChannel.SERVER_COMMAND, CellChannel.AUX_COMMUNICATION)
CHANNELS_TO_EXCLUDE = (
CellChannel.CLIENT_MAIN,
CellChannel.SERVER_MAIN,
CellChannel.SERVER_PARENT_LISTENER,
CellChannel.CLIENT_COMMAND,
CellChannel.CLIENT_SUB_WORKER_COMMAND,
CellChannel.MULTI_PROCESS_EXECUTOR,
CellChannel.SIMULATOR_RUNNER,
CellChannel.RETURN_ONLY,
)


def _is_stream_channel(channel: str) -> bool:
if channel is None or channel == "":
return False
elif channel in CHANNELS_TO_EXCLUDE:
return False
# if not excluded, all channels supporting streaming capabilities
return True


class SimpleWaiter:
Expand Down Expand Up @@ -104,13 +122,13 @@ def __getattr__(self, func):
This method is called when Python cannot find an invoked method "x" of this class.
Method "x" is one of the message sending methods (send_request, broadcast_request, etc.)
In this method, we decide which method should be used instead, based on the "channel" of the message.
- If the channel is in CHANNELS_TO_HANDLE, use the method "_x" of this class.
- Otherwise, user the method "x" of the core_cell.
- If the channel is stream channel, use the method "_x" of this class.
- Otherwise, user the method "x" of the CoreCell.
"""

def method(*args, **kwargs):
self.logger.debug(f"__getattr__: {args=}, {kwargs=}")
if kwargs.get("channel") in CHANNELS_TO_HANDLE:
if _is_stream_channel(kwargs.get("channel")):
self.logger.debug(f"calling cell {func}")
return getattr(self, f"_{func}")(*args, **kwargs)
if not hasattr(self.core_cell, func):
Expand Down Expand Up @@ -311,7 +329,7 @@ def _register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs):

if not callable(cb):
raise ValueError(f"specified request_cb {type(cb)} is not callable")
if channel in CHANNELS_TO_HANDLE:
if _is_stream_channel(channel):
self.logger.info(f"Register blob CB for {channel=}, {topic=}")
adapter = Adapter(cb, self.core_cell.my_info, self)
self.register_blob_cb(channel, topic, adapter.call, *args, **kwargs)
Expand Down

0 comments on commit b8d2c1a

Please sign in to comment.