From e116a816dea43934cbb9f776bc0d41d391af1d52 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 4 Dec 2024 19:24:17 +0000 Subject: [PATCH 1/2] init --- src/py/flwr/common/retry_invoker.py | 67 ++++++++++++++++++++++++ src/py/flwr/server/driver/grpc_driver.py | 65 ++--------------------- 2 files changed, 71 insertions(+), 61 deletions(-) diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index 9785b0fbd9b4..b942bb86a0ff 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -20,8 +20,16 @@ import time from collections.abc import Generator, Iterable from dataclasses import dataclass +from logging import INFO, WARN from typing import Any, Callable, Optional, Union, cast +import grpc + +from flwr.common.constant import MAX_RETRY_DELAY +from flwr.common.logger import log +from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub +from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub + def exponential( base_delay: float = 1, @@ -303,3 +311,62 @@ def giveup_check(_exception: Exception) -> bool: # Trigger success event try_call_event_handler(self.on_success) return ret + + +def _make_simple_grpc_retry_invoker() -> RetryInvoker: + """Create a simple gRPC retry invoker.""" + + def _on_sucess(retry_state: RetryState) -> None: + if retry_state.tries > 1: + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + + def _on_backoff(retry_state: RetryState) -> None: + if retry_state.tries == 1: + log(WARN, "Connection attempt failed, retrying...") + else: + log( + WARN, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + + def _on_giveup(retry_state: RetryState) -> None: + if retry_state.tries > 1: + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + + return RetryInvoker( + wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY), + recoverable_exceptions=grpc.RpcError, + max_tries=None, + max_time=None, + on_success=_on_sucess, + on_backoff=_on_backoff, + on_giveup=_on_giveup, + should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore + ) + + +def _wrap_stub( + stub: Union[ServerAppIoStub, ClientAppIoStub], retry_invoker: RetryInvoker +) -> None: + """Wrap a gRPC stub with a retry invoker.""" + + def make_lambda(original_method: Any) -> Any: + return lambda *args, **kwargs: retry_invoker.invoke( + original_method, *args, **kwargs + ) + + for method_name in vars(stub): + method = getattr(stub, method_name) + if callable(method): + setattr(stub, method_name, make_lambda(method)) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 05b7ce4be8bc..1b38c297460b 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -17,16 +17,16 @@ import time import warnings from collections.abc import Iterable -from logging import DEBUG, INFO, WARN, WARNING -from typing import Any, Optional, cast +from logging import DEBUG, WARNING +from typing import Optional, cast import grpc from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet -from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS +from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -258,60 +258,3 @@ def close(self) -> None: return # Disconnect self._disconnect() - - -def _make_simple_grpc_retry_invoker() -> RetryInvoker: - """Create a simple gRPC retry invoker.""" - - def _on_sucess(retry_state: RetryState) -> None: - if retry_state.tries > 1: - log( - INFO, - "Connection successful after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - - def _on_backoff(retry_state: RetryState) -> None: - if retry_state.tries == 1: - log(WARN, "Connection attempt failed, retrying...") - else: - log( - WARN, - "Connection attempt failed, retrying in %.2f seconds", - retry_state.actual_wait, - ) - - def _on_giveup(retry_state: RetryState) -> None: - if retry_state.tries > 1: - log( - WARN, - "Giving up reconnection after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - - return RetryInvoker( - wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY), - recoverable_exceptions=grpc.RpcError, - max_tries=None, - max_time=None, - on_success=_on_sucess, - on_backoff=_on_backoff, - on_giveup=_on_giveup, - should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore - ) - - -def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None: - """Wrap the gRPC stub with a retry invoker.""" - - def make_lambda(original_method: Any) -> Any: - return lambda *args, **kwargs: retry_invoker.invoke( - original_method, *args, **kwargs - ) - - for method_name in vars(stub): - method = getattr(stub, method_name) - if callable(method): - setattr(stub, method_name, make_lambda(method)) From 7479d2b1fc2bd838fafcc33dd3c0d482dff85fad Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 5 Dec 2024 20:41:14 +0000 Subject: [PATCH 2/2] init --- src/py/flwr/common/retry_invoker.py | 4 +++- src/py/flwr/simulation/simulationio_connection.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index b942bb86a0ff..3c226ca09151 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -29,6 +29,7 @@ from flwr.common.logger import log from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub +from flwr.proto.simulationio_pb2_grpc import SimulationIoStub def exponential( @@ -357,7 +358,8 @@ def _on_giveup(retry_state: RetryState) -> None: def _wrap_stub( - stub: Union[ServerAppIoStub, ClientAppIoStub], retry_invoker: RetryInvoker + stub: Union[ServerAppIoStub, ClientAppIoStub, SimulationIoStub], + retry_invoker: RetryInvoker, ) -> None: """Wrap a gRPC stub with a retry invoker.""" diff --git a/src/py/flwr/simulation/simulationio_connection.py b/src/py/flwr/simulation/simulationio_connection.py index ab6e5450c90e..011d2ff8e9d0 100644 --- a/src/py/flwr/simulation/simulationio_connection.py +++ b/src/py/flwr/simulation/simulationio_connection.py @@ -23,6 +23,7 @@ from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611 @@ -48,6 +49,7 @@ def __init__( # pylint: disable=too-many-arguments self._cert = root_certificates self._grpc_stub: Optional[SimulationIoStub] = None self._channel: Optional[grpc.Channel] = None + self._retry_invoker = _make_simple_grpc_retry_invoker() @property def _is_connected(self) -> bool: @@ -72,6 +74,7 @@ def _connect(self) -> None: root_certificates=self._cert, ) self._grpc_stub = SimulationIoStub(self._channel) + _wrap_stub(self._grpc_stub, self._retry_invoker) log(DEBUG, "[SimulationIO] Connected to %s", self._addr) def _disconnect(self) -> None: