diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index 4b834fa17302..d389cbb24c0e 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -30,6 +30,7 @@ from flwr.common.typing import RunNotRunningException from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub +from flwr.proto.simulationio_pb2_grpc import SimulationIoStub def exponential( @@ -365,7 +366,8 @@ def _should_giveup_fn(e: Exception) -> bool: def _wrap_stub( - stub: Union[ServerAppIoStub, ClientAppIoStub], retry_invoker: RetryInvoker + stub: Union[ServerAppIoStub, ClientAppIoStub, SimulationIoStub], + retry_invoker: RetryInvoker, ) -> None: """Wrap a gRPC stub with a retry invoker.""" diff --git a/src/py/flwr/simulation/simulationio_connection.py b/src/py/flwr/simulation/simulationio_connection.py index ab6e5450c90e..011d2ff8e9d0 100644 --- a/src/py/flwr/simulation/simulationio_connection.py +++ b/src/py/flwr/simulation/simulationio_connection.py @@ -23,6 +23,7 @@ from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611 @@ -48,6 +49,7 @@ def __init__( # pylint: disable=too-many-arguments self._cert = root_certificates self._grpc_stub: Optional[SimulationIoStub] = None self._channel: Optional[grpc.Channel] = None + self._retry_invoker = _make_simple_grpc_retry_invoker() @property def _is_connected(self) -> bool: @@ -72,6 +74,7 @@ def _connect(self) -> None: root_certificates=self._cert, ) self._grpc_stub = SimulationIoStub(self._channel) + _wrap_stub(self._grpc_stub, self._retry_invoker) log(DEBUG, "[SimulationIO] Connected to %s", self._addr) def _disconnect(self) -> None: