Skip to content

Commit

Permalink
Merge branch 'main' into fin_xgb
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Oct 3, 2023
2 parents 5c312a5 + 57fea64 commit 3b5501f
Show file tree
Hide file tree
Showing 20 changed files with 246 additions and 92 deletions.
15 changes: 13 additions & 2 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def start_controller(self, fl_ctx: FLContext):
self.system_panic("Not enough client sites.", fl_ctx)
self._last_client = None

def _get_relay_orders(self):
def _get_relay_orders(self, fl_ctx: FLContext):
targets = list(self._participating_clients)
if len(targets) <= 1:
self.system_panic("Not enough client sites.", fl_ctx)
if self._order == RelayOrder.RANDOM:
random.shuffle(targets)
elif self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
Expand Down Expand Up @@ -169,7 +171,7 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True)

# Task for one cyclic
targets = self._get_relay_orders()
targets = self._get_relay_orders(fl_ctx)
targets_names = [t.name for t in targets]
self.log_debug(fl_ctx, f"Relay on {targets_names}")

Expand Down Expand Up @@ -244,3 +246,12 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
self._start_round = self._current_round
finally:
pass

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
super().handle_dead_job(client_name, fl_ctx)

new_client_list = []
for client in self._participating_clients:
if client_name != client.name:
new_client_list.append(client)
self._participating_clients = new_client_list
13 changes: 10 additions & 3 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def call(self, future): # this will be called by StreamCell upon receiving the

req_id = request.get_header(MessageHeaderKey.REQ_ID, "")
secure = request.get_header(MessageHeaderKey.SECURE, False)
optional = request.get_header(MessageHeaderKey.OPTIONAL, False)
self.logger.debug(f"{stream_req_id=}: on {channel=}, {topic=}")
response = self.cb(request)
self.logger.debug(f"response available: {stream_req_id=}: on {channel=}, {topic=}")
Expand All @@ -84,7 +85,9 @@ def call(self, future): # this will be called by StreamCell upon receiving the

encode_payload(response, StreamHeaderKey.PAYLOAD_ENCODING)
self.logger.debug(f"sending: {stream_req_id=}: {response.headers=}, target={origin}")
reply_future = self.cell.send_blob(CellChannel.RETURN_ONLY, f"{channel}:{topic}", origin, response, secure)
reply_future = self.cell.send_blob(
CellChannel.RETURN_ONLY, f"{channel}:{topic}", origin, response, secure, optional
)
self.logger.debug(f"Done sending: {stream_req_id=}: {reply_future=}")


Expand Down Expand Up @@ -201,7 +204,9 @@ def _fire_and_forget(

result = {}
for target in targets:
self.send_blob(channel=channel, topic=topic, target=target, message=message, secure=secure)
self.send_blob(
channel=channel, topic=topic, target=target, message=message, secure=secure, optional=optional
)
result[target] = ""
return result

Expand Down Expand Up @@ -237,7 +242,9 @@ def _send_one_request(self, channel, target, topic, request, timeout=10.0, secur

# this future can be used to check sending progress, but not for checking return blob
self.logger.debug(f"{req_id=}, {channel=}, {topic=}, {target=}, {timeout=}: send_request about to send_blob")
future = self.send_blob(channel=channel, topic=topic, target=target, message=request, secure=secure)
future = self.send_blob(
channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional
)

waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT))
self.requests_dict[req_id] = waiter
Expand Down
30 changes: 27 additions & 3 deletions nvflare/fuel/f3/comm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class VarName:
SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold"
COMM_DRIVER_PATH = "comm_driver_path"
HEARTBEAT_INTERVAL = "heartbeat_interval"
STREAMING_CHUNK_SIZE = "streaming_chunk_size"
STREAMING_ACK_WAIT = "streaming_ack_wait"
STREAMING_WINDOW_SIZE = "streaming_window_size"
STREAMING_ACK_INTERVAL = "streaming_ack_interval"
STREAMING_MAX_OUT_SEQ_CHUNKS = "streaming_max_out_seq_chunks"
STREAMING_READ_TIMEOUT = "streaming_read_timeout"


class CommConfigurator:
Expand Down Expand Up @@ -61,13 +67,31 @@ def get_backbone_connection_generation(self, default):
return ConfigService.get_int_var(VarName.BACKBONE_CONN_GEN, self.config, default=default)

def get_subnet_heartbeat_interval(self, default):
return ConfigService.get_int_var(VarName.SUBNET_HEARTBEAT_INTERVAL, self.config, default)
return ConfigService.get_int_var(VarName.SUBNET_HEARTBEAT_INTERVAL, self.config, default=default)

def get_subnet_trouble_threshold(self, default):
return ConfigService.get_int_var(VarName.SUBNET_TROUBLE_THRESHOLD, self.config, default)
return ConfigService.get_int_var(VarName.SUBNET_TROUBLE_THRESHOLD, self.config, default=default)

def get_comm_driver_path(self, default):
return ConfigService.get_str_var(VarName.COMM_DRIVER_PATH, self.config, default=default)

