diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index e50e631b3ca..079d6923b23 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -532,7 +532,7 @@ def _run_fleet_api_grpc_rere( """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server fleet_servicer = FleetServicer( - state=state_factory.state(), + state_factory=state_factory, ) fleet_add_servicer_to_server_fn = add_FleetServicer_to_server fleet_grpc_server = generic_create_grpc_server( diff --git a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py index b12f365e898..25707b6247f 100644 --- a/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/fleet/grpc_rere/fleet_servicer.py @@ -32,14 +32,14 @@ PushTaskResResponse, ) from flwr.server.fleet.message_handler import message_handler -from flwr.server.state import State +from flwr.server.state import StateFactory class FleetServicer(fleet_pb2_grpc.FleetServicer): """Fleet API servicer.""" - def __init__(self, state: State) -> None: - self.state = state + def __init__(self, state_factory: StateFactory) -> None: + self.state_factory = state_factory def CreateNode( self, request: CreateNodeRequest, context: grpc.ServicerContext @@ -48,7 +48,7 @@ def CreateNode( log(INFO, "FleetServicer.CreateNode") return message_handler.create_node( request=request, - state=self.state, + state=self.state_factory.state(), ) def DeleteNode( @@ -58,7 +58,7 @@ def DeleteNode( log(INFO, "FleetServicer.DeleteNode") return message_handler.delete_node( request=request, - state=self.state, + state=self.state_factory.state(), ) def PullTaskIns( @@ -68,7 +68,7 @@ def PullTaskIns( log(INFO, "FleetServicer.PullTaskIns") return message_handler.pull_task_ins( request=request, - state=self.state, + state=self.state_factory.state(), ) def PushTaskRes( @@ -78,5 +78,5 @@ def PushTaskRes( log(INFO, "FleetServicer.PushTaskRes") return message_handler.push_task_res( request=request, - state=self.state, + state=self.state_factory.state(), )