From 571f17fdfe5a1c1dec31b444d0a0efcdf15a7a5b Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:11:42 -0400 Subject: [PATCH] Restore non-aio GRPC and a few improvements (#2058) * restore non-aio grpc driver * restore non-aio grpc driver * fix unit tests * fix drivers * fix CP HB bug; add retry for result submit * fix test cases * fix test case * address pr review * fix download_job bug in flare_api --- nvflare/apis/fl_constant.py | 2 + nvflare/apis/impl/controller.py | 5 + nvflare/apis/responder.py | 10 + nvflare/fuel/f3/comm_config.py | 28 +- nvflare/fuel/f3/drivers/aio_grpc_driver.py | 66 ++-- nvflare/fuel/f3/drivers/grpc/qq.py | 52 ++++ nvflare/fuel/f3/drivers/grpc/utils.py | 51 ++++ nvflare/fuel/f3/drivers/grpc_driver.py | 288 ++++++++++++++++++ nvflare/fuel/f3/drivers/net_utils.py | 3 +- nvflare/fuel/flare_api/flare_api.py | 8 +- nvflare/fuel/hci/client/cli.py | 17 +- nvflare/fuel/hci/client/file_transfer.py | 6 +- nvflare/fuel/hci/proto.py | 1 + .../private/fed/client/client_run_manager.py | 3 +- nvflare/private/fed/client/client_runner.py | 138 ++++++++- nvflare/private/fed/client/communicator.py | 8 +- nvflare/private/fed/client/fed_client_base.py | 86 ++---- nvflare/private/fed/server/server_runner.py | 33 ++ nvflare/private/fed/utils/fed_utils.py | 25 +- tests/unit_test/fuel/f3/communicator_test.py | 1 + .../fuel/f3/drivers/custom_driver_test.py | 2 + .../fuel/f3/drivers/driver_manager_test.py | 3 + 22 files changed, 677 insertions(+), 159 deletions(-) create mode 100644 nvflare/fuel/f3/drivers/grpc/qq.py create mode 100644 nvflare/fuel/f3/drivers/grpc/utils.py create mode 100644 nvflare/fuel/f3/drivers/grpc_driver.py diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index ddea05cd24..2ecb828f22 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -181,6 +181,8 @@ class ReservedTopic(object): DO_TASK = "__do_task__" AUX_COMMAND = "__aux_command__" SYNC_RUNNER = "__sync_runner__" + JOB_HEART_BEAT = "__job_heartbeat__" + TASK_CHECK = "__task_check__" class AdminCommandNames(object): diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 16725f7ad0..0d8ca2f88b 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -344,6 +344,11 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if not self._dead_client_reports.get(client_name): self._dead_client_reports[client_name] = time.time() + def process_task_check(self, task_id: str, fl_ctx: FLContext): + with self._task_lock: + # task_id is the uuid associated with the client_task + return self._client_task_map.get(task_id, None) + def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): """Called to process a submission from one client. diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index a411b30bd4..6072670886 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -63,6 +63,16 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul """ pass + @abstractmethod + def process_task_check(self, task_id: str, fl_ctx: FLContext): + """Called by the Engine to check whether a specified task still exists. + Args: + task_id: the id of the task + fl_ctx: the FLContext + Returns: the ClientTask object if exists; None otherwise + """ + pass + @abstractmethod def handle_dead_job(self, client_name: str, fl_ctx: FLContext): """Called by the Engine to handle the case that the job on the client is dead. diff --git a/nvflare/fuel/f3/comm_config.py b/nvflare/fuel/f3/comm_config.py index c2fa51d9b5..4a93bd1d89 100644 --- a/nvflare/fuel/f3/comm_config.py +++ b/nvflare/fuel/f3/comm_config.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging - from nvflare.fuel.f3.drivers.net_utils import MAX_PAYLOAD_SIZE from nvflare.fuel.utils.config import Config from nvflare.fuel.utils.config_service import ConfigService @@ -34,6 +32,7 @@ class VarName: SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold" COMM_DRIVER_PATH = "comm_driver_path" HEARTBEAT_INTERVAL = "heartbeat_interval" + USE_AIO_GRPC_VAR_NAME = "use_aio_grpc" STREAMING_CHUNK_SIZE = "streaming_chunk_size" STREAMING_ACK_WAIT = "streaming_ack_wait" STREAMING_WINDOW_SIZE = "streaming_window_size" @@ -43,10 +42,26 @@ class VarName: class CommConfigurator: + + _config_loaded = False + _configuration = None + def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) - config: Config = ConfigService.load_configuration(file_basename=_comm_config_files[0]) - self.config = None if config is None else config.to_dict() + # only load once! + if not CommConfigurator._config_loaded: + config: Config = ConfigService.load_configuration(file_basename=_comm_config_files[0]) + CommConfigurator._configuration = None if config is None else config.to_dict() + CommConfigurator._config_loaded = True + self.config = CommConfigurator._configuration + + @staticmethod + def reset(): + """Reset the configurator to allow reloading config files. + + Returns: + + """ + CommConfigurator._config_loaded = False def get_config(self): return self.config @@ -78,6 +93,9 @@ def get_comm_driver_path(self, default): def get_heartbeat_interval(self, default): return ConfigService.get_int_var(VarName.HEARTBEAT_INTERVAL, self.config, default=default) + def use_aio_grpc(self, default): + return ConfigService.get_bool_var(VarName.USE_AIO_GRPC_VAR_NAME, self.config, default) + def get_streaming_chunk_size(self, default): return ConfigService.get_int_var(VarName.STREAMING_CHUNK_SIZE, self.config, default=default) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 2872e3b7b9..c93dffed93 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -34,6 +34,7 @@ from .base_driver import BaseDriver from .driver_params import DriverCap, DriverParams from .grpc.streamer_pb2 import Frame +from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required GRPC_DEFAULT_OPTIONS = [ @@ -68,11 +69,18 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di def get_conn_properties(self) -> dict: return self.conn_props + async def _abort(self): + try: + self.context.abort(grpc.StatusCode.CANCELLED, "service closed") + except: + # ignore exception (if any) when aborting + pass + def close(self): self.closing = True with self.lock: if self.context: - self.aio_ctx.run_coro(self.context.abort(grpc.StatusCode.CANCELLED, "service closed")) + self.aio_ctx.run_coro(self._abort()) self.context = None if self.channel: self.aio_ctx.run_coro(self.channel.close()) @@ -197,20 +205,18 @@ def __init__(self, driver, connector, aio_ctx: AioContext, options, conn_ctx: _C servicer = Servicer(self, aio_ctx) add_StreamerServicer_to_server(servicer, self.grpc_server) params = connector.params - host = params.get(DriverParams.HOST.value) - if not host: - host = "0.0.0.0" - port = int(params.get(DriverParams.PORT.value)) - addr = f"{host}:{port}" + addr = get_address(params) try: self.logger.debug(f"SERVER: connector params: {params}") secure = ssl_required(params) if secure: - credentials = AioGrpcDriver.get_grpc_server_credentials(params) + credentials = get_grpc_server_credentials(params) self.grpc_server.add_secure_port(addr, server_credentials=credentials) + self.logger.info(f"added secure port at {addr}") else: self.grpc_server.add_insecure_port(addr) + self.logger.info(f"added insecure port at {addr}") except Exception as ex: conn_ctx.error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}" self.logger.debug(conn_ctx.error) @@ -251,7 +257,10 @@ def __init__(self): @staticmethod def supported_transports() -> List[str]: - return ["grpc", "grpcs"] + if use_aio_grpc(): + return ["grpc", "grpcs"] + else: + return ["agrpc", "agrpcs"] @staticmethod def capabilities() -> Dict[str, Any]: @@ -295,10 +304,12 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co secure = ssl_required(params) if secure: grpc_channel = grpc.aio.secure_channel( - address, options=self.options, credentials=self.get_grpc_client_credentials(params) + address, options=self.options, credentials=get_grpc_client_credentials(params) ) + self.logger.info(f"created secure channel at {address}") else: grpc_channel = grpc.aio.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") async with grpc_channel as channel: self.logger.debug(f"CLIENT: connected to {address}") @@ -374,38 +385,9 @@ def shutdown(self): def get_urls(scheme: str, resources: dict) -> (str, str): secure = resources.get(DriverParams.SECURE) if secure: - scheme = "grpcs" + if use_aio_grpc(): + scheme = "grpcs" + else: + scheme = "agrpcs" return get_tcp_urls(scheme, resources) - - @staticmethod - def get_grpc_client_credentials(params: dict): - - root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_CERT)) - private_key = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_KEY)) - - return grpc.ssl_channel_credentials( - certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert - ) - - @staticmethod - def get_grpc_server_credentials(params: dict): - - root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_CERT)) - private_key = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_KEY)) - - return grpc.ssl_server_credentials( - [(private_key, cert_chain)], - root_certificates=root_cert, - require_client_auth=True, - ) - - @staticmethod - def read_file(file_name: str): - if not file_name: - return None - - with open(file_name, "rb") as f: - return f.read() diff --git a/nvflare/fuel/f3/drivers/grpc/qq.py b/nvflare/fuel/f3/drivers/grpc/qq.py new file mode 100644 index 0000000000..ca0eeb25f2 --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc/qq.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import queue + + +class QueueClosed(Exception): + pass + + +class QQ: + def __init__(self): + self.q = queue.Queue() + self.closed = False + self.logger = logging.getLogger(self.__class__.__name__) + + def close(self): + self.closed = True + + def append(self, i): + if self.closed: + raise QueueClosed("queue stopped") + self.q.put_nowait(i) + + def __iter__(self): + return self + + def __next__(self): + if self.closed: + raise StopIteration() + while True: + try: + return self.q.get(block=True, timeout=0.1) + except queue.Empty: + if self.closed: + self.logger.debug("Queue closed - stop iteration") + raise StopIteration() + except Exception as e: + self.logger.error(f"queue exception {type(e)}") + raise e diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py new file mode 100644 index 0000000000..d95bb8138a --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import grpc + +from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +def use_aio_grpc(): + configurator = CommConfigurator() + return configurator.use_aio_grpc(default=True) + + +def get_grpc_client_credentials(params: dict): + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT)) + private_key = _read_file(params.get(DriverParams.CLIENT_KEY)) + return grpc.ssl_channel_credentials( + certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert + ) + + +def get_grpc_server_credentials(params: dict): + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.SERVER_CERT)) + private_key = _read_file(params.get(DriverParams.SERVER_KEY)) + + return grpc.ssl_server_credentials( + [(private_key, cert_chain)], + root_certificates=root_cert, + require_client_auth=True, + ) + + +def _read_file(file_name: str): + if not file_name: + return None + + with open(file_name, "rb") as f: + return f.read() diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py new file mode 100644 index 0000000000..6ef5711aec --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from concurrent import futures +from typing import Any, Dict, List, Union + +import grpc + +from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_error import CommError +from nvflare.fuel.f3.connection import Connection +from nvflare.fuel.f3.drivers.driver import ConnectorInfo +from nvflare.fuel.f3.drivers.grpc.streamer_pb2_grpc import ( + StreamerServicer, + StreamerStub, + add_StreamerServicer_to_server, +) +from nvflare.fuel.utils.obj_utils import get_logger +from nvflare.security.logging import secure_format_exception + +from .base_driver import BaseDriver +from .driver_params import DriverCap, DriverParams +from .grpc.qq import QQ +from .grpc.streamer_pb2 import Frame +from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc +from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required + +GRPC_DEFAULT_OPTIONS = [ + ("grpc.max_send_message_length", MAX_FRAME_SIZE), + ("grpc.max_receive_message_length", MAX_FRAME_SIZE), +] + + +class StreamConnection(Connection): + + seq_num = 0 + + def __init__(self, oq: QQ, connector: ConnectorInfo, conn_props: dict, side: str, context=None, channel=None): + super().__init__(connector) + self.side = side + self.oq = oq + self.closing = False + self.conn_props = conn_props + self.context = context # for server side + self.channel = channel # for client side + self.lock = threading.Lock() + self.logger = get_logger(self) + + def get_conn_properties(self) -> dict: + return self.conn_props + + def close(self): + self.closing = True + with self.lock: + self.oq.close() + if self.context: + try: + self.context.abort(grpc.StatusCode.CANCELLED, "service closed") + except: + # ignore any exception when aborting + pass + self.context = None + if self.channel: + self.channel.close() + self.channel = None + + def send_frame(self, frame: Union[bytes, bytearray, memoryview]): + try: + StreamConnection.seq_num += 1 + seq = StreamConnection.seq_num + self.logger.debug(f"{self.side}: queued frame #{seq}") + self.oq.append(Frame(seq=seq, data=bytes(frame))) + except BaseException as ex: + raise CommError(CommError.ERROR, f"Error sending frame: {ex}") + + def read_loop(self, msg_iter, q: QQ): + ct = threading.current_thread() + self.logger.debug(f"{self.side}: started read_loop in thread {ct.name}") + try: + for f in msg_iter: + if self.closing: + break + + assert isinstance(f, Frame) + self.logger.debug(f"{self.side} in {ct.name}: incoming frame #{f.seq}") + if self.frame_receiver: + self.frame_receiver.process_frame(f.data) + else: + self.logger.error(f"{self.side}: Frame receiver not registered for connection: {self.name}") + except Exception as ex: + if not self.closing: + self.logger.debug(f"{self.side}: exception {type(ex)} in read_loop") + if q: + self.logger.debug(f"{self.side}: closing queue") + q.close() + self.logger.debug(f"{self.side} in {ct.name}: done read_loop") + + def generate_output(self): + ct = threading.current_thread() + self.logger.debug(f"{self.side}: generate_output in thread {ct.name}") + for i in self.oq: + assert isinstance(i, Frame) + self.logger.debug(f"{self.side}: outgoing frame #{i.seq}") + yield i + self.logger.debug(f"{self.side}: done generate_output in thread {ct.name}") + + +class Servicer(StreamerServicer): + def __init__(self, server): + self.server = server + self.logger = get_logger(self) + + def Stream(self, request_iterator, context): + connection = None + oq = QQ() + t = None + ct = threading.current_thread() + conn_props = { + DriverParams.PEER_ADDR.value: context.peer(), + DriverParams.LOCAL_ADDR.value: get_address(self.server.connector.params), + } + cn_names = context.auth_context().get("x509_common_name") + if cn_names: + conn_props[DriverParams.PEER_CN.value] = cn_names[0].decode("utf-8") + + try: + self.logger.debug(f"SERVER started Stream CB in thread {ct.name}") + connection = StreamConnection(oq, self.server.connector, conn_props, "SERVER", context=context) + self.logger.debug(f"SERVER created connection in thread {ct.name}") + self.server.driver.add_connection(connection) + self.logger.debug(f"SERVER created read_loop thread in thread {ct.name}") + t = threading.Thread(target=connection.read_loop, args=(request_iterator, oq)) + t.start() + + # DO NOT use connection.generate_output()! + self.logger.debug(f"SERVER: generate_output in thread {ct.name}") + for i in oq: + assert isinstance(i, Frame) + self.logger.debug(f"SERVER: outgoing frame #{i.seq}") + yield i + self.logger.debug(f"SERVER: done generate_output in thread {ct.name}") + + except BaseException as ex: + self.logger.error(f"Connection closed due to error: {ex}") + finally: + if t is not None: + t.join() + if connection: + self.logger.debug(f"SERVER: closing connection {connection.name}") + self.server.driver.close_connection(connection) + self.logger.debug(f"SERVER: cleanly finished Stream CB in thread {ct.name}") + + +class Server: + def __init__( + self, + driver, + connector, + max_workers, + options, + ): + self.driver = driver + self.logger = get_logger(self) + self.connector = connector + self.grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), options=options) + servicer = Servicer(self) + add_StreamerServicer_to_server(servicer, self.grpc_server) + + params = connector.params + addr = get_address(params) + try: + self.logger.debug(f"SERVER: connector params: {params}") + secure = ssl_required(params) + if secure: + credentials = get_grpc_server_credentials(params) + self.grpc_server.add_secure_port(addr, server_credentials=credentials) + self.logger.info(f"added secure port at {addr}") + else: + self.grpc_server.add_insecure_port(addr) + self.logger.info(f"added insecure port at {addr}") + except Exception as ex: + error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}" + self.logger.debug(error) + + def start(self): + self.grpc_server.start() + self.grpc_server.wait_for_termination() + + def shutdown(self): + self.grpc_server.stop(grace=0.5) + + +class GrpcDriver(BaseDriver): + def __init__(self): + BaseDriver.__init__(self) + self.server = None + self.closing = False + self.max_workers = 100 + self.options = GRPC_DEFAULT_OPTIONS + self.logger = get_logger(self) + configurator = CommConfigurator() + config = configurator.get_config() + if config: + my_params = config.get("grpc") + if my_params: + self.max_workers = my_params.get("max_workers", 100) + self.options = my_params.get("options") + self.logger.debug(f"GRPC Config: max_workers={self.max_workers}, options={self.options}") + + @staticmethod + def supported_transports() -> List[str]: + if use_aio_grpc(): + return ["nagrpc", "nagrpcs"] + else: + return ["grpc", "grpcs"] + + @staticmethod + def capabilities() -> Dict[str, Any]: + return {DriverCap.SEND_HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: True} + + def listen(self, connector: ConnectorInfo): + self.connector = connector + self.server = Server(self, connector, max_workers=self.max_workers, options=self.options) + self.server.start() + + def connect(self, connector: ConnectorInfo): + self.logger.debug("CLIENT: trying connect ...") + params = connector.params + address = get_address(params) + conn_props = {DriverParams.PEER_ADDR.value: address} + + secure = ssl_required(params) + if secure: + self.logger.debug("CLIENT: creating secure channel") + channel = grpc.secure_channel( + address, options=self.options, credentials=get_grpc_client_credentials(params) + ) + self.logger.info(f"created secure channel at {address}") + else: + self.logger.info("CLIENT: creating insecure channel") + channel = grpc.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") + + self.logger.debug("CLIENT: created channel") + stub = StreamerStub(channel) + self.logger.debug("CLIENT: got stub") + oq = QQ() + connection = StreamConnection(oq, connector, conn_props, "CLIENT", channel=channel) + self.add_connection(connection) + self.logger.debug("CLIENT: added connection") + try: + received = stub.Stream(connection.generate_output()) + connection.read_loop(received, oq) + except BaseException as ex: + self.logger.info(f"CLIENT: connection done: {type(ex)}") + connection.close() + self.close_connection(connection) + self.logger.info(f"CLIENT: finished connection {connection}") + + @staticmethod + def get_urls(scheme: str, resources: dict) -> (str, str): + secure = resources.get(DriverParams.SECURE) + if secure: + if use_aio_grpc(): + scheme = "nagrpcs" + else: + scheme = "grpcs" + return get_tcp_urls(scheme, resources) + + def shutdown(self): + if self.closing: + return + self.closing = True + self.close_all() + if self.server: + self.server.shutdown() diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 3f9ae1e77d..6dd17b8979 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -79,7 +79,8 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: def get_address(params: dict) -> str: host = params.get(DriverParams.HOST.value, "0.0.0.0") port = params.get(DriverParams.PORT.value, 0) - + if not host: + host = "0.0.0.0" return f"{host}:{port}" diff --git a/nvflare/fuel/flare_api/flare_api.py b/nvflare/fuel/flare_api/flare_api.py index 4a6de97917..e3ba0f2b85 100644 --- a/nvflare/fuel/flare_api/flare_api.py +++ b/nvflare/fuel/flare_api/flare_api.py @@ -358,12 +358,8 @@ def download_job_result(self, job_id: str) -> str: self._validate_job_id(job_id) result = self._do_command(AdminCommandNames.DOWNLOAD_JOB + " " + job_id) meta = result[ResultKey.META] - download_job_id = meta.get(MetaKey.JOB_ID, None) - job_download_url = meta.get(MetaKey.JOB_DOWNLOAD_URL, None) - if not job_download_url: - return os.path.join(self.download_dir, download_job_id) - else: - return job_download_url + location = meta.get(MetaKey.LOCATION) + return location def abort_job(self, job_id: str): """Abort the specified job. diff --git a/nvflare/fuel/hci/client/cli.py b/nvflare/fuel/hci/client/cli.py index d43a9a7439..d18c54dcad 100644 --- a/nvflare/fuel/hci/client/cli.py +++ b/nvflare/fuel/hci/client/cli.py @@ -304,6 +304,13 @@ def default(self, line): self.write_stdout(f"exception occurred: {secure_format_exception(e)}") self._close_output_file() + @staticmethod + def _user_input(prompt: str) -> str: + answer = input(prompt) + + # remove leading and trailing spaces + return answer.strip() + def _do_default(self, line): args = split_to_args(line) cmd_name = args[0] @@ -360,14 +367,14 @@ def _do_default(self, line): info = CommandInfo.CONFIRM_YN if info == CommandInfo.CONFIRM_YN: - answer = input("Are you sure (y/N): ") + answer = self._user_input("Are you sure (y/N): ") answer = answer.lower() if answer != "y" and answer != "yes": return elif info == CommandInfo.CONFIRM_USER_NAME: - answer = input("Confirm with User Name: ") + answer = self._user_input("Confirm with User Name: ") if answer != self.user_name: - self.write_string("user name mismatch") + self.write_string(f"user name mismatch: {answer} != {self.user_name}") return elif info == CommandInfo.CONFIRM_PWD: pwd = getpass.getpass("Enter password to confirm: ") @@ -428,7 +435,7 @@ def cmdloop(self, intro=None): else: if self.use_rawinput: try: - line = input(self.prompt) + line = self._user_input(self.prompt) except (EOFError, ConnectionError): line = "bye" except KeyboardInterrupt: @@ -477,7 +484,7 @@ def _get_login_creds(self): elif self.credential_type == CredentialType.LOCAL_CERT: self.user_name = self.username else: - self.user_name = input("User Name: ") + self.user_name = self._user_input("User Name: ") def print_resp(self, resp: dict): """Prints the server response diff --git a/nvflare/fuel/hci/client/file_transfer.py b/nvflare/fuel/hci/client/file_transfer.py index 113e304eba..a31eb51ad7 100644 --- a/nvflare/fuel/hci/client/file_transfer.py +++ b/nvflare/fuel/hci/client/file_transfer.py @@ -430,7 +430,11 @@ def pull_folder(self, args, ctx: CommandContext): tx_path = self._tx_path(tx_id, folder_name) destination_path = os.path.join(self.download_dir, destination_name) location = self._rename_folder(tx_path, destination_path) - reply = {ProtoKey.STATUS: APIStatus.SUCCESS, ProtoKey.DETAILS: f"content downloaded to {location}"} + reply = { + ProtoKey.STATUS: APIStatus.SUCCESS, + ProtoKey.DETAILS: f"content downloaded to {location}", + ProtoKey.META: {MetaKey.LOCATION: location}, + } else: reply = error return reply diff --git a/nvflare/fuel/hci/proto.py b/nvflare/fuel/hci/proto.py index 2459181902..fb0e272b63 100644 --- a/nvflare/fuel/hci/proto.py +++ b/nvflare/fuel/hci/proto.py @@ -66,6 +66,7 @@ class MetaKey(object): CMD_NAME = "cmd_name" TX_ID = "tx_id" FOLDER_NAME = "folder_name" + LOCATION = "location" class MetaStatusValue(object): diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index d52fdf2583..6dcf435046 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -128,9 +128,10 @@ def new_context(self) -> FLContext: def send_task_result(self, result: Shareable, fl_ctx: FLContext) -> bool: push_result = self.client.push_results(result, fl_ctx) # push task execution results - if push_result[0] == CellReturnCode.OK: + if push_result == CellReturnCode.OK: return True else: + self.logger.error(f"failed to send task result: {push_result}") return False def get_workspace(self) -> Workspace: diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index ab7543b55b..ab8d0d57d5 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -25,6 +25,7 @@ from nvflare.apis.signal import Signal from nvflare.apis.utils.fl_context_utils import add_job_audit_event from nvflare.apis.utils.task_utils import apply_filters +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.utils.config_service import ConfigService from nvflare.private.defs import SpecialTaskName, TaskConstant from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment @@ -33,6 +34,10 @@ from nvflare.security.logging import secure_format_exception from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector +_TASK_CHECK_RESULT_OK = 0 +_TASK_CHECK_RESULT_TRY_AGAIN = 1 +_TASK_CHECK_RESULT_TASK_GONE = 2 + class TaskRouter: def __init__(self): @@ -136,6 +141,9 @@ def __init__( self.run_abort_signal = Signal() self.task_lock = threading.Lock() self.running_tasks = {} # task_id => TaskAssignment + + self.task_check_timeout = 5.0 + self.task_check_interval = 5.0 self._register_aux_message_handlers(engine) def find_executor(self, task_name): @@ -398,11 +406,39 @@ def _do_process_task(self, task: TaskAssignment, fl_ctx: FLContext, abort_signal return self._reply_and_audit(reply=reply, ref=server_audit_event_id, fl_ctx=fl_ctx, msg="submit result OK") def _try_run(self): + heartbeat_thread = threading.Thread(target=self._send_job_heartbeat, args=[], daemon=True) + heartbeat_thread.start() + while not self.run_abort_signal.triggered: with self.engine.new_context() as fl_ctx: task_fetch_interval, _ = self.fetch_and_run_one_task(fl_ctx) time.sleep(task_fetch_interval) + def _send_job_heartbeat(self, interval=30.0): + sleep_time = 1.0 + wait_times = int(interval / sleep_time) + if wait_times == 0: + wait_times = 1 + request = Shareable() + while not self.run_abort_signal.triggered: + with self.engine.new_context() as fl_ctx: + self.engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=ReservedTopic.JOB_HEART_BEAT, + request=request, + timeout=0, + fl_ctx=fl_ctx, + optional=True, + ) + + # we want to send the HB every "interval" secs. + # but we don't want to sleep that long since it will block us from checking abort signal. + # hence we only sleep 1 sec, and check the abort signal. + for i in range(wait_times): + time.sleep(sleep_time) + if self.run_abort_signal.triggered: + break + def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): """Fetches and runs a task. @@ -439,19 +475,105 @@ def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): self.log_debug(fl_ctx, "firing event EventType.BEFORE_SEND_TASK_RESULT") self.fire_event(EventType.BEFORE_SEND_TASK_RESULT, fl_ctx) - reply_sent = self.engine.send_task_result(task_reply, fl_ctx) - if reply_sent: - self.log_info(fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id)) - else: - self.log_error( - fl_ctx, - "failed to send result to server for task: name={}, id={}".format(task.name, task.task_id), - ) + self._send_task_result(task_reply, task.task_id, fl_ctx) self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT") self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx) return task_fetch_interval, True + def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext): + try_count = 1 + while True: + self.log_info(fl_ctx, f"try #{try_count}: sending task result to server") + + if self.run_abort_signal.triggered: + self.log_info(fl_ctx, "job aborted: stopped trying to send result") + return False + + try_count += 1 + rc = self._try_send_result_once(result, task_id, fl_ctx) + + if rc == _TASK_CHECK_RESULT_OK: + return True + elif rc == _TASK_CHECK_RESULT_TASK_GONE: + return False + else: + # retry + time.sleep(self.task_check_interval) + + def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext): + # wait until server is ready to receive + while True: + if self.run_abort_signal.triggered: + return _TASK_CHECK_RESULT_TASK_GONE + + rc = self._check_task_once(task_id, fl_ctx) + if rc == _TASK_CHECK_RESULT_OK: + break + elif rc == _TASK_CHECK_RESULT_TASK_GONE: + return rc + else: + # try again + time.sleep(self.task_check_interval) + + # try to send the result + self.log_info(fl_ctx, "start to send task result to server") + reply_sent = self.engine.send_task_result(result, fl_ctx) + if reply_sent: + self.log_info(fl_ctx, "task result sent to server") + return _TASK_CHECK_RESULT_OK + else: + self.log_error(fl_ctx, "failed to send task result to server - will try again") + return _TASK_CHECK_RESULT_TRY_AGAIN + + def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int: + """This method checks whether the server is still waiting for the specified task. + The real reason for this method is to fight against unstable network connections. + We try to make sure that when we send task result to the server, the connection is available. + If the task check succeeds, then the network connection is likely to be available. + Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out). + Args: + task_id: + fl_ctx: + Returns: + """ + self.log_info(fl_ctx, "checking task ...") + task_check_req = Shareable() + task_check_req.set_header(ReservedKey.TASK_ID, task_id) + resp = self.engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=ReservedTopic.TASK_CHECK, + request=task_check_req, + timeout=self.task_check_timeout, + fl_ctx=fl_ctx, + optional=True, + ) + if resp and isinstance(resp, dict): + reply = resp.get(FQCN.ROOT_SERVER) + if not isinstance(reply, Shareable): + self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}") + return _TASK_CHECK_RESULT_TRY_AGAIN + + rc = reply.get_return_code() + if rc == ReturnCode.OK: + return _TASK_CHECK_RESULT_OK + elif rc == ReturnCode.COMMUNICATION_ERROR: + self.log_error(fl_ctx, f"failed task_check: {rc}") + return _TASK_CHECK_RESULT_TRY_AGAIN + elif rc == ReturnCode.SERVER_NOT_READY: + self.log_error(fl_ctx, f"server rejected task_check: {rc}") + return _TASK_CHECK_RESULT_TRY_AGAIN + elif rc == ReturnCode.TASK_UNKNOWN: + self.log_error(fl_ctx, f"task no longer exists on server: {rc}") + return _TASK_CHECK_RESULT_TASK_GONE + else: + # this should never happen + self.log_error(fl_ctx, f"programming error: received {rc} from server") + return _TASK_CHECK_RESULT_OK # try to push the result regardless + else: + self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}") + return _TASK_CHECK_RESULT_TRY_AGAIN + def run(self, app_root, args): self.init_run(app_root, args) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 0f0f44231b..4c8c96b21a 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -64,6 +64,7 @@ def __init__( cell: CoreCell = None, client_register_interval=2, timeout=5.0, + maint_msg_timeout=5.0, ): """To init the Communicator. @@ -84,6 +85,7 @@ def __init__( self.compression = compression self.client_register_interval = client_register_interval self.timeout = timeout + self.maint_msg_timeout = maint_msg_timeout self.logger = logging.getLogger(self.__class__.__name__) @@ -130,7 +132,7 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Register, request=login_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: @@ -298,7 +300,7 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Quit, request=quit_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: @@ -336,7 +338,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.HEART_BEAT, request=heartbeat_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 9ff11a2ca7..b62642c8ee 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -15,8 +15,6 @@ import logging import threading import time -from functools import partial -from multiprocessing.dummy import Pool as ThreadPool from typing import List, Optional from nvflare.apis.filter import Filter @@ -40,15 +38,6 @@ from .communicator import Communicator -def _check_progress(remote_tasks): - if remote_tasks[0] is not None: - # shareable = fobs.loads(remote_tasks[0].payload) - shareable = remote_tasks[0].payload - return True, shareable.get_header(ServerCommandKey.TASK_NAME), shareable - else: - return False, None, None - - class FederatedClientBase: """The client-side base implementation of federated learning. @@ -104,6 +93,7 @@ def __init__( cell=cell, client_register_interval=client_args.get("client_register_interval", 2.0), timeout=client_args.get("communication_timeout", 30.0), + maint_msg_timeout=client_args.get("maint_msg_timeout", 5.0), ) self.secure_train = secure_train @@ -336,63 +326,38 @@ def quit_remote(self, project_name, fl_ctx: FLContext): """ return self.communicator.quit_remote(self.servers, project_name, self.token, self.ssid, fl_ctx) + def _get_project_name(self): + """Get name of the project that the site is part of. + + Returns: + + """ + s = tuple(self.servers) # self.servers is a dict of project_name => server config + return s[0] + def heartbeat(self, interval): """Sends a heartbeat from the client to the server.""" - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.send_heartbeat, interval=interval), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.send_heartbeat(self._get_project_name(), interval) def pull_task(self, fl_ctx: FLContext): """Fetch remote models and update the local client's session.""" - pool = None - try: - pool = ThreadPool(len(self.servers)) - self.remote_tasks = pool.map(partial(self.fetch_execute_task, fl_ctx=fl_ctx), tuple(self.servers)) - pull_success, task_name, shareable = _check_progress(self.remote_tasks) - # TODO: if some of the servers failed - return pull_success, task_name, shareable - finally: - if pool: - pool.terminate() + result = self.fetch_execute_task(self._get_project_name(), fl_ctx) + if result: + shareable = result.payload + return True, shareable.get_header(ServerCommandKey.TASK_NAME), shareable + else: + return False, None, None def push_results(self, shareable: Shareable, fl_ctx: FLContext): """Push the local model to multiple servers.""" - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.push_execute_result, shareable=shareable, fl_ctx=fl_ctx), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.push_execute_result(self._get_project_name(), shareable, fl_ctx) def register(self, fl_ctx: FLContext): - """Push the local model to multiple servers. - - Args: - fl_ctx: FLContext - - Returns: N/A - """ - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.client_register, fl_ctx=fl_ctx), tuple(self.servers)) - finally: - if pool: - pool.terminate() + """Push the local model to multiple servers.""" + return self.client_register(self._get_project_name(), fl_ctx) def set_primary_sp(self, sp): - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.set_sp, sp=sp), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.set_sp(self._get_project_name(), sp) def run_heartbeat(self, interval): """Periodically runs the heartbeat.""" @@ -403,6 +368,7 @@ def run_heartbeat(self, interval): def start_heartbeat(self, interval=30): heartbeat_thread = threading.Thread(target=self.run_heartbeat, args=[interval]) + heartbeat_thread.daemon = True heartbeat_thread.start() def logout_client(self, fl_ctx: FLContext): @@ -414,13 +380,7 @@ def logout_client(self, fl_ctx: FLContext): Returns: N/A """ - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.quit_remote, fl_ctx=fl_ctx), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.quit_remote(self._get_project_name(), fl_ctx) def set_client_engine(self, engine): self.engine = engine diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index d9635eb6c1..5e782f45f2 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -99,11 +99,19 @@ def __init__(self, config: ServerRunnerConfig, job_id: str, engine: ServerEngine self.current_wf_index = 0 self.status = "init" self.turn_to_cold = False + self._register_aux_message_handler(engine) + def _register_aux_message_handler(self, engine): engine.register_aux_message_handler( topic=ReservedTopic.SYNC_RUNNER, message_handle_func=self._handle_sync_runner ) + engine.register_aux_message_handler( + topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat + ) + + engine.register_aux_message_handler(topic=ReservedTopic.TASK_CHECK, message_handle_func=self._handle_task_check) + def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: # simply ack return make_reply(ReturnCode.OK) @@ -475,6 +483,31 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul "Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)), ) + def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + self.log_info(fl_ctx, "received client job_heartbeat aux request") + return make_reply(ReturnCode.OK) + + def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + task_id = request.get_header(ReservedHeaderKey.TASK_ID) + if not task_id: + self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request") + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + self.log_info(fl_ctx, f"received task_check on task {task_id}") + + with self.wf_lock: + if self.current_wf is None or self.current_wf.responder is None: + self.log_info(fl_ctx, "no current workflow - dropped task_check.") + return make_reply(ReturnCode.TASK_UNKNOWN) + + task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx) + if task: + self.log_info(fl_ctx, f"task {task_id} is still good") + return make_reply(ReturnCode.OK) + else: + self.log_info(fl_ctx, f"task {task_id} is not found") + return make_reply(ReturnCode.TASK_UNKNOWN) + def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False): self.status = "done" self.abort_signal.trigger(value=True) diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index b6cdcc534f..ba39d26e23 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -18,7 +18,6 @@ import os import sys from logging.handlers import RotatingFileHandler -from multiprocessing.connection import Listener from typing import List from nvflare.apis.app_validation import AppValidator @@ -39,7 +38,7 @@ from nvflare.private.event import fire_event from nvflare.private.fed.utils.decomposers import private_decomposers from nvflare.private.privacy_manager import PrivacyManager, PrivacyService -from nvflare.security.logging import secure_format_exception, secure_log_traceback +from nvflare.security.logging import secure_format_exception from nvflare.security.security import EmptyAuthorizer, FLAuthorizer from .app_authz import AppAuthzService @@ -54,28 +53,6 @@ def add_logfile_handler(log_file): root_logger.addHandler(file_handler) -def listen_command(listen_port, engine, execute_func, logger): - conn = None - listener = None - try: - address = ("localhost", listen_port) - listener = Listener(address, authkey="client process secret password".encode()) - conn = listener.accept() - - execute_func(conn, engine) - - except Exception as e: - logger.exception( - f"Could not create the listener for this process on port: {listen_port}: {secure_format_exception(e)}." - ) - secure_log_traceback(logger) - finally: - if conn: - conn.close() - if listener: - listener.close() - - def _check_secure_content(site_type: str) -> List[str]: """To check the security contents. diff --git a/tests/unit_test/fuel/f3/communicator_test.py b/tests/unit_test/fuel/f3/communicator_test.py index b08bf56b18..45cbd178d6 100644 --- a/tests/unit_test/fuel/f3/communicator_test.py +++ b/tests/unit_test/fuel/f3/communicator_test.py @@ -90,6 +90,7 @@ class TestCommunicator: [ ("tcp", "2000-3000"), ("grpc", "3000-4000"), + ("nagrpc", "4000-5000"), # ("http", "4000-5000"), TODO (YT): We disable this, as it is causing our jenkins hanging ("atcp", "5000-6000"), ], diff --git a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py index 5137f52db5..90583c653c 100644 --- a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py +++ b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py @@ -18,12 +18,14 @@ from nvflare.fuel.f3 import communicator # Setup custom driver path before communicator module initialization +from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.utils.config_service import ConfigService class TestCustomDriver: @pytest.fixture def manager(self): + CommConfigurator.reset() rel_path = "../../../data/custom_drivers/config" config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), rel_path)) ConfigService.initialize({}, [config_path]) diff --git a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py index a653a47af6..438ccea80b 100644 --- a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py +++ b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py @@ -20,6 +20,7 @@ from nvflare.fuel.f3.drivers.aio_http_driver import AioHttpDriver from nvflare.fuel.f3.drivers.aio_tcp_driver import AioTcpDriver from nvflare.fuel.f3.drivers.driver_manager import DriverManager +from nvflare.fuel.f3.drivers.grpc_driver import GrpcDriver from nvflare.fuel.f3.drivers.tcp_driver import TcpDriver @@ -37,6 +38,8 @@ def manager(self): ("stcp", TcpDriver), ("grpc", AioGrpcDriver), ("grpcs", AioGrpcDriver), + ("nagrpc", GrpcDriver), + ("nagrpcs", GrpcDriver), ("http", AioHttpDriver), ("https", AioHttpDriver), ("ws", AioHttpDriver),