def get_heartbeat_interval(self, default):
return ConfigService.get_int_var(VarName.HEARTBEAT_INTERVAL, self.config, default)
return ConfigService.get_int_var(VarName.HEARTBEAT_INTERVAL, self.config, default=default)

def get_streaming_chunk_size(self, default):
return ConfigService.get_int_var(VarName.STREAMING_CHUNK_SIZE, self.config, default=default)

def get_streaming_ack_wait(self, default):
return ConfigService.get_int_var(VarName.STREAMING_ACK_WAIT, self.config, default=default)

def get_streaming_window_size(self, default):
return ConfigService.get_int_var(VarName.STREAMING_WINDOW_SIZE, self.config, default=default)

def get_streaming_ack_interval(self, default):
return ConfigService.get_int_var(VarName.STREAMING_ACK_INTERVAL, self.config, default=default)

def get_streaming_max_out_seq_chunks(self, default):
return ConfigService.get_int_var(VarName.STREAMING_MAX_OUT_SEQ_CHUNKS, self.config, default=default)

def get_streaming_read_timeout(self, default):
return ConfigService.get_int_var(VarName.STREAMING_READ_TIMEOUT, self.config, default)
3 changes: 3 additions & 0 deletions nvflare/fuel/f3/sfm/conn_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -302,11 +303,13 @@ def state_change(self, connection: Connection):
state = connection.state
connector = connection.connector
if state == ConnState.CONNECTED:
log.info(f"Connection {connection} is created: PID: {os.getpid()}")
self.handle_new_connection(connection)
with self.lock:
connector.total_conns += 1
connector.curr_conns += 1
elif state == ConnState.CLOSED:
log.info(f"Connection {connection} is closed PID: {os.getpid()}")
self.close_connection(connection)
with self.lock:
connector.curr_conns -= 1
Expand Down
31 changes: 22 additions & 9 deletions nvflare/fuel/f3/stream_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def get_chunk_size():
"""
return ByteStreamer.get_chunk_size()

def send_stream(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture:
def send_stream(
self, channel: str, topic: str, target: str, message: Message, secure=False, optional=False
) -> StreamFuture:
"""
Send a byte-stream over a channel/topic asynchronously. The streaming is performed in a different thread.
The streamer will read from stream and send the data in chunks till the stream reaches EOF.
Expand All @@ -51,6 +53,7 @@ def send_stream(self, channel: str, topic: str, target: str, message: Message, s
target: destination cell FQCN
message: The payload is the stream to send
secure: Send the message with end-end encryption if True
optional: Optional message, error maybe suppressed
Returns: StreamFuture that can be used to check status/progress, or register callbacks.
The future result is the number of bytes sent
Expand All @@ -60,7 +63,7 @@ def send_stream(self, channel: str, topic: str, target: str, message: Message, s
if not isinstance(message.payload, Stream):
raise StreamError(f"Message payload is not a stream: {type(message.payload)}")

return self.byte_streamer.send(channel, topic, target, message.headers, message.payload, secure)
return self.byte_streamer.send(channel, topic, target, message.headers, message.payload, secure, optional)

def register_stream_cb(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs):
"""
Expand Down Expand Up @@ -88,7 +91,9 @@ def register_stream_cb(self, channel: str, topic: str, stream_cb: Callable, *arg
"""
self.byte_receiver.register_callback(channel, topic, stream_cb, *args, **kwargs)

def send_blob(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture:
def send_blob(
self, channel: str, topic: str, target: str, message: Message, secure=False, optional=False
) -> StreamFuture:
"""
Send a BLOB (Binary Large Object) to the target. The payload of message is the BLOB. The BLOB must fit in
memory on the receiving end.
Expand All @@ -99,6 +104,7 @@ def send_blob(self, channel: str, topic: str, target: str, message: Message, sec
target: destination cell IDs
message: the headers and the blob as payload
secure: Send the message with end-end encryption if True
optional: Optional message, error maybe suppressed
Returns: StreamFuture that can be used to check status/progress and get result
The future result is the total number of bytes sent
Expand All @@ -111,7 +117,7 @@ def send_blob(self, channel: str, topic: str, target: str, message: Message, sec
if not isinstance(message.payload, (bytes, bytearray, memoryview)):
raise StreamError(f"Message payload is not a byte array: {type(message.payload)}")

return self.blob_streamer.send(channel, topic, target, message, secure)
return self.blob_streamer.send(channel, topic, target, message, secure, optional)

def register_blob_cb(self, channel: str, topic: str, blob_cb, *args, **kwargs):
"""
Expand All @@ -131,7 +137,9 @@ def register_blob_cb(self, channel: str, topic: str, blob_cb, *args, **kwargs):
"""
self.blob_streamer.register_blob_callback(channel, topic, blob_cb, *args, **kwargs)

def send_file(self, channel: str, topic: str, target: str, message: Message, secure=False) -> StreamFuture:
def send_file(
self, channel: str, topic: str, target: str, message: Message, secure=False, optional=False
) -> StreamFuture:
"""
Send a file to target using stream API.
Expand All @@ -141,6 +149,7 @@ def send_file(self, channel: str, topic: str, target: str, message: Message, sec
target: destination cell FQCN
message: the headers and the full path of the file to be sent as payload
secure: Send the message with end-end encryption if True
optional: Optional message, error maybe suppressed
Returns: StreamFuture that can be used to check status/progress and get the total bytes sent
Expand All @@ -152,7 +161,7 @@ def send_file(self, channel: str, topic: str, target: str, message: Message, sec
if not os.path.isfile(file_name) or not os.access(file_name, os.R_OK):
raise StreamError(f"File {file_name} doesn't exist or isn't readable")

return self.file_streamer.send(channel, topic, target, message, secure)
return self.file_streamer.send(channel, topic, target, message, secure, optional)

def register_file_cb(self, channel: str, topic: str, file_cb, *args, **kwargs):
"""
Expand All @@ -168,7 +177,9 @@ def register_file_cb(self, channel: str, topic: str, file_cb, *args, **kwargs):
"""
self.file_streamer.register_file_callback(channel, topic, file_cb, *args, **kwargs)

def send_objects(self, channel: str, topic: str, target: str, message: Message, secure=False) -> ObjectStreamFuture:
def send_objects(
self, channel: str, topic: str, target: str, message: Message, secure=False, optional=False
) -> ObjectStreamFuture:
"""
Send a list of objects to the destination. Each object is sent as BLOB, so it must fit in memory
Expand All @@ -178,14 +189,16 @@ def send_objects(self, channel: str, topic: str, target: str, message: Message,
target: destination cell IDs
message: Headers and the payload which is an iterator that provides next object
secure: Send the message with end-end encryption if True
optional: Optional message, error maybe suppressed
Returns: ObjectStreamFuture that can be used to check status/progress, or register callbacks
"""
if not isinstance(message.payload, ObjectIterator):
raise StreamError(f"Message payload is not an object iterator: {type(message.payload)}")

return self.object_streamer.stream_objects(channel, topic, target, message.headers, message.payload, secure)
return self.object_streamer.stream_objects(
channel, topic, target, message.headers, message.payload, secure, optional
)

def register_objects_cb(
self, channel: str, topic: str, object_stream_cb: Callable, object_cb: Callable, *args, **kwargs
Expand Down
6 changes: 4 additions & 2 deletions nvflare/fuel/f3/streaming/blob_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,17 @@ def __init__(self, byte_streamer: ByteStreamer, byte_receiver: ByteReceiver):
self.byte_streamer = byte_streamer
self.byte_receiver = byte_receiver

def send(self, channel: str, topic: str, target: str, message: Message, secure: bool) -> StreamFuture:
def send(
self, channel: str, topic: str, target: str, message: Message, secure: bool, optional: bool
) -> StreamFuture:
if message.payload is None:
message.payload = bytes(0)

if not isinstance(message.payload, (bytes, bytearray, memoryview)):
raise StreamError(f"BLOB is invalid type: {type(message.payload)}")

blob_stream = BlobStream(message.payload, message.headers)
return self.byte_streamer.send(channel, topic, target, message.headers, blob_stream, secure)
return self.byte_streamer.send(channel, topic, target, message.headers, blob_stream, secure, optional)

def register_blob_callback(self, channel, topic, blob_cb: Callable, *args, **kwargs):
handler = BlobHandler(blob_cb)
Expand Down
13 changes: 9 additions & 4 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nvflare.fuel.f3.cellnet.registry import Callback, Registry
from nvflare.fuel.f3.connection import BytesAlike
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.streaming.byte_streamer import ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import (
EOS,
STREAM_ACK_TOPIC,
Expand Down Expand Up @@ -90,8 +91,9 @@ def read(self, chunk_size: int) -> bytes:
log.debug(f"Read block is unblocked multiple times: {count}")

self.task.waiter.clear()
if not self.task.waiter.wait(READ_TIMEOUT):
error = StreamError(f"{self.task} read timed out after {READ_TIMEOUT} seconds")
timeout = ByteStreamer.comm_config.get_streaming_read_timeout(READ_TIMEOUT)
if not self.task.waiter.wait(timeout):
error = StreamError(f"{self.task} read timed out after {timeout} seconds")
self.byte_receiver.stop_task(self.task, error)
raise error

Expand All @@ -112,7 +114,9 @@ def read(self, chunk_size: int) -> bytes:
self.task.eos = True

self.task.offset += len(result)
if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > ACK_INTERVAL):

ack_interval = ByteStreamer.comm_config.get_streaming_ack_interval(ACK_INTERVAL)
if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > ack_interval):
# Send ACK
message = Message()
message.add_headers(
Expand Down Expand Up @@ -231,7 +235,8 @@ def _data_handler(self, message: Message):

else:
# Out-of-seq chunk reassembly
if len(task.out_seq_buffers) >= MAX_OUT_SEQ_CHUNKS:
max_out_seq = ByteStreamer.comm_config.get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS)
if len(task.out_seq_buffers) >= max_out_seq:
self.stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}"))
return
else:
Expand Down
Loading

0 comments on commit 3b5501f

Please sign in to comment.