Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore non-aio GRPC and a few improvements #2058

Merged
merged 11 commits into from
Oct 6, 2023
Merged
2 changes: 2 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class ReservedTopic(object):
DO_TASK = "__do_task__"
AUX_COMMAND = "__aux_command__"
SYNC_RUNNER = "__sync_runner__"
JOB_HEART_BEAT = "__job_heartbeat__"
TASK_CHECK = "__task_check__"


class AdminCommandNames(object):
Expand Down
5 changes: 5 additions & 0 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
return self._client_task_map.get(task_id, None)

def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext):
"""Called to process a submission from one client.
Expand Down
10 changes: 10 additions & 0 deletions nvflare/apis/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
"""
pass

@abstractmethod
def process_task_check(self, task_id: str, fl_ctx: FLContext):
"""Called by the Engine to check whether a specified task still exists.
Args:
task_id: the id of the task
fl_ctx: the FLContext
Returns: the ClientTask object if exists; None otherwise
"""
pass

@abstractmethod
def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Expand Down
28 changes: 23 additions & 5 deletions nvflare/fuel/f3/comm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

from nvflare.fuel.f3.drivers.net_utils import MAX_PAYLOAD_SIZE
from nvflare.fuel.utils.config import Config
from nvflare.fuel.utils.config_service import ConfigService
Expand All @@ -34,6 +32,7 @@ class VarName:
SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold"
COMM_DRIVER_PATH = "comm_driver_path"
HEARTBEAT_INTERVAL = "heartbeat_interval"
USE_AIO_GRPC_VAR_NAME = "use_aio_grpc"
STREAMING_CHUNK_SIZE = "streaming_chunk_size"
STREAMING_ACK_WAIT = "streaming_ack_wait"
STREAMING_WINDOW_SIZE = "streaming_window_size"
Expand All @@ -43,10 +42,26 @@ class VarName:


class CommConfigurator:

_config_loaded = False
_configuration = None

def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
config: Config = ConfigService.load_configuration(file_basename=_comm_config_files[0])
self.config = None if config is None else config.to_dict()
# only load once!
if not CommConfigurator._config_loaded:
config: Config = ConfigService.load_configuration(file_basename=_comm_config_files[0])
CommConfigurator._configuration = None if config is None else config.to_dict()
CommConfigurator._config_loaded = True
self.config = CommConfigurator._configuration

@staticmethod
def reset():
"""Reset the configurator to allow reloading config files.
Returns:
"""
CommConfigurator._config_loaded = False

def get_config(self):
return self.config
Expand Down Expand Up @@ -78,6 +93,9 @@ def get_comm_driver_path(self, default):
def get_heartbeat_interval(self, default):
return ConfigService.get_int_var(VarName.HEARTBEAT_INTERVAL, self.config, default=default)

def use_aio_grpc(self, default):
return ConfigService.get_bool_var(VarName.USE_AIO_GRPC_VAR_NAME, self.config, default)

def get_streaming_chunk_size(self, default):
return ConfigService.get_int_var(VarName.STREAMING_CHUNK_SIZE, self.config, default=default)

Expand Down
66 changes: 24 additions & 42 deletions nvflare/fuel/f3/drivers/aio_grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .base_driver import BaseDriver
from .driver_params import DriverCap, DriverParams
from .grpc.streamer_pb2 import Frame
from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc
from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required

GRPC_DEFAULT_OPTIONS = [
Expand Down Expand Up @@ -68,11 +69,18 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di
def get_conn_properties(self) -> dict:
return self.conn_props

async def _abort(self):
try:
self.context.abort(grpc.StatusCode.CANCELLED, "service closed")
except:
# ignore exception (if any) when aborting
pass

def close(self):
self.closing = True
with self.lock:
if self.context:
self.aio_ctx.run_coro(self.context.abort(grpc.StatusCode.CANCELLED, "service closed"))
self.aio_ctx.run_coro(self._abort())
self.context = None
if self.channel:
self.aio_ctx.run_coro(self.channel.close())
Expand Down Expand Up @@ -197,20 +205,18 @@ def __init__(self, driver, connector, aio_ctx: AioContext, options, conn_ctx: _C
servicer = Servicer(self, aio_ctx)
add_StreamerServicer_to_server(servicer, self.grpc_server)
params = connector.params
host = params.get(DriverParams.HOST.value)
if not host:
host = "0.0.0.0"
port = int(params.get(DriverParams.PORT.value))
addr = f"{host}:{port}"
addr = get_address(params)
try:
self.logger.debug(f"SERVER: connector params: {params}")

secure = ssl_required(params)
if secure:
credentials = AioGrpcDriver.get_grpc_server_credentials(params)
credentials = get_grpc_server_credentials(params)
self.grpc_server.add_secure_port(addr, server_credentials=credentials)
self.logger.info(f"added secure port at {addr}")
else:
self.grpc_server.add_insecure_port(addr)
self.logger.info(f"added insecure port at {addr}")
except Exception as ex:
conn_ctx.error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}"
self.logger.debug(conn_ctx.error)
Expand Down Expand Up @@ -251,7 +257,10 @@ def __init__(self):

@staticmethod
def supported_transports() -> List[str]:
return ["grpc", "grpcs"]
if use_aio_grpc():
return ["grpc", "grpcs"]
else:
return ["agrpc", "agrpcs"]

@staticmethod
def capabilities() -> Dict[str, Any]:
Expand Down Expand Up @@ -295,10 +304,12 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co
secure = ssl_required(params)
if secure:
grpc_channel = grpc.aio.secure_channel(
address, options=self.options, credentials=self.get_grpc_client_credentials(params)
address, options=self.options, credentials=get_grpc_client_credentials(params)
)
self.logger.info(f"created secure channel at {address}")
else:
grpc_channel = grpc.aio.insecure_channel(address, options=self.options)
self.logger.info(f"created insecure channel at {address}")

async with grpc_channel as channel:
self.logger.debug(f"CLIENT: connected to {address}")
Expand Down Expand Up @@ -374,38 +385,9 @@ def shutdown(self):
def get_urls(scheme: str, resources: dict) -> (str, str):
secure = resources.get(DriverParams.SECURE)
if secure:
scheme = "grpcs"
if use_aio_grpc():
scheme = "grpcs"
else:
scheme = "agrpcs"

return get_tcp_urls(scheme, resources)

@staticmethod
def get_grpc_client_credentials(params: dict):

root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_CERT))
private_key = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_KEY))

return grpc.ssl_channel_credentials(
certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert
)

@staticmethod
def get_grpc_server_credentials(params: dict):

root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_CERT))
private_key = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_KEY))

return grpc.ssl_server_credentials(
[(private_key, cert_chain)],
root_certificates=root_cert,
require_client_auth=True,
)

@staticmethod
def read_file(file_name: str):
if not file_name:
return None

with open(file_name, "rb") as f:
return f.read()
52 changes: 52 additions & 0 deletions nvflare/fuel/f3/drivers/grpc/qq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import queue


class QueueClosed(Exception):
pass


class QQ:
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
self.q = queue.Queue()
self.closed = False
self.logger = logging.getLogger(self.__class__.__name__)

def close(self):
self.closed = True

def append(self, i):
if self.closed:
raise QueueClosed("queue stopped")
self.q.put_nowait(i)

def __iter__(self):
return self

def __next__(self):
if self.closed:
raise StopIteration()
while True:
try:
return self.q.get(block=True, timeout=0.1)
except queue.Empty:
if self.closed:
self.logger.debug("Queue closed - stop iteration")
raise StopIteration()
except Exception as e:
self.logger.error(f"queue exception {type(e)}")
raise e
51 changes: 51 additions & 0 deletions nvflare/fuel/f3/drivers/grpc/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import grpc

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.drivers.driver_params import DriverParams


def use_aio_grpc():
configurator = CommConfigurator()
return configurator.use_aio_grpc(default=True)


def get_grpc_client_credentials(params: dict):
root_cert = _read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT))
private_key = _read_file(params.get(DriverParams.CLIENT_KEY))
return grpc.ssl_channel_credentials(
certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert
)


def get_grpc_server_credentials(params: dict):
root_cert = _read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = _read_file(params.get(DriverParams.SERVER_CERT))
private_key = _read_file(params.get(DriverParams.SERVER_KEY))

return grpc.ssl_server_credentials(
[(private_key, cert_chain)],
root_certificates=root_cert,
require_client_auth=True,
)


def _read_file(file_name: str):
if not file_name:
return None

with open(file_name, "rb") as f:
return f.read()
Loading