Skip to content

Commit

Permalink
fix wait-for-reply (NVIDIA#2478)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv authored and YuanTingHsieh committed Apr 18, 2024
1 parent e3d11d2 commit 7c8b833
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 43 deletions.
84 changes: 43 additions & 41 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.private.defs import CellChannel
from nvflare.security.logging import secure_format_exception

CHANNELS_TO_EXCLUDE = (
CellChannel.CLIENT_MAIN,
Expand Down Expand Up @@ -233,15 +234,21 @@ def _get_result(self, req_id):
return waiter.result

def _future_wait(self, future, timeout):
# future could have an error!
last_progress = 0
while not future.waiter.wait(timeout):
if future.error:
return False
current_progress = future.get_progress()
if last_progress == current_progress:
return False
else:
self.logger.debug(f"{current_progress=}")
last_progress = current_progress
return True
if future.error:
return False
else:
return True

def _encode_message(self, msg: Message):
try:
Expand All @@ -259,7 +266,6 @@ def _send_request(
timeout=10.0,
secure=False,
optional=False,
wait_for_reply=True,
):
"""Stream one request to the target
Expand All @@ -271,14 +277,12 @@ def _send_request(
timeout: how long to wait
secure: is P2P security to be applied
optional: is the message optional
wait_for_reply: whether to wait for reply
Returns: if wait_for_reply, then reply data; otherwise only a bool to indicate whether the request
is sent successfully
Returns: reply data
"""
self._encode_message(request)
return self._send_one_request(channel, target, topic, request, timeout, secure, optional, wait_for_reply)
return self._send_one_request(channel, target, topic, request, timeout, secure, optional)

def _send_one_request(
self,
Expand All @@ -289,7 +293,6 @@ def _send_one_request(
timeout=10.0,
secure=False,
optional=False,
wait_for_reply=True,
):
req_id = str(uuid.uuid4())
request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id})
Expand All @@ -299,47 +302,46 @@ def _send_one_request(

waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT))
self.requests_dict[req_id] = waiter
future = self.send_blob(
channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional
)

self.logger.debug(f"{req_id=}: Waiting starts")
try:
future = self.send_blob(
channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional
)

# Three stages, sending, waiting for receiving first byte, receiving
self.logger.debug(f"{req_id=}: Waiting starts")

# sending with progress timeout
self.logger.debug(f"{req_id=}: entering sending wait {timeout=}")
sending_complete = self._future_wait(future, timeout)
if not sending_complete:
self.logger.info(f"{req_id=}: sending timeout {timeout=}")
if wait_for_reply:
# Three stages, sending, waiting for receiving first byte, receiving
# sending with progress timeout
self.logger.debug(f"{req_id=}: entering sending wait {timeout=}")
sending_complete = self._future_wait(future, timeout)
if not sending_complete:
self.logger.debug(f"{req_id=}: sending timeout {timeout=}")
return self._get_result(req_id)
else:
return False
self.logger.debug(f"{req_id=}: sending complete")
if not wait_for_reply:
return True

# waiting for receiving first byte
self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}")
if not waiter.in_receiving.wait(timeout):
self.logger.info(f"{req_id=}: remote processing timeout {timeout=}")
self.logger.debug(f"{req_id=}: sending complete")

# waiting for receiving first byte
self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}")
if not waiter.in_receiving.wait(timeout):
self.logger.debug(f"{req_id=}: remote processing timeout {timeout=}")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: in receiving")

# receiving with progress timeout
r_future = waiter.receiving_future
self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}")
receiving_complete = self._future_wait(r_future, timeout)
if not receiving_complete:
self.logger.info(f"{req_id=}: receiving timeout {timeout=}")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: receiving complete")
waiter.result = Message(r_future.headers, r_future.result())
decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING)
self.logger.debug(f"{req_id=}: return result {waiter.result=}")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: in receiving")

# receiving with progress timeout
r_future = waiter.receiving_future
self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}")
receiving_complete = self._future_wait(r_future, timeout)
if not receiving_complete:
self.logger.info(f"{req_id=}: receiving timeout {timeout=}")
except Exception as ex:
self.logger.error(f"exception sending request: {secure_format_exception(ex)}")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: receiving complete")
waiter.result = Message(r_future.headers, r_future.result())
decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING)
self.logger.debug(f"{req_id=}: return result {waiter.result=}")
result = self._get_result(req_id)
return result

def _process_reply(self, future: StreamFuture):
headers = future.headers
Expand Down
16 changes: 14 additions & 2 deletions nvflare/fuel/utils/pipe/cell_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,27 @@ def send(self, msg: Message, timeout=None) -> bool:
)
return True

return self.cell.send_request(
reply = self.cell.send_request(
channel=self.channel,
topic=msg.topic,
target=self.peer_fqcn,
request=_to_cell_message(msg),
timeout=timeout,
optional=optional,
wait_for_reply=False,
)
if reply:
rc = reply.get_header(MessageHeaderKey.RETURN_CODE)
if rc == ReturnCode.OK:
return True
else:
err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}"
if optional:
self.logger.debug(err)
else:
self.logger.error(err)
return False
else:
return False

def _receive_message(self, request: CellMessage) -> Union[None, CellMessage]:
sender = request.get_header(MessageHeaderKey.ORIGIN)
Expand Down

0 comments on commit 7c8b833

Please sign in to comment.