diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index 34ee6df7cc..723b894f75 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -27,6 +27,7 @@ from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey from nvflare.fuel.f3.streaming.stream_types import StreamFuture from nvflare.private.defs import CellChannel +from nvflare.security.logging import secure_format_exception CHANNELS_TO_EXCLUDE = ( CellChannel.CLIENT_MAIN, @@ -233,15 +234,21 @@ def _get_result(self, req_id): return waiter.result def _future_wait(self, future, timeout): + # future could have an error! last_progress = 0 while not future.waiter.wait(timeout): + if future.error: + return False current_progress = future.get_progress() if last_progress == current_progress: return False else: self.logger.debug(f"{current_progress=}") last_progress = current_progress - return True + if future.error: + return False + else: + return True def _encode_message(self, msg: Message): try: @@ -259,7 +266,6 @@ def _send_request( timeout=10.0, secure=False, optional=False, - wait_for_reply=True, ): """Stream one request to the target @@ -271,14 +277,12 @@ def _send_request( timeout: how long to wait secure: is P2P security to be applied optional: is the message optional - wait_for_reply: whether to wait for reply - Returns: if wait_for_reply, then reply data; otherwise only a bool to indicate whether the request - is sent successfully + Returns: reply data """ self._encode_message(request) - return self._send_one_request(channel, target, topic, request, timeout, secure, optional, wait_for_reply) + return self._send_one_request(channel, target, topic, request, timeout, secure, optional) def _send_one_request( self, @@ -289,7 +293,6 @@ def _send_one_request( timeout=10.0, secure=False, optional=False, - wait_for_reply=True, ): req_id = str(uuid.uuid4()) request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id}) @@ -299,47 +302,46 @@ def _send_one_request( waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT)) self.requests_dict[req_id] = waiter - future = self.send_blob( - channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional - ) - self.logger.debug(f"{req_id=}: Waiting starts") + try: + future = self.send_blob( + channel=channel, topic=topic, target=target, message=request, secure=secure, optional=optional + ) - # Three stages, sending, waiting for receiving first byte, receiving + self.logger.debug(f"{req_id=}: Waiting starts") - # sending with progress timeout - self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") - sending_complete = self._future_wait(future, timeout) - if not sending_complete: - self.logger.info(f"{req_id=}: sending timeout {timeout=}") - if wait_for_reply: + # Three stages, sending, waiting for receiving first byte, receiving + # sending with progress timeout + self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") + sending_complete = self._future_wait(future, timeout) + if not sending_complete: + self.logger.debug(f"{req_id=}: sending timeout {timeout=}") return self._get_result(req_id) - else: - return False - self.logger.debug(f"{req_id=}: sending complete") - if not wait_for_reply: - return True - # waiting for receiving first byte - self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") - if not waiter.in_receiving.wait(timeout): - self.logger.info(f"{req_id=}: remote processing timeout {timeout=}") + self.logger.debug(f"{req_id=}: sending complete") + + # waiting for receiving first byte + self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") + if not waiter.in_receiving.wait(timeout): + self.logger.debug(f"{req_id=}: remote processing timeout {timeout=}") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: in receiving") + + # receiving with progress timeout + r_future = waiter.receiving_future + self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") + receiving_complete = self._future_wait(r_future, timeout) + if not receiving_complete: + self.logger.info(f"{req_id=}: receiving timeout {timeout=}") + return self._get_result(req_id) + self.logger.debug(f"{req_id=}: receiving complete") + waiter.result = Message(r_future.headers, r_future.result()) + decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING) + self.logger.debug(f"{req_id=}: return result {waiter.result=}") return self._get_result(req_id) - self.logger.debug(f"{req_id=}: in receiving") - - # receiving with progress timeout - r_future = waiter.receiving_future - self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") - receiving_complete = self._future_wait(r_future, timeout) - if not receiving_complete: - self.logger.info(f"{req_id=}: receiving timeout {timeout=}") + except Exception as ex: + self.logger.error(f"exception sending request: {secure_format_exception(ex)}") return self._get_result(req_id) - self.logger.debug(f"{req_id=}: receiving complete") - waiter.result = Message(r_future.headers, r_future.result()) - decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING) - self.logger.debug(f"{req_id=}: return result {waiter.result=}") - result = self._get_result(req_id) - return result def _process_reply(self, future: StreamFuture): headers = future.headers diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index 579a9190db..56f6716454 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -271,15 +271,27 @@ def send(self, msg: Message, timeout=None) -> bool: ) return True - return self.cell.send_request( + reply = self.cell.send_request( channel=self.channel, topic=msg.topic, target=self.peer_fqcn, request=_to_cell_message(msg), timeout=timeout, optional=optional, - wait_for_reply=False, ) + if reply: + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == ReturnCode.OK: + return True + else: + err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}" + if optional: + self.logger.debug(err) + else: + self.logger.error(err) + return False + else: + return False def _receive_message(self, request: CellMessage) -> Union[None, CellMessage]: sender = request.get_header(MessageHeaderKey.ORIGIN)