From 82a399a7530bbdc43ec6bb46dc5ed8acaede92a0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 11 Jun 2024 13:21:15 +0100 Subject: [PATCH 01/23] add run proto --- src/proto/flwr/proto/driver.proto | 4 ++ src/proto/flwr/proto/fleet.proto | 10 +-- src/proto/flwr/proto/run.proto | 26 ++++++++ .../grpc_rere_client/client_interceptor.py | 2 +- .../client_interceptor_test.py | 3 +- .../client/grpc_rere_client/connection.py | 3 +- src/py/flwr/client/rest_client/connection.py | 3 +- src/py/flwr/proto/driver_pb2.py | 39 ++++++------ src/py/flwr/proto/driver_pb2_grpc.py | 35 +++++++++++ src/py/flwr/proto/driver_pb2_grpc.pyi | 14 +++++ src/py/flwr/proto/fleet_pb2.py | 61 +++++++++---------- src/py/flwr/proto/fleet_pb2.pyi | 42 ------------- src/py/flwr/proto/fleet_pb2_grpc.py | 13 ++-- src/py/flwr/proto/fleet_pb2_grpc.pyi | 9 +-- src/py/flwr/proto/run_pb2.py | 30 +++++++++ src/py/flwr/proto/run_pb2.pyi | 52 ++++++++++++++++ src/py/flwr/proto/run_pb2_grpc.py | 4 ++ src/py/flwr/proto/run_pb2_grpc.pyi | 4 ++ .../superlink/driver/driver_servicer.py | 7 +++ .../fleet/grpc_rere/fleet_servicer.py | 3 +- .../fleet/grpc_rere/server_interceptor.py | 3 +- .../grpc_rere/server_interceptor_test.py | 3 +- .../fleet/message_handler/message_handler.py | 8 ++- .../superlink/fleet/rest_rere/rest_api.py | 2 +- src/py/flwr_tool/protoc_test.py | 2 +- 25 files changed, 251 insertions(+), 131 deletions(-) create mode 100644 src/proto/flwr/proto/run.proto create mode 100644 src/py/flwr/proto/run_pb2.py create mode 100644 src/py/flwr/proto/run_pb2.pyi create mode 100644 src/py/flwr/proto/run_pb2_grpc.py create mode 100644 src/py/flwr/proto/run_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index 54e6b6b41b68..edbd5d91bb5b 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -19,6 +19,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; +import "flwr/proto/run.proto"; service Driver { // Request run_id @@ -32,6 +33,9 @@ service Driver { // Get task results rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {} + + // Get run details + rpc GetRun(GetRunRequest) returns (GetRunResponse) {} } // CreateRun diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index df6b5843023d..24f60bb3d825 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -19,6 +19,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; +import "flwr/proto/run.proto"; service Fleet { rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {} @@ -70,13 +71,4 @@ message PushTaskResResponse { map results = 2; } -// GetRun messages -message Run { - sint64 run_id = 1; - string fab_id = 2; - string fab_version = 3; -} -message GetRunRequest { sint64 run_id = 1; } -message GetRunResponse { Run run = 1; } - message Reconnect { uint64 reconnect = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto new file mode 100644 index 000000000000..fba08eac171a --- /dev/null +++ b/src/proto/flwr/proto/run.proto @@ -0,0 +1,26 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message Run { + sint64 run_id = 1; + string fab_id = 2; + string fab_version = 3; +} +message GetRunRequest { sint64 run_id = 1; } +message GetRunResponse { Run run = 1; } \ No newline at end of file diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 8bc55878971d..d2dded8a73d9 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -31,11 +31,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, - GetRunRequest, PingRequest, PullTaskInsRequest, PushTaskResRequest, ) +from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 487361a06026..cc35ffef46db 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -41,13 +41,12 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 9579d5830165..52d0cc58b2bb 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -44,8 +44,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -53,6 +51,7 @@ ) from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from .client_interceptor import AuthenticateClientInterceptor diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index da8fbd351ab1..7383eae3d22b 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -46,8 +46,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -56,6 +54,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 try: diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index b0caae58ff6f..a2458b445563 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -14,31 +14,32 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc1\x02\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST']._serialized_start=85 - _globals['_CREATERUNREQUEST']._serialized_end=140 - _globals['_CREATERUNRESPONSE']._serialized_start=142 - _globals['_CREATERUNRESPONSE']._serialized_end=177 - _globals['_GETNODESREQUEST']._serialized_start=179 - _globals['_GETNODESREQUEST']._serialized_end=212 - _globals['_GETNODESRESPONSE']._serialized_start=214 - _globals['_GETNODESRESPONSE']._serialized_end=265 - _globals['_PUSHTASKINSREQUEST']._serialized_start=267 - _globals['_PUSHTASKINSREQUEST']._serialized_end=331 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=333 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=372 - _globals['_PULLTASKRESREQUEST']._serialized_start=374 - _globals['_PULLTASKRESREQUEST']._serialized_end=444 - _globals['_PULLTASKRESRESPONSE']._serialized_start=446 - _globals['_PULLTASKRESRESPONSE']._serialized_end=511 - _globals['_DRIVER']._serialized_start=514 - _globals['_DRIVER']._serialized_end=835 + _globals['_CREATERUNREQUEST']._serialized_start=107 + _globals['_CREATERUNREQUEST']._serialized_end=162 + _globals['_CREATERUNRESPONSE']._serialized_start=164 + _globals['_CREATERUNRESPONSE']._serialized_end=199 + _globals['_GETNODESREQUEST']._serialized_start=201 + _globals['_GETNODESREQUEST']._serialized_end=234 + _globals['_GETNODESRESPONSE']._serialized_start=236 + _globals['_GETNODESRESPONSE']._serialized_end=287 + _globals['_PUSHTASKINSREQUEST']._serialized_start=289 + _globals['_PUSHTASKINSREQUEST']._serialized_end=353 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=355 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=394 + _globals['_PULLTASKRESREQUEST']._serialized_start=396 + _globals['_PULLTASKRESREQUEST']._serialized_end=466 + _globals['_PULLTASKRESRESPONSE']._serialized_start=468 + _globals['_PULLTASKRESRESPONSE']._serialized_end=533 + _globals['_DRIVER']._serialized_start=536 + _globals['_DRIVER']._serialized_end=924 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index ac6815023ebd..2cd3ebe62a63 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 class DriverStub(object): @@ -34,6 +35,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString, ) + self.GetRun = channel.unary_unary( + '/flwr.proto.Driver/GetRun', + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, + ) class DriverServicer(object): @@ -67,6 +73,13 @@ def PullTaskRes(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetRun(self, request, context): + """Get run details + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { @@ -90,6 +103,11 @@ def add_DriverServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.FromString, response_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.SerializeToString, ), + 'GetRun': grpc.unary_unary_rpc_method_handler( + servicer.GetRun, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'flwr.proto.Driver', rpc_method_handlers) @@ -167,3 +185,20 @@ def PullTaskRes(request, flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetRun(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/GetRun', + flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 43cf45f39b25..4ff09db588ca 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import abc import flwr.proto.driver_pb2 +import flwr.proto.run_pb2 import grpc class DriverStub: @@ -28,6 +29,11 @@ class DriverStub: flwr.proto.driver_pb2.PullTaskResResponse] """Get task results""" + GetRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.GetRunRequest, + flwr.proto.run_pb2.GetRunResponse] + """Get run details""" + class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -62,5 +68,13 @@ class DriverServicer(metaclass=abc.ABCMeta): """Get task results""" pass + @abc.abstractmethod + def GetRun(self, + request: flwr.proto.run_pb2.GetRunRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.GetRunResponse: + """Get run details""" + pass + def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index 42f3292d910d..9763b71fed2f 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -14,9 +14,10 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,36 +26,30 @@ DESCRIPTOR._options = None _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001' - _globals['_CREATENODEREQUEST']._serialized_start=84 - _globals['_CREATENODEREQUEST']._serialized_end=126 - _globals['_CREATENODERESPONSE']._serialized_start=128 - _globals['_CREATENODERESPONSE']._serialized_end=180 - _globals['_DELETENODEREQUEST']._serialized_start=182 - _globals['_DELETENODEREQUEST']._serialized_end=233 - _globals['_DELETENODERESPONSE']._serialized_start=235 - _globals['_DELETENODERESPONSE']._serialized_end=255 - _globals['_PINGREQUEST']._serialized_start=257 - _globals['_PINGREQUEST']._serialized_end=325 - _globals['_PINGRESPONSE']._serialized_start=327 - _globals['_PINGRESPONSE']._serialized_end=358 - _globals['_PULLTASKINSREQUEST']._serialized_start=360 - _globals['_PULLTASKINSREQUEST']._serialized_end=430 - _globals['_PULLTASKINSRESPONSE']._serialized_start=432 - _globals['_PULLTASKINSRESPONSE']._serialized_end=539 - _globals['_PUSHTASKRESREQUEST']._serialized_start=541 - _globals['_PUSHTASKRESREQUEST']._serialized_end=605 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=608 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=782 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=736 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=782 - _globals['_RUN']._serialized_start=784 - _globals['_RUN']._serialized_end=842 - _globals['_GETRUNREQUEST']._serialized_start=844 - _globals['_GETRUNREQUEST']._serialized_end=875 - _globals['_GETRUNRESPONSE']._serialized_start=877 - _globals['_GETRUNRESPONSE']._serialized_end=923 - _globals['_RECONNECT']._serialized_start=925 - _globals['_RECONNECT']._serialized_end=955 - _globals['_FLEET']._serialized_start=958 - _globals['_FLEET']._serialized_end=1415 + _globals['_CREATENODEREQUEST']._serialized_start=106 + _globals['_CREATENODEREQUEST']._serialized_end=148 + _globals['_CREATENODERESPONSE']._serialized_start=150 + _globals['_CREATENODERESPONSE']._serialized_end=202 + _globals['_DELETENODEREQUEST']._serialized_start=204 + _globals['_DELETENODEREQUEST']._serialized_end=255 + _globals['_DELETENODERESPONSE']._serialized_start=257 + _globals['_DELETENODERESPONSE']._serialized_end=277 + _globals['_PINGREQUEST']._serialized_start=279 + _globals['_PINGREQUEST']._serialized_end=347 + _globals['_PINGRESPONSE']._serialized_start=349 + _globals['_PINGRESPONSE']._serialized_end=380 + _globals['_PULLTASKINSREQUEST']._serialized_start=382 + _globals['_PULLTASKINSREQUEST']._serialized_end=452 + _globals['_PULLTASKINSRESPONSE']._serialized_start=454 + _globals['_PULLTASKINSRESPONSE']._serialized_end=561 + _globals['_PUSHTASKRESREQUEST']._serialized_start=563 + _globals['_PUSHTASKRESREQUEST']._serialized_end=627 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=630 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=804 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=758 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=804 + _globals['_RECONNECT']._serialized_start=806 + _globals['_RECONNECT']._serialized_end=836 + _globals['_FLEET']._serialized_start=839 + _globals['_FLEET']._serialized_end=1296 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index a6f38b703e76..5989f45c5c60 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -164,48 +164,6 @@ class PushTaskResResponse(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ... global___PushTaskResResponse = PushTaskResResponse -class Run(google.protobuf.message.Message): - """GetRun messages""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - FAB_ID_FIELD_NUMBER: builtins.int - FAB_VERSION_FIELD_NUMBER: builtins.int - run_id: builtins.int - fab_id: typing.Text - fab_version: typing.Text - def __init__(self, - *, - run_id: builtins.int = ..., - fab_id: typing.Text = ..., - fab_version: typing.Text = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... -global___Run = Run - -class GetRunRequest(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - run_id: builtins.int - def __init__(self, - *, - run_id: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... -global___GetRunRequest = GetRunRequest - -class GetRunResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_FIELD_NUMBER: builtins.int - @property - def run(self) -> global___Run: ... - def __init__(self, - *, - run: typing.Optional[global___Run] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ... -global___GetRunResponse = GetRunResponse - class Reconnect(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor RECONNECT_FIELD_NUMBER: builtins.int diff --git a/src/py/flwr/proto/fleet_pb2_grpc.py b/src/py/flwr/proto/fleet_pb2_grpc.py index 16757eaed381..e0b0fbc50460 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.py +++ b/src/py/flwr/proto/fleet_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 class FleetStub(object): @@ -41,8 +42,8 @@ def __init__(self, channel): ) self.GetRun = channel.unary_unary( '/flwr.proto.Fleet/GetRun', - request_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString, + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, ) @@ -121,8 +122,8 @@ def add_FleetServicer_to_server(servicer, server): ), 'GetRun': grpc.unary_unary_rpc_method_handler( servicer.GetRun, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.SerializeToString, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -231,7 +232,7 @@ def GetRun(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/GetRun', - flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString, - flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString, + flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/fleet_pb2_grpc.pyi b/src/py/flwr/proto/fleet_pb2_grpc.pyi index f275cd149d69..1c0ab862d45c 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.pyi +++ b/src/py/flwr/proto/fleet_pb2_grpc.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import abc import flwr.proto.fleet_pb2 +import flwr.proto.run_pb2 import grpc class FleetStub: @@ -37,8 +38,8 @@ class FleetStub: """ GetRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.fleet_pb2.GetRunRequest, - flwr.proto.fleet_pb2.GetRunResponse] + flwr.proto.run_pb2.GetRunRequest, + flwr.proto.run_pb2.GetRunResponse] class FleetServicer(metaclass=abc.ABCMeta): @@ -84,9 +85,9 @@ class FleetServicer(metaclass=abc.ABCMeta): @abc.abstractmethod def GetRun(self, - request: flwr.proto.fleet_pb2.GetRunRequest, + request: flwr.proto.run_pb2.GetRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.fleet_pb2.GetRunResponse: ... + ) -> flwr.proto.run_pb2.GetRunResponse: ... def add_FleetServicer_to_server(servicer: FleetServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py new file mode 100644 index 000000000000..13f06e7169aa --- /dev/null +++ b/src/py/flwr/proto/run_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/run.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_RUN']._serialized_start=36 + _globals['_RUN']._serialized_end=94 + _globals['_GETRUNREQUEST']._serialized_start=96 + _globals['_GETRUNREQUEST']._serialized_end=127 + _globals['_GETRUNRESPONSE']._serialized_start=129 + _globals['_GETRUNRESPONSE']._serialized_end=175 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi new file mode 100644 index 000000000000..401d27855a41 --- /dev/null +++ b/src/py/flwr/proto/run_pb2.pyi @@ -0,0 +1,52 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class Run(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + FAB_ID_FIELD_NUMBER: builtins.int + FAB_VERSION_FIELD_NUMBER: builtins.int + run_id: builtins.int + fab_id: typing.Text + fab_version: typing.Text + def __init__(self, + *, + run_id: builtins.int = ..., + fab_id: typing.Text = ..., + fab_version: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... +global___Run = Run + +class GetRunRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int + def __init__(self, + *, + run_id: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___GetRunRequest = GetRunRequest + +class GetRunResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_FIELD_NUMBER: builtins.int + @property + def run(self) -> global___Run: ... + def __init__(self, + *, + run: typing.Optional[global___Run] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ... +global___GetRunResponse = GetRunResponse diff --git a/src/py/flwr/proto/run_pb2_grpc.py b/src/py/flwr/proto/run_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/run_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/run_pb2_grpc.pyi b/src/py/flwr/proto/run_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/run_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index ce2d9d68d8ca..e808616af778 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -35,6 +35,7 @@ PushTaskInsResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -129,6 +130,12 @@ def on_rpc_done() -> None: context.set_code(grpc.StatusCode.OK) return PullTaskResResponse(task_res_list=task_res_list) + def GetRun( + self, request: GetRunRequest, context: grpc.ServicerContext + ) -> GetRunResponse: + """Get run information.""" + raise NotImplementedError + def _raise_if(validation_error: bool, detail: str) -> None: if validation_error: diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 03a2ec064213..13e024eb31e4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -26,8 +26,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -35,6 +33,7 @@ PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.fleet.message_handler import message_handler from flwr.server.superlink.state import StateFactory diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 6a302679a235..21e9c44907cd 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -34,8 +34,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -44,6 +42,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.state import State _PUBLIC_KEY_HEADER = "public-key" diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index c4c71e5a8188..01499102b7d8 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -32,8 +32,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -42,6 +40,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere from flwr.server.superlink.state.state_factory import StateFactory diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 83b005a4cb8e..4c796502436b 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -24,8 +24,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -33,9 +31,13 @@ PushTaskResRequest, PushTaskResResponse, Reconnect, - Run, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + GetRunRequest, + GetRunResponse, + Run, +) from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index 8ac7c6cfc613..c7ff496d39bf 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -21,11 +21,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, - GetRunRequest, PingRequest, PullTaskInsRequest, PushTaskResRequest, ) +from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 from flwr.server.superlink.fleet.message_handler import message_handler from flwr.server.superlink.state import State diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 8dcf4c6474d6..6aec4251c384 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 8 + assert len(PROTO_FILES) == 9 From 019d490b2a823be4f12241e4c7fd6ea4b0e021df Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 11 Jun 2024 16:18:22 +0100 Subject: [PATCH 02/23] implement getrun in driver_servicer and grpc_driver --- src/py/flwr/server/driver/grpc_driver.py | 62 ++++++++++++++++--- src/py/flwr/server/driver/grpc_driver_test.py | 18 +++++- .../superlink/driver/driver_servicer.py | 16 ++++- 3 files changed, 83 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index d339f1b232f9..1cafbac4c5b7 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -37,6 +37,7 @@ ) from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from .driver import Driver @@ -101,6 +102,17 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse: res: CreateRunResponse = self.stub.CreateRun(request=req) return res + def get_run(self, req: GetRunRequest) -> GetRunResponse: + """Get run information.""" + # Check if channel is open + if self.stub is None: + log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) + raise ConnectionError("`GrpcDriverHelper` instance not connected") + + # Call gRPC Driver API + res: GetRunResponse = self.stub.GetRun(request=req) + return res + def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: """Get client IDs.""" # Check if channel is open @@ -157,39 +169,69 @@ class GrpcDriver(Driver): The version of the FAB used in the run. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, root_certificates: Optional[bytes] = None, fab_id: Optional[str] = None, fab_version: Optional[str] = None, + run_id: Optional[int] = None, ) -> None: self.addr = driver_service_address self.root_certificates = root_certificates self.driver_helper: Optional[GrpcDriverHelper] = None - self.run_id: Optional[int] = None - self.fab_id = fab_id if fab_id is not None else "" - self.fab_version = fab_version if fab_version is not None else "" + self._run_id = run_id + self._fab_id = fab_id if fab_id is not None else "" + self._fab_ver = fab_version if fab_version is not None else "" self.node = Node(node_id=0, anonymous=True) + @property + def run_id(self) -> int: + """Run ID.""" + _, run_id = self._get_grpc_driver_helper_and_run_id() + return run_id + + @property + def fab_id(self) -> str: + """FAB ID.""" + self._get_grpc_driver_helper_and_run_id() + return self._fab_id + + @property + def fab_version(self) -> str: + """FAB version.""" + self._get_grpc_driver_helper_and_run_id() + return self._fab_ver + def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: # Check if the GrpcDriverHelper is initialized - if self.driver_helper is None or self.run_id is None: + if self.driver_helper is None or self._run_id is None: # Connect and create run self.driver_helper = GrpcDriverHelper( driver_service_address=self.addr, root_certificates=self.root_certificates, ) self.driver_helper.connect() - req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version) - res = self.driver_helper.create_run(req) - self.run_id = res.run_id - return self.driver_helper, self.run_id + # Create the run if the run_id is not provided + if self._run_id is None: + create_run_req = CreateRunRequest( + fab_id=self._fab_id, fab_version=self._fab_ver + ) + create_run_res = self.driver_helper.create_run(create_run_req) + self._run_id = create_run_res.run_id + # Get the run if the run_id is provided + else: + get_run_req = GetRunRequest(run_id=self._run_id) + get_run_res = self.driver_helper.get_run(get_run_req) + self._fab_id = get_run_res.run.fab_id + self._fab_ver = get_run_res.run.fab_version + + return self.driver_helper, self._run_id def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == self.run_id + message.metadata.run_id == self._run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index fbead0e3043d..61971c6690bf 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -52,11 +52,27 @@ def tearDown(self) -> None: """Cleanup after each test.""" self.patcher.stop() + def test_get_run(self) -> None: + """Test the GrpcDriver starting with run_id.""" + # Prepare + self.driver._run_id = 61016 # pylint: disable=protected-access + mock_response = Mock() + mock_response.run = Mock() + mock_response.run.run_id = 61016 + mock_response.run.fab_id = "mock/mock" + mock_response.run.fab_version = "v1.0.0" + self.mock_grpc_driver_helper.get_run.return_value = mock_response + + # Assert + self.assertEqual(self.driver.run_id, 61016) + self.assertEqual(self.driver.fab_id, "mock/mock") + self.assertEqual(self.driver.fab_version, "v1.0.0") + def test_check_and_init_grpc_driver_already_initialized(self) -> None: """Test that GrpcDriverHelper doesn't initialize if run is created.""" # Prepare self.driver.driver_helper = self.mock_grpc_driver_helper - self.driver.run_id = 61016 + self.driver._run_id = 61016 # pylint: disable=protected-access # Execute # pylint: disable-next=protected-access diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index e808616af778..6b4b0b5d1b59 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -35,7 +35,11 @@ PushTaskInsResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + GetRunRequest, + GetRunResponse, + Run, +) from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -134,7 +138,15 @@ def GetRun( self, request: GetRunRequest, context: grpc.ServicerContext ) -> GetRunResponse: """Get run information.""" - raise NotImplementedError + log(DEBUG, "DriverServicer.GetRun") + + # Init state + state: State = self.state_factory.state() + + # Retrieve run information + run_id, fab_id, fab_version = state.get_run(request.run_id) + run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version) + return GetRunResponse(run=run) def _raise_if(validation_error: bool, detail: str) -> None: From 7c820821bb88cbec9c582ede3dbc04d24ea7ea59 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 18 Jun 2024 13:37:30 +0100 Subject: [PATCH 03/23] amend driver class and in mem driver --- src/py/flwr/server/compat/app_utils.py | 2 +- src/py/flwr/server/driver/driver.py | 15 ++++++ src/py/flwr/server/driver/inmemory_driver.py | 53 +++++++++++++------ .../server/driver/inmemory_driver_test.py | 32 +++++++---- src/py/flwr/simulation/run_simulation.py | 2 +- 5 files changed, 76 insertions(+), 28 deletions(-) diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index 1cdf1efbffb9..c0f5f2956eba 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -91,7 +91,7 @@ def _update_client_manager( node_id=node_id, driver=driver, anonymous=False, - run_id=driver.run_id, # type: ignore + run_id=driver.run_id, ) if client_manager.register(client_proxy): registered_nodes[node_id] = client_proxy diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index b95cec95ab47..4e7dd3d39ee3 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -24,6 +24,21 @@ class Driver(ABC): """Abstract base Driver class for the Driver API.""" + @property + @abstractmethod + def run_id(self) -> int: + """Run ID.""" + + @property + @abstractmethod + def fab_id(self) -> str: + """FAB ID.""" + + @property + @abstractmethod + def fab_version(self) -> str: + """FAB version.""" + @abstractmethod def create_message( # pylint: disable=too-many-arguments self, diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 8c71b1067293..5d40c58fd2ae 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,7 @@ import time import warnings -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, cast from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -46,10 +46,11 @@ def __init__( state_factory: StateFactory, fab_id: Optional[str] = None, fab_version: Optional[str] = None, + run_id: Optional[int] = None, ) -> None: - self.run_id: Optional[int] = None - self.fab_id = fab_id if fab_id is not None else "" - self.fab_version = fab_version if fab_version is not None else "" + self._run_id = run_id + self._fab_id = fab_id + self._fab_ver = fab_version self.node = Node(node_id=0, anonymous=True) self.state = state_factory.state() @@ -64,16 +65,36 @@ def _check_message(self, message: Message) -> None: ): raise ValueError(f"Invalid message: {message}") - def _get_run_id(self) -> int: - """Return run_id. - - If unset, create a new run. - """ - if self.run_id is None: - self.run_id = self.state.create_run( - fab_id=self.fab_id, fab_version=self.fab_version + def _init_run(self) -> None: + """Initialize the run.""" + # Run ID is not provided + if self._run_id is None: + self._fab_id = "" if self._fab_id is None else self._fab_id + self._fab_ver = "" if self._fab_ver is None else self._fab_ver + self._run_id = self.state.create_run( + fab_id=self._fab_id, fab_version=self._fab_ver ) - return self.run_id + # Run ID is provided + elif self._fab_id is None or self._fab_ver is None: + _, self._fab_id, self._fab_ver = self.state.get_run(self._run_id) + + @property + def run_id(self) -> int: + """Run ID.""" + self._init_run() + return cast(int, self._run_id) + + @property + def fab_id(self) -> str: + """FAB ID.""" + self._init_run() + return cast(str, self._fab_id) + + @property + def fab_version(self) -> str: + """FAB version.""" + self._init_run() + return cast(str, self._fab_ver) def create_message( # pylint: disable=too-many-arguments self, @@ -88,7 +109,6 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ - run_id = self._get_run_id() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -99,7 +119,7 @@ def create_message( # pylint: disable=too-many-arguments ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( - run_id=run_id, + run_id=self.run_id, message_id="", # Will be set by the server src_node_id=self.node.node_id, dst_node_id=dst_node_id, @@ -112,8 +132,7 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - run_id = self._get_run_id() - return list(self.state.get_nodes(run_id)) + return list(self.state.get_nodes(self.run_id)) def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs. diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 95c2a0b277af..5aaa80934b08 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -79,12 +79,26 @@ def setUp(self) -> None: """ # Create driver self.num_nodes = 42 - self.driver = InMemoryDriver(StateFactory("")) - self.driver.state = MagicMock() - self.driver.state.get_nodes.return_value = [ + self.state = MagicMock() + self.state.get_nodes.return_value = [ int.from_bytes(os.urandom(8), "little", signed=True) for _ in range(self.num_nodes) ] + state_factory = MagicMock() + state_factory.state.return_value = self.state + self.driver = InMemoryDriver(state_factory) + self.driver.state = self.state + + def test_get_run(self) -> None: + """Test the InMemoryDriver starting with run_id.""" + # Prepare + self.driver._run_id = 61016 # pylint: disable=protected-access + self.state.get_run.return_value = (61016, "mock/mock", "v1.0.0") + + # Assert + self.assertEqual(self.driver.run_id, 61016) + self.assertEqual(self.driver.fab_id, "mock/mock") + self.assertEqual(self.driver.fab_version, "v1.0.0") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -104,7 +118,7 @@ def test_push_messages_valid(self) -> None: ] taskins_ids = [uuid4() for _ in range(num_messages)] - self.driver.state.store_task_ins.side_effect = taskins_ids # type: ignore + self.state.store_task_ins.side_effect = taskins_ids # Execute msg_ids = list(self.driver.push_messages(msgs)) @@ -141,7 +155,7 @@ def test_pull_messages_with_given_message_ids(self) -> None: task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0))) ), ] - self.driver.state.get_task_res.return_value = task_res_list # type: ignore + self.state.get_task_res.return_value = task_res_list # Execute pulled_msgs = list(self.driver.pull_messages(msg_ids)) @@ -167,8 +181,8 @@ def test_send_and_receive_messages_complete(self) -> None: task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0))) ), ] - self.driver.state.store_task_ins.side_effect = msg_ids # type: ignore - self.driver.state.get_task_res.return_value = task_res_list # type: ignore + self.state.store_task_ins.side_effect = msg_ids + self.state.get_task_res.return_value = task_res_list # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) @@ -193,8 +207,8 @@ def test_send_and_receive_messages_timeout(self) -> None: task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0))) ), ] - self.driver.state.store_task_ins.side_effect = msg_ids # type: ignore - self.driver.state.get_task_res.return_value = task_res_list # type: ignore + self.state.store_task_ins.side_effect = msg_ids + self.state.get_task_res.return_value = task_res_list # Execute with patch("time.sleep", side_effect=lambda t: time.sleep(t * 0.01)): diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 3532c5a4e877..3e5eb266c89a 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -173,7 +173,7 @@ def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> No """Create a run with a given `run_id`.""" log(DEBUG, "Pre-registering run with id %s", run_id) state.state().run_ids[run_id] = ("", "") # type: ignore - driver.run_id = run_id + driver._run_id = run_id # pylint: disable=protected-access # pylint: disable=too-many-locals From 6ef8cebe2c9019c761fccfd738466153bda12d95 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 12:33:45 +0100 Subject: [PATCH 04/23] update with main --- src/py/flwr/server/compat/app_utils.py | 2 +- src/py/flwr/server/driver/driver.py | 15 ++------- src/py/flwr/server/driver/grpc_driver.py | 25 ++++++-------- src/py/flwr/server/driver/grpc_driver_test.py | 8 ++--- src/py/flwr/server/driver/inmemory_driver.py | 33 +++++++++---------- .../server/driver/inmemory_driver_test.py | 6 ++-- .../superlink/driver/driver_servicer.py | 6 ++-- 7 files changed, 39 insertions(+), 56 deletions(-) diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index c0f5f2956eba..baff27307b88 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -91,7 +91,7 @@ def _update_client_manager( node_id=node_id, driver=driver, anonymous=False, - run_id=driver.run_id, + run_id=driver.run.run_id, ) if client_manager.register(client_proxy): registered_nodes[node_id] = client_proxy diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 4e7dd3d39ee3..4f888323e586 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -19,6 +19,7 @@ from typing import Iterable, List, Optional from flwr.common import Message, RecordSet +from flwr.common.typing import Run class Driver(ABC): @@ -26,18 +27,8 @@ class Driver(ABC): @property @abstractmethod - def run_id(self) -> int: - """Run ID.""" - - @property - @abstractmethod - def fab_id(self) -> str: - """FAB ID.""" - - @property - @abstractmethod - def fab_version(self) -> str: - """FAB version.""" + def run(self) -> Run: + """Run information.""" @abstractmethod def create_message( # pylint: disable=too-many-arguments diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 1cafbac4c5b7..f2a43f0e2c8e 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -25,6 +25,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.common.typing import Run from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, CreateRunResponse, @@ -186,22 +187,14 @@ def __init__( # pylint: disable=too-many-arguments self.node = Node(node_id=0, anonymous=True) @property - def run_id(self) -> int: - """Run ID.""" + def run(self) -> Run: + """Run information.""" _, run_id = self._get_grpc_driver_helper_and_run_id() - return run_id - - @property - def fab_id(self) -> str: - """FAB ID.""" - self._get_grpc_driver_helper_and_run_id() - return self._fab_id - - @property - def fab_version(self) -> str: - """FAB version.""" - self._get_grpc_driver_helper_and_run_id() - return self._fab_ver + return Run( + run_id=run_id, + fab_id=self._fab_id, + fab_version=self._fab_ver, + ) def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: # Check if the GrpcDriverHelper is initialized @@ -223,6 +216,8 @@ def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: else: get_run_req = GetRunRequest(run_id=self._run_id) get_run_res = self.driver_helper.get_run(get_run_req) + if not get_run_res.HasField("run"): + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") self._fab_id = get_run_res.run.fab_id self._fab_ver = get_run_res.run.fab_version diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 61971c6690bf..642fdbe9d8ab 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -64,9 +64,9 @@ def test_get_run(self) -> None: self.mock_grpc_driver_helper.get_run.return_value = mock_response # Assert - self.assertEqual(self.driver.run_id, 61016) - self.assertEqual(self.driver.fab_id, "mock/mock") - self.assertEqual(self.driver.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.run_id, 61016) + self.assertEqual(self.driver.run.fab_id, "mock/mock") + self.assertEqual(self.driver.run.fab_version, "v1.0.0") def test_check_and_init_grpc_driver_already_initialized(self) -> None: """Test that GrpcDriverHelper doesn't initialize if run is created.""" @@ -89,7 +89,7 @@ def test_check_and_init_grpc_driver_needs_initialization(self) -> None: # Assert self.mock_grpc_driver_helper.connect.assert_called_once() - self.assertEqual(self.driver.run_id, 61016) + self.assertEqual(self.driver.run.run_id, 61016) def test_get_nodes(self) -> None: """Test retrieval of nodes.""" diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 5d40c58fd2ae..eda4051b5a70 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -22,6 +22,7 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.server.superlink.state import StateFactory @@ -57,7 +58,7 @@ def __init__( def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == self.run_id + message.metadata.run_id == self.run.run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -76,25 +77,21 @@ def _init_run(self) -> None: ) # Run ID is provided elif self._fab_id is None or self._fab_ver is None: - _, self._fab_id, self._fab_ver = self.state.get_run(self._run_id) + run = self.state.get_run(self._run_id) + if run is None: + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._fab_id = run.fab_id + self._fab_ver = run.fab_version @property - def run_id(self) -> int: + def run(self) -> Run: """Run ID.""" self._init_run() - return cast(int, self._run_id) - - @property - def fab_id(self) -> str: - """FAB ID.""" - self._init_run() - return cast(str, self._fab_id) - - @property - def fab_version(self) -> str: - """FAB version.""" - self._init_run() - return cast(str, self._fab_ver) + return Run( + run_id=cast(int, self._run_id), + fab_id=cast(str, self._fab_id), + fab_version=cast(str, self._fab_ver), + ) def create_message( # pylint: disable=too-many-arguments self, @@ -119,7 +116,7 @@ def create_message( # pylint: disable=too-many-arguments ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( - run_id=self.run_id, + run_id=self.run.run_id, message_id="", # Will be set by the server src_node_id=self.node.node_id, dst_node_id=dst_node_id, @@ -132,7 +129,7 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - return list(self.state.get_nodes(self.run_id)) + return list(self.state.get_nodes(self.run.run_id)) def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs. diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 5aaa80934b08..a41a7c57236e 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -96,9 +96,9 @@ def test_get_run(self) -> None: self.state.get_run.return_value = (61016, "mock/mock", "v1.0.0") # Assert - self.assertEqual(self.driver.run_id, 61016) - self.assertEqual(self.driver.fab_id, "mock/mock") - self.assertEqual(self.driver.fab_version, "v1.0.0") + self.assertEqual(self.driver.run.run_id, 61016) + self.assertEqual(self.driver.run.fab_id, "mock/mock") + self.assertEqual(self.driver.run.fab_version, "v1.0.0") def test_get_nodes(self) -> None: """Test retrieval of nodes.""" diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 6b4b0b5d1b59..30d2e883dc63 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -144,9 +144,9 @@ def GetRun( state: State = self.state_factory.state() # Retrieve run information - run_id, fab_id, fab_version = state.get_run(request.run_id) - run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version) - return GetRunResponse(run=run) + run = state.get_run(request.run_id) + run_proto = None if run is None else Run(**vars(run)) + return GetRunResponse(run=run_proto) def _raise_if(validation_error: bool, detail: str) -> None: From bf10f81f5693e7e779494f25e667f705a3ee43d0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 12:39:22 +0100 Subject: [PATCH 05/23] fix the test for in mem driver --- src/py/flwr/server/driver/inmemory_driver_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index a41a7c57236e..8bfa14def0c2 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -93,7 +93,9 @@ def test_get_run(self) -> None: """Test the InMemoryDriver starting with run_id.""" # Prepare self.driver._run_id = 61016 # pylint: disable=protected-access - self.state.get_run.return_value = (61016, "mock/mock", "v1.0.0") + self.state.get_run.return_value = MagicMock( + run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" + ) # Assert self.assertEqual(self.driver.run.run_id, 61016) From 5461e5d26f8e8e9baad3e1106b191e3762369735 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 15:27:03 +0100 Subject: [PATCH 06/23] make run_id mandatory --- src/py/flwr/server/driver/grpc_driver.py | 124 +++++++----------- src/py/flwr/server/driver/grpc_driver_test.py | 93 ++++--------- src/py/flwr/server/driver/inmemory_driver.py | 39 +++--- .../server/driver/inmemory_driver_test.py | 41 +++--- src/py/flwr/server/run_serverapp.py | 18 ++- src/py/flwr/simulation/run_simulation.py | 5 +- 6 files changed, 131 insertions(+), 189 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index f2a43f0e2c8e..cb5a617164ff 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,11 +16,9 @@ import time import warnings -from logging import DEBUG, ERROR, WARNING +from logging import DEBUG, ERROR from typing import Iterable, List, Optional, Tuple -import grpc - from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event from flwr.common.grpc import create_channel from flwr.common.logger import log @@ -48,103 +46,94 @@ ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ [Driver] Error: Not connected. -Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other -`GrpcDriverHelper` methods. +Call `connect()` on the `GrpcDriverStub` instance before calling any of the other +`GrpcDriverStub` methods. """ -class GrpcDriverHelper: - """`GrpcDriverHelper` provides access to the gRPC Driver API/service.""" +class GrpcDriverStub(DriverStub): + """`GrpcDriverStub` provides access to the gRPC Driver API/service.""" def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, root_certificates: Optional[bytes] = None, ) -> None: + event(EventType.DRIVER_CONNECT) self.driver_service_address = driver_service_address self.root_certificates = root_certificates - self.channel: Optional[grpc.Channel] = None - self.stub: Optional[DriverStub] = None - - def connect(self) -> None: - """Connect to the Driver API.""" - event(EventType.DRIVER_CONNECT) - if self.channel is not None or self.stub is not None: - log(WARNING, "Already connected") - return self.channel = create_channel( server_address=self.driver_service_address, insecure=(self.root_certificates is None), root_certificates=self.root_certificates, ) - self.stub = DriverStub(self.channel) + super().__init__(self.channel) log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) def disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) - if self.channel is None or self.stub is None: + if self.channel is None: log(DEBUG, "Already disconnected") return channel = self.channel self.channel = None - self.stub = None channel.close() log(DEBUG, "[Driver] Disconnected") def create_run(self, req: CreateRunRequest) -> CreateRunResponse: """Request for run ID.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API - res: CreateRunResponse = self.stub.CreateRun(request=req) + res: CreateRunResponse = self.CreateRun(request=req) return res def get_run(self, req: GetRunRequest) -> GetRunResponse: """Get run information.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: GetRunResponse = self.stub.GetRun(request=req) + res: GetRunResponse = self.GetRun(request=req) return res def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: """Get client IDs.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: GetNodesResponse = self.stub.GetNodes(request=req) + res: GetNodesResponse = self.GetNodes(request=req) return res def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: """Schedule tasks.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) + res: PushTaskInsResponse = self.PushTaskIns(request=req) return res def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: """Get task results.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API - res: PullTaskResResponse = self.stub.PullTaskRes(request=req) + res: PullTaskResResponse = self.PullTaskRes(request=req) return res @@ -172,18 +161,14 @@ class GrpcDriver(Driver): def __init__( # pylint: disable=too-many-arguments self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - root_certificates: Optional[bytes] = None, - fab_id: Optional[str] = None, - fab_version: Optional[str] = None, - run_id: Optional[int] = None, + run_id: int, + stub: Optional[GrpcDriverStub] = None, ) -> None: - self.addr = driver_service_address - self.root_certificates = root_certificates - self.driver_helper: Optional[GrpcDriverHelper] = None + self.stub = stub self._run_id = run_id - self._fab_id = fab_id if fab_id is not None else "" - self._fab_ver = fab_version if fab_version is not None else "" + self._fab_id = "" + self._fab_ver = "" + self._has_initialized = False self.node = Node(node_id=0, anonymous=True) @property @@ -196,32 +181,23 @@ def run(self) -> Run: fab_version=self._fab_ver, ) - def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: - # Check if the GrpcDriverHelper is initialized - if self.driver_helper is None or self._run_id is None: + def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]: + # Check if the GrpcDriverStub is initialized + if not self._has_initialized or self.stub is None: # Connect and create run - self.driver_helper = GrpcDriverHelper( - driver_service_address=self.addr, - root_certificates=self.root_certificates, - ) - self.driver_helper.connect() - # Create the run if the run_id is not provided - if self._run_id is None: - create_run_req = CreateRunRequest( - fab_id=self._fab_id, fab_version=self._fab_ver - ) - create_run_res = self.driver_helper.create_run(create_run_req) - self._run_id = create_run_res.run_id - # Get the run if the run_id is provided - else: - get_run_req = GetRunRequest(run_id=self._run_id) - get_run_res = self.driver_helper.get_run(get_run_req) - if not get_run_res.HasField("run"): - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._fab_id = get_run_res.run.fab_id - self._fab_ver = get_run_res.run.fab_version - - return self.driver_helper, self._run_id + if self.stub is None: + self.stub = GrpcDriverStub() + + # Get the run info + req = GetRunRequest(run_id=self._run_id) + res = self.stub.get_run(req) + if not res.HasField("run"): + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._fab_id = res.run.fab_id + self._fab_ver = res.run.fab_version + self._has_initialized = True + + return self.stub, self._run_id def _check_message(self, message: Message) -> None: # Check if the message is valid @@ -272,7 +248,7 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id() - # Call GrpcDriverHelper method + # Call GrpcDriverStub method res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) return [node.node_id for node in res.nodes] @@ -292,7 +268,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: taskins = message_to_taskins(msg) # Add to list task_ins_list.append(taskins) - # Call GrpcDriverHelper method + # Call GrpcDriverStub method res = grpc_driver_helper.push_task_ins( PushTaskInsRequest(task_ins_list=task_ins_list) ) @@ -345,8 +321,8 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" - # Check if GrpcDriverHelper is initialized - if self.driver_helper is None: + # Check if GrpcDriverStub is initialized + if self.stub is None: return # Disconnect - self.driver_helper.disconnect() + self.stub.disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 642fdbe9d8ab..775348437da6 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -36,74 +36,36 @@ class TestGrpcDriver(unittest.TestCase): """Tests for `GrpcDriver` class.""" def setUp(self) -> None: - """Initialize mock GrpcDriverHelper and Driver instance before each test.""" - mock_response = Mock() - mock_response.run_id = 61016 - self.mock_grpc_driver_helper = Mock() - self.mock_grpc_driver_helper.create_run.return_value = mock_response - self.patcher = patch( - "flwr.server.driver.grpc_driver.GrpcDriverHelper", - return_value=self.mock_grpc_driver_helper, + """Initialize mock GrpcDriverStub and Driver instance before each test.""" + mock_response = Mock( + run=Mock(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") ) - self.patcher.start() - self.driver = GrpcDriver() - - def tearDown(self) -> None: - """Cleanup after each test.""" - self.patcher.stop() - - def test_get_run(self) -> None: - """Test the GrpcDriver starting with run_id.""" - # Prepare - self.driver._run_id = 61016 # pylint: disable=protected-access - mock_response = Mock() - mock_response.run = Mock() - mock_response.run.run_id = 61016 - mock_response.run.fab_id = "mock/mock" - mock_response.run.fab_version = "v1.0.0" - self.mock_grpc_driver_helper.get_run.return_value = mock_response + self.mock_grpc_driver_stub = Mock() + self.mock_grpc_driver_stub.get_run.return_value = mock_response + self.mock_grpc_driver_stub.HasField.return_value = True + self.driver = GrpcDriver(run_id=61016, stub=self.mock_grpc_driver_stub) + def test_init_grpc_driver(self) -> None: + """Test GrpcDriverStub initialization.""" # Assert self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") - - def test_check_and_init_grpc_driver_already_initialized(self) -> None: - """Test that GrpcDriverHelper doesn't initialize if run is created.""" - # Prepare - self.driver.driver_helper = self.mock_grpc_driver_helper - self.driver._run_id = 61016 # pylint: disable=protected-access - - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() - - # Assert - self.mock_grpc_driver_helper.connect.assert_not_called() - - def test_check_and_init_grpc_driver_needs_initialization(self) -> None: - """Test GrpcDriverHelper initialization when run is not created.""" - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() - - # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() - self.assertEqual(self.driver.run.run_id, 61016) + self.mock_grpc_driver_stub.get_run.assert_called_once() def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] - self.mock_grpc_driver_helper.get_nodes.return_value = mock_response + self.mock_grpc_driver_stub.get_nodes.return_value = mock_response # Execute node_ids = self.driver.get_node_ids() - args, kwargs = self.mock_grpc_driver_helper.get_nodes.call_args + args, kwargs = self.mock_grpc_driver_stub.get_nodes.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) @@ -114,7 +76,7 @@ def test_push_messages_valid(self) -> None: """Test pushing valid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -122,10 +84,10 @@ def test_push_messages_valid(self) -> None: # Execute msg_ids = self.driver.push_messages(msgs) - args, kwargs = self.mock_grpc_driver_helper.push_task_ins.call_args + args, kwargs = self.mock_grpc_driver_stub.push_task_ins.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) @@ -137,7 +99,7 @@ def test_push_messages_invalid(self) -> None: """Test pushing invalid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -161,16 +123,16 @@ def test_pull_messages_with_given_message_ids(self) -> None: ), TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msg_ids = ["id1", "id2", "id3"] # Execute msgs = self.driver.pull_messages(msg_ids) reply_tos = {msg.metadata.reply_to_message for msg in msgs} - args, kwargs = self.mock_grpc_driver_helper.pull_task_res.call_args + args, kwargs = self.mock_grpc_driver_stub.pull_task_res.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) @@ -181,14 +143,14 @@ def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response # The response message must include either `content` (i.e. a recordset) or # an `Error`. We choose the latter in this case error_proto = error_to_proto(Error(code=0)) mock_response = Mock( task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -203,9 +165,9 @@ def test_send_and_receive_messages_timeout(self) -> None: # Prepare sleep_fn = time.sleep mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response mock_response = Mock(task_res_list=[]) - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -227,12 +189,15 @@ def test_del_with_initialized_driver(self) -> None: self.driver.close() # Assert - self.mock_grpc_driver_helper.disconnect.assert_called_once() + self.mock_grpc_driver_stub.disconnect.assert_called_once() def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" + # Prepare + self.driver.stub = None + # Execute self.driver.close() # Assert - self.mock_grpc_driver_helper.disconnect.assert_not_called() + self.mock_grpc_driver_stub.disconnect.assert_not_called() diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index eda4051b5a70..96dc8a2a5716 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,7 @@ import time import warnings -from typing import Iterable, List, Optional, cast +from typing import Iterable, List, Optional from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -44,14 +44,13 @@ class InMemoryDriver(Driver): def __init__( self, + run_id: int, state_factory: StateFactory, - fab_id: Optional[str] = None, - fab_version: Optional[str] = None, - run_id: Optional[int] = None, ) -> None: self._run_id = run_id - self._fab_id = fab_id - self._fab_ver = fab_version + self._fab_id = "" + self._fab_ver = "" + self._has_initialized = False self.node = Node(node_id=0, anonymous=True) self.state = state_factory.state() @@ -68,29 +67,23 @@ def _check_message(self, message: Message) -> None: def _init_run(self) -> None: """Initialize the run.""" - # Run ID is not provided - if self._run_id is None: - self._fab_id = "" if self._fab_id is None else self._fab_id - self._fab_ver = "" if self._fab_ver is None else self._fab_ver - self._run_id = self.state.create_run( - fab_id=self._fab_id, fab_version=self._fab_ver - ) - # Run ID is provided - elif self._fab_id is None or self._fab_ver is None: - run = self.state.get_run(self._run_id) - if run is None: - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._fab_id = run.fab_id - self._fab_ver = run.fab_version + if self._has_initialized: + return + run = self.state.get_run(self._run_id) + if run is None: + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._fab_id = run.fab_id + self._fab_ver = run.fab_version + self._has_initialized = True @property def run(self) -> Run: """Run ID.""" self._init_run() return Run( - run_id=cast(int, self._run_id), - fab_id=cast(str, self._fab_id), - fab_version=cast(str, self._fab_ver), + run_id=self._run_id, + fab_id=self._fab_id, + fab_version=self._fab_ver, ) def create_message( # pylint: disable=too-many-arguments diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 8bfa14def0c2..1f457decf228 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -32,7 +32,7 @@ recordset_to_proto, ) from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory from .inmemory_driver import InMemoryDriver @@ -84,19 +84,15 @@ def setUp(self) -> None: int.from_bytes(os.urandom(8), "little", signed=True) for _ in range(self.num_nodes) ] - state_factory = MagicMock() - state_factory.state.return_value = self.state - self.driver = InMemoryDriver(state_factory) - self.driver.state = self.state - - def test_get_run(self) -> None: - """Test the InMemoryDriver starting with run_id.""" - # Prepare - self.driver._run_id = 61016 # pylint: disable=protected-access self.state.get_run.return_value = MagicMock( run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" ) + state_factory = MagicMock(state=lambda: self.state) + self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) + self.driver.state = self.state + def test_get_run(self) -> None: + """Test the InMemoryDriver starting with run_id.""" # Assert self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") @@ -224,19 +220,23 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: """Test tasks are deleted in sqlite state once messages are pulled.""" # Prepare - self.driver = InMemoryDriver(StateFactory("")) + state = StateFactory("").state() + self.driver = InMemoryDriver( + state.create_run("", ""), MagicMock(state=lambda: state) + ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) + assert isinstance(state, SqliteState) # Check recorded - task_ins = self.driver.state.query("SELECT * FROM task_ins;") # type: ignore + task_ins = state.query("SELECT * FROM task_ins;") self.assertEqual(len(task_ins), len(list(msg_ids))) # Prepare: create replies reply_tos = get_replies(self.driver, msg_ids, node_id) # Query number of task_ins and task_res in State - task_res = self.driver.state.query("SELECT * FROM task_res;") # type: ignore - task_ins = self.driver.state.query("SELECT * FROM task_ins;") # type: ignore + task_res = state.query("SELECT * FROM task_res;") + task_ins = state.query("SELECT * FROM task_ins;") # Assert self.assertEqual(reply_tos, msg_ids) @@ -246,18 +246,19 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: """Test tasks are deleted in in-memory state once messages are pulled.""" # Prepare - self.driver = InMemoryDriver(StateFactory(":flwr-in-memory-state:")) + state_factory = StateFactory(":flwr-in-memory-state:") + state = state_factory.state() + self.driver = InMemoryDriver(state.create_run("", ""), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) + assert isinstance(state, InMemoryState) # Check recorded - self.assertEqual( - len(self.driver.state.task_ins_store), len(list(msg_ids)) # type: ignore - ) + self.assertEqual(len(state.task_ins_store), len(list(msg_ids))) # Prepare: create replies reply_tos = get_replies(self.driver, msg_ids, node_id) # Assert self.assertEqual(reply_tos, msg_ids) - self.assertEqual(len(self.driver.state.task_res_store), 0) # type: ignore - self.assertEqual(len(self.driver.state.task_ins_store), 0) # type: ignore + self.assertEqual(len(state.task_res_store), 0) + self.assertEqual(len(state.task_ins_store), 0) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index fd0214a040bc..25a523229f52 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -24,8 +24,10 @@ from flwr.common import Context, EventType, RecordSet, event from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app +from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 -from .driver import Driver, GrpcDriver +from .driver import Driver +from .driver.grpc_driver import GrpcDriver, GrpcDriverStub from .server_app import LoadServerAppError, ServerApp ADDRESS_DRIVER_API = "0.0.0.0:9091" @@ -147,13 +149,15 @@ def run_server_app() -> None: server_app_dir = args.dir server_app_attr = getattr(args, "server-app") - # Initialize GrpcDriver - driver = GrpcDriver( - driver_service_address=args.superlink, - root_certificates=root_certificates, - fab_id=args.fab_id, - fab_version=args.fab_version, + # Create run + stub = GrpcDriverStub( + driver_service_address=args.superlink, root_certificates=root_certificates ) + req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version) + res = stub.create_run(req) + + # Initialize GrpcDriver + driver = GrpcDriver(run_id=res.run_id, stub=stub) # Run the ServerApp with the Driver run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 3e5eb266c89a..6785f3ac38b6 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -201,8 +201,11 @@ def _main_loop( f_stop = asyncio.Event() serverapp_th = None try: + # Create run (with empty fab_id and fab_version) + run_id = state_factory.state().create_run("", "") + # Initialize Driver - driver = InMemoryDriver(state_factory) + driver = InMemoryDriver(run_id=run_id, state_factory=state_factory) if run_id: _init_run_id(driver, state_factory, run_id) From dcc6cfd9da47d8f89547d18aa9584cada486c7e6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 16:05:50 +0100 Subject: [PATCH 07/23] fix a bug in _init_run_id in simulation --- src/py/flwr/simulation/run_simulation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 6785f3ac38b6..4ed6360d5c77 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -21,6 +21,7 @@ import threading import traceback from logging import DEBUG, ERROR, INFO, WARNING +from flwr.common.typing import Run from time import sleep from typing import Dict, Optional @@ -172,7 +173,7 @@ def server_th_with_start_checks( # type: ignore def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> None: """Create a run with a given `run_id`.""" log(DEBUG, "Pre-registering run with id %s", run_id) - state.state().run_ids[run_id] = ("", "") # type: ignore + state.state().run_ids[run_id] = Run(run_id, "", "") # type: ignore driver._run_id = run_id # pylint: disable=protected-access From dc033864f987352dd0f69d2c37d65449e1daf7d6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 16:07:55 +0100 Subject: [PATCH 08/23] format --- src/py/flwr/simulation/run_simulation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 4ed6360d5c77..7dc12a5afac0 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -21,14 +21,13 @@ import threading import traceback from logging import DEBUG, ERROR, INFO, WARNING -from flwr.common.typing import Run from time import sleep from typing import Dict, Optional from flwr.client import ClientApp from flwr.common import EventType, event, log from flwr.common.logger import set_logger_propagation, update_console_handler -from flwr.common.typing import ConfigsRecordValues +from flwr.common.typing import ConfigsRecordValues, Run from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run from flwr.server.server_app import ServerApp From 5aa7ced008ad00cfc3f579139eca2c33f9252bd1 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 17:34:46 +0100 Subject: [PATCH 09/23] update doc string --- src/py/flwr/server/driver/grpc_driver.py | 9 +++++---- src/py/flwr/server/driver/inmemory_driver.py | 6 ++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index cb5a617164ff..b78824107d64 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -153,10 +153,11 @@ class GrpcDriver(Driver): * CA certificate. * server certificate. * server private key. - fab_id : str (default: None) - The identifier of the FAB used in the run. - fab_version : str (default: None) - The version of the FAB used in the run. + run_id : int + The identifier of the run. + stub : Optional[GrpcDriverStub] (default: None) + The ``GrpcDriverStub`` instance used to communicate with the SuperLink. + If None, an instance connected to "[::]:9091" will be created. """ def __init__( # pylint: disable=too-many-arguments diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 96dc8a2a5716..caaaaa5ae9e5 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -34,12 +34,10 @@ class InMemoryDriver(Driver): Parameters ---------- + run_id : int + The identifier of the run. state_factory : StateFactory A StateFactory embedding a state that this driver can interface with. - fab_id : str (default: None) - The identifier of the FAB used in the run. - fab_version : str (default: None) - The version of the FAB used in the run. """ def __init__( From 084e22cabb57952357f01f28c98e384876eef230 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 21:07:45 +0100 Subject: [PATCH 10/23] update doc string --- src/py/flwr/server/driver/grpc_driver.py | 28 ++++++++++++++---------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index b78824107d64..4594592d7133 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -52,7 +52,22 @@ class GrpcDriverStub(DriverStub): - """`GrpcDriverStub` provides access to the gRPC Driver API/service.""" + """`GrpcDriverStub` provides access to the gRPC Driver API/service. + + Parameters + ---------- + driver_service_address : Optional[str] + The IPv4 or IPv6 address of the Driver API server. + Defaults to `"[::]:9091"`. + certificates : bytes (default: None) + Tuple containing root certificate, server certificate, and private key + to start a secure SSL-enabled server. The tuple is expected to have + three bytes elements in the following order: + + * CA certificate. + * server certificate. + * server private key. + """ def __init__( self, @@ -142,17 +157,6 @@ class GrpcDriver(Driver): Parameters ---------- - driver_service_address : Optional[str] - The IPv4 or IPv6 address of the Driver API server. - Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. run_id : int The identifier of the run. stub : Optional[GrpcDriverStub] (default: None) From db1640885126a219322f738d29bdcc98ce560b50 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 19 Jun 2024 22:25:49 +0100 Subject: [PATCH 11/23] update GrpcDriverStub --- src/py/flwr/server/driver/grpc_driver.py | 64 ++++++++++--------- src/py/flwr/server/driver/grpc_driver_test.py | 3 +- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 4594592d7133..c773fbd8ba1d 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,9 +16,11 @@ import time import warnings -from logging import DEBUG, ERROR +from logging import DEBUG, ERROR, WARNING from typing import Iterable, List, Optional, Tuple +import grpc + from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event from flwr.common.grpc import create_channel from flwr.common.logger import log @@ -51,7 +53,7 @@ """ -class GrpcDriverStub(DriverStub): +class GrpcDriverStub: """`GrpcDriverStub` provides access to the gRPC Driver API/service. Parameters @@ -74,81 +76,90 @@ def __init__( driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, root_certificates: Optional[bytes] = None, ) -> None: - event(EventType.DRIVER_CONNECT) self.driver_service_address = driver_service_address self.root_certificates = root_certificates + self.channel: Optional[grpc.Channel] = None + self.stub: Optional[DriverStub] = None + + def connect(self) -> None: + """Connect to the Driver API.""" + event(EventType.DRIVER_CONNECT) + if self.channel is not None or self.stub is not None: + log(WARNING, "Already connected") + return self.channel = create_channel( server_address=self.driver_service_address, insecure=(self.root_certificates is None), root_certificates=self.root_certificates, ) - super().__init__(self.channel) + self.stub = DriverStub(self.channel) log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) def disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) - if self.channel is None: + if self.channel is None or self.stub is None: log(DEBUG, "Already disconnected") return channel = self.channel self.channel = None + self.stub = None channel.close() log(DEBUG, "[Driver] Disconnected") def create_run(self, req: CreateRunRequest) -> CreateRunResponse: """Request for run ID.""" # Check if channel is open - if self.channel is None: + if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API - res: CreateRunResponse = self.CreateRun(request=req) + res: CreateRunResponse = self.stub.CreateRun(request=req) return res def get_run(self, req: GetRunRequest) -> GetRunResponse: """Get run information.""" # Check if channel is open - if self.channel is None: + if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: GetRunResponse = self.GetRun(request=req) + res: GetRunResponse = self.stub.GetRun(request=req) return res def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: """Get client IDs.""" # Check if channel is open - if self.channel is None: + if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: GetNodesResponse = self.GetNodes(request=req) + res: GetNodesResponse = self.stub.GetNodes(request=req) return res def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: """Schedule tasks.""" # Check if channel is open - if self.channel is None: + if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: PushTaskInsResponse = self.PushTaskIns(request=req) + res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) return res def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: """Get task results.""" # Check if channel is open - if self.channel is None: + if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API - res: PullTaskResResponse = self.PullTaskRes(request=req) + res: PullTaskResResponse = self.stub.PullTaskRes(request=req) return res @@ -170,21 +181,15 @@ def __init__( # pylint: disable=too-many-arguments stub: Optional[GrpcDriverStub] = None, ) -> None: self.stub = stub - self._run_id = run_id - self._fab_id = "" - self._fab_ver = "" + self._run = Run(run_id=run_id, fab_id="", fab_version="") self._has_initialized = False self.node = Node(node_id=0, anonymous=True) @property def run(self) -> Run: """Run information.""" - _, run_id = self._get_grpc_driver_helper_and_run_id() - return Run( - run_id=run_id, - fab_id=self._fab_id, - fab_version=self._fab_ver, - ) + self._get_grpc_driver_helper_and_run_id() + return Run(**vars(self._run)) def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]: # Check if the GrpcDriverStub is initialized @@ -194,20 +199,19 @@ def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]: self.stub = GrpcDriverStub() # Get the run info - req = GetRunRequest(run_id=self._run_id) + req = GetRunRequest(run_id=self._run.run_id) res = self.stub.get_run(req) if not res.HasField("run"): - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._fab_id = res.run.fab_id - self._fab_ver = res.run.fab_version + raise RuntimeError(f"Cannot find the run with ID: {self._run.run_id}") + self._run = Run(**{fld.name: v for fld, v in res.run.ListFields()}) self._has_initialized = True - return self.stub, self._run_id + return self.stub, self._run.run_id def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == self._run_id + message.metadata.run_id == self._run.run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 775348437da6..d41bf1001b71 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -27,6 +27,7 @@ PullTaskResRequest, PushTaskInsRequest, ) +from flwr.proto.run_pb2 import Run # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from .grpc_driver import GrpcDriver @@ -38,7 +39,7 @@ class TestGrpcDriver(unittest.TestCase): def setUp(self) -> None: """Initialize mock GrpcDriverStub and Driver instance before each test.""" mock_response = Mock( - run=Mock(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") + run=Run(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") ) self.mock_grpc_driver_stub = Mock() self.mock_grpc_driver_stub.get_run.return_value = mock_response From 6c9657ec72a74f8153b6e8a5af3af8a198fc28c7 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 09:30:40 +0100 Subject: [PATCH 12/23] use _run & _run_id --- src/py/flwr/server/driver/grpc_driver.py | 31 +++++++++---------- src/py/flwr/server/driver/grpc_driver_test.py | 3 -- src/py/flwr/server/driver/inmemory_driver.py | 20 ++++-------- .../server/driver/inmemory_driver_test.py | 3 +- 4 files changed, 22 insertions(+), 35 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index c773fbd8ba1d..b72390fc31ef 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -17,7 +17,7 @@ import time import warnings from logging import DEBUG, ERROR, WARNING -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, cast import grpc @@ -180,38 +180,35 @@ def __init__( # pylint: disable=too-many-arguments run_id: int, stub: Optional[GrpcDriverStub] = None, ) -> None: - self.stub = stub - self._run = Run(run_id=run_id, fab_id="", fab_version="") - self._has_initialized = False + self._run_id = run_id + self._run: Optional[Run] = None + self.stub = stub if stub is not None else GrpcDriverStub() self.node = Node(node_id=0, anonymous=True) @property def run(self) -> Run: """Run information.""" self._get_grpc_driver_helper_and_run_id() - return Run(**vars(self._run)) + return Run(**vars(cast(Run, self._run))) def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]: - # Check if the GrpcDriverStub is initialized - if not self._has_initialized or self.stub is None: - # Connect and create run - if self.stub is None: - self.stub = GrpcDriverStub() - + # Check if is initialized + if self._run is None: + # Connect + self.stub.connect() # Get the run info - req = GetRunRequest(run_id=self._run.run_id) + req = GetRunRequest(run_id=self._run_id) res = self.stub.get_run(req) if not res.HasField("run"): - raise RuntimeError(f"Cannot find the run with ID: {self._run.run_id}") + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") self._run = Run(**{fld.name: v for fld, v in res.run.ListFields()}) - self._has_initialized = True return self.stub, self._run.run_id def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( - message.metadata.run_id == self._run.run_id + message.metadata.run_id == cast(Run, self._run).run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -330,8 +327,8 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" - # Check if GrpcDriverStub is initialized - if self.stub is None: + # Check if is initialized + if self._run is None: return # Disconnect self.stub.disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index d41bf1001b71..ed1928a50887 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -194,9 +194,6 @@ def test_del_with_initialized_driver(self) -> None: def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" - # Prepare - self.driver.stub = None - # Execute self.driver.close() diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index caaaaa5ae9e5..ce809d823e59 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,7 @@ import time import warnings -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, cast from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -46,11 +46,9 @@ def __init__( state_factory: StateFactory, ) -> None: self._run_id = run_id - self._fab_id = "" - self._fab_ver = "" - self._has_initialized = False - self.node = Node(node_id=0, anonymous=True) + self._run: Optional[Run] = None self.state = state_factory.state() + self.node = Node(node_id=0, anonymous=True) def _check_message(self, message: Message) -> None: # Check if the message is valid @@ -65,24 +63,18 @@ def _check_message(self, message: Message) -> None: def _init_run(self) -> None: """Initialize the run.""" - if self._has_initialized: + if self._run is not None: return run = self.state.get_run(self._run_id) if run is None: raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._fab_id = run.fab_id - self._fab_ver = run.fab_version - self._has_initialized = True + self._run = run @property def run(self) -> Run: """Run ID.""" self._init_run() - return Run( - run_id=self._run_id, - fab_id=self._fab_id, - fab_version=self._fab_ver, - ) + return Run(**vars(cast(Run, self._run))) def create_message( # pylint: disable=too-many-arguments self, diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 1f457decf228..55d52d848dfd 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -31,6 +31,7 @@ message_to_taskres, recordset_to_proto, ) +from flwr.common.typing import Run from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory @@ -84,7 +85,7 @@ def setUp(self) -> None: int.from_bytes(os.urandom(8), "little", signed=True) for _ in range(self.num_nodes) ] - self.state.get_run.return_value = MagicMock( + self.state.get_run.return_value = Run( run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" ) state_factory = MagicMock(state=lambda: self.state) From 51108737614df60cb963c831aaede2720efa1785 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 09:41:37 +0100 Subject: [PATCH 13/23] update sim --- src/py/flwr/simulation/run_simulation.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 7dc12a5afac0..a6c5021a0d24 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -27,7 +27,7 @@ from flwr.client import ClientApp from flwr.common import EventType, event, log from flwr.common.logger import set_logger_propagation, update_console_handler -from flwr.common.typing import ConfigsRecordValues, Run +from flwr.common.typing import ConfigsRecordValues from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run from flwr.server.server_app import ServerApp @@ -169,13 +169,6 @@ def server_th_with_start_checks( # type: ignore return serverapp_th -def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> None: - """Create a run with a given `run_id`.""" - log(DEBUG, "Pre-registering run with id %s", run_id) - state.state().run_ids[run_id] = Run(run_id, "", "") # type: ignore - driver._run_id = run_id # pylint: disable=protected-access - - # pylint: disable=too-many-locals def _main_loop( num_supernodes: int, @@ -202,13 +195,10 @@ def _main_loop( serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "") # Initialize Driver - driver = InMemoryDriver(run_id=run_id, state_factory=state_factory) - - if run_id: - _init_run_id(driver, state_factory, run_id) + driver = InMemoryDriver(run_id=run_id_, state_factory=state_factory) # Get and run ServerApp thread serverapp_th = run_serverapp_th( From d60d295e2d223e3b35aa697f54e40b160be989d6 Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 20 Jun 2024 11:58:35 +0200 Subject: [PATCH 14/23] fix in run_simulation() (#3654) --- src/py/flwr/simulation/run_simulation.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index a6c5021a0d24..60f84758ffcd 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -27,7 +27,7 @@ from flwr.client import ClientApp from flwr.common import EventType, event, log from flwr.common.logger import set_logger_propagation, update_console_handler -from flwr.common.typing import ConfigsRecordValues +from flwr.common.typing import ConfigsRecordValues, Run from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run from flwr.server.server_app import ServerApp @@ -169,6 +169,16 @@ def server_th_with_start_checks( # type: ignore return serverapp_th +def _override_run_id(state: StateFactory, run_id_to_replace: int, run_id: int) -> None: + """Override the run_id of an existing Run.""" + log(DEBUG, "Pre-registering run with id %s", run_id) + # Remove run + run: Run = state.state().run_ids.pop(run_id_to_replace) # type: ignore + # Update with new run_id and insert back in state + run.run_id = run_id + state.state().run_ids[run_id] = run # type: ignore + + # pylint: disable=too-many-locals def _main_loop( num_supernodes: int, @@ -176,7 +186,7 @@ def _main_loop( backend_config_stream: str, app_dir: str, enable_tf_gpu_growth: bool, - run_id: Optional[int] = None, + run_id_: Optional[int] = None, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, server_app: Optional[ServerApp] = None, @@ -195,10 +205,14 @@ def _main_loop( serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id_ = state_factory.state().create_run("", "") + run_id = state_factory.state().create_run("", "") + + if run_id_: + _override_run_id(state_factory, run_id_to_replace=run_id, run_id=run_id_) + run_id = run_id_ # Initialize Driver - driver = InMemoryDriver(run_id=run_id_, state_factory=state_factory) + driver = InMemoryDriver(run_id=run_id, state_factory=state_factory) # Get and run ServerApp thread serverapp_th = run_serverapp_th( From e65af5fa151c59f8cb82ce01e3248c845b2c3df1 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:00:47 +0100 Subject: [PATCH 15/23] fix doc string --- src/py/flwr/server/driver/grpc_driver.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index b72390fc31ef..b6b54945f84d 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -61,14 +61,10 @@ class GrpcDriverStub: driver_service_address : Optional[str] The IPv4 or IPv6 address of the Driver API server. Defaults to `"[::]:9091"`. - certificates : bytes (default: None) - Tuple containing root certificate, server certificate, and private key - to start a secure SSL-enabled server. The tuple is expected to have - three bytes elements in the following order: - - * CA certificate. - * server certificate. - * server private key. + root_certificates : Optional[bytes] (default: None) + The PEM-encoded root certificates as a byte string. + If provided, a secure connection using the certificates will be + established to an SSL-enabled Flower server. """ def __init__( From f914b2c98953b22515e5ae1eb502139854172a16 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 20 Jun 2024 12:09:49 +0200 Subject: [PATCH 16/23] Update src/py/flwr/server/driver/grpc_driver.py --- src/py/flwr/server/driver/grpc_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index b6b54945f84d..f7df4fba98a0 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -323,7 +323,7 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" - # Check if is initialized + # Check if `connect` was called before if self._run is None: return # Disconnect From c62f49624d30fe7131db3b0cec02ada4f799ea20 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:10:41 +0100 Subject: [PATCH 17/23] fix naming conflicts --- src/py/flwr/simulation/run_simulation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 60f84758ffcd..51f8edccb85f 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -173,10 +173,10 @@ def _override_run_id(state: StateFactory, run_id_to_replace: int, run_id: int) - """Override the run_id of an existing Run.""" log(DEBUG, "Pre-registering run with id %s", run_id) # Remove run - run: Run = state.state().run_ids.pop(run_id_to_replace) # type: ignore + run_info: Run = state.state().run_ids.pop(run_id_to_replace) # type: ignore # Update with new run_id and insert back in state - run.run_id = run_id - state.state().run_ids[run_id] = run # type: ignore + run_info.run_id = run_id + state.state().run_ids[run_id] = run_info # type: ignore # pylint: disable=too-many-locals From b8515e3c5507f037cfdd5f3d2ca258ad318f0c41 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:24:48 +0100 Subject: [PATCH 18/23] fix a bug that driver stub not connected --- src/py/flwr/server/driver/grpc_driver.py | 21 ++++++++++++------- src/py/flwr/server/driver/grpc_driver_test.py | 2 +- src/py/flwr/server/run_serverapp.py | 1 + 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index f7df4fba98a0..03f517846c4c 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -77,6 +77,10 @@ def __init__( self.channel: Optional[grpc.Channel] = None self.stub: Optional[DriverStub] = None + def is_connected(self) -> bool: + """Return True if connected to the Driver API server, otherwise False.""" + return self.channel is not None + def connect(self) -> None: """Connect to the Driver API.""" event(EventType.DRIVER_CONNECT) @@ -184,14 +188,15 @@ def __init__( # pylint: disable=too-many-arguments @property def run(self) -> Run: """Run information.""" - self._get_grpc_driver_helper_and_run_id() + self._get_stub_and_run_id() return Run(**vars(cast(Run, self._run))) - def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]: + def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: # Check if is initialized if self._run is None: # Connect - self.stub.connect() + if self.stub.is_connected(): + self.stub.connect() # Get the run info req = GetRunRequest(run_id=self._run_id) res = self.stub.get_run(req) @@ -225,7 +230,7 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ - _, run_id = self._get_grpc_driver_helper_and_run_id() + _, run_id = self._get_stub_and_run_id() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -249,7 +254,7 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id() + grpc_driver_helper, run_id = self._get_stub_and_run_id() # Call GrpcDriverStub method res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) return [node.node_id for node in res.nodes] @@ -260,7 +265,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id() + grpc_driver_helper, _ = self._get_stub_and_run_id() # Construct TaskIns task_ins_list: List[TaskIns] = [] for msg in messages: @@ -282,7 +287,7 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. """ - grpc_driver, _ = self._get_grpc_driver_helper_and_run_id() + grpc_driver, _ = self._get_stub_and_run_id() # Pull TaskRes res = grpc_driver.pull_task_res( PullTaskResRequest(node=self.node, task_ids=message_ids) @@ -324,7 +329,7 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" # Check if `connect` was called before - if self._run is None: + if not self.stub.is_connected(): return # Disconnect self.stub.disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index ed1928a50887..28df05110359 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -184,7 +184,7 @@ def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" # Prepare # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() + self.driver._get_stub_and_run_id() # Execute self.driver.close() diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index af751d1fa129..63ffc4a1caae 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -155,6 +155,7 @@ def run_server_app() -> None: stub = GrpcDriverStub( driver_service_address=args.superlink, root_certificates=root_certificates ) + stub.connect() req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version) res = stub.create_run(req) From fe5a8db89b9bcde21b5842659bbaae0d4819a384 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:26:21 +0100 Subject: [PATCH 19/23] quick fix --- src/py/flwr/server/driver/grpc_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 03f517846c4c..ea9b8dabad52 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -195,7 +195,7 @@ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: # Check if is initialized if self._run is None: # Connect - if self.stub.is_connected(): + if not self.stub.is_connected(): self.stub.connect() # Get the run info req = GetRunRequest(run_id=self._run_id) From 25e6354f909dbdd062a19f2d9bb4f795f106b2e0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:33:53 +0100 Subject: [PATCH 20/23] update naming --- src/py/flwr/simulation/run_simulation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 51f8edccb85f..a3de1401d252 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -186,7 +186,7 @@ def _main_loop( backend_config_stream: str, app_dir: str, enable_tf_gpu_growth: bool, - run_id_: Optional[int] = None, + run_id: Optional[int] = None, client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, server_app: Optional[ServerApp] = None, @@ -205,14 +205,14 @@ def _main_loop( serverapp_th = None try: # Create run (with empty fab_id and fab_version) - run_id = state_factory.state().create_run("", "") + run_id_ = state_factory.state().create_run("", "") - if run_id_: - _override_run_id(state_factory, run_id_to_replace=run_id, run_id=run_id_) - run_id = run_id_ + if run_id: + _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id) + run_id_ = run_id # Initialize Driver - driver = InMemoryDriver(run_id=run_id, state_factory=state_factory) + driver = InMemoryDriver(run_id=run_id_, state_factory=state_factory) # Get and run ServerApp thread serverapp_th = run_serverapp_th( From 05945f04a831b4e27c30ee2b25202e13b9a2547a Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:39:15 +0100 Subject: [PATCH 21/23] fix unit tests --- src/py/flwr/server/driver/grpc_driver_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 28df05110359..60ab79002f85 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -183,8 +183,7 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" # Prepare - # pylint: disable-next=protected-access - self.driver._get_stub_and_run_id() + self.mock_grpc_driver_stub.is_connected.return_value = True # Execute self.driver.close() @@ -194,6 +193,9 @@ def test_del_with_initialized_driver(self) -> None: def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" + # Prepare + self.mock_grpc_driver_stub.is_connected.return_value = False + # Execute self.driver.close() From 1efc53d36034aca5346554b49fdb09bb57fc9b67 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 11:54:39 +0100 Subject: [PATCH 22/23] fix get_run --- src/py/flwr/server/driver/grpc_driver.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index ea9b8dabad52..2016d54b655a 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -202,7 +202,11 @@ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]: res = self.stub.get_run(req) if not res.HasField("run"): raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._run = Run(**{fld.name: v for fld, v in res.run.ListFields()}) + self._run = Run( + run_id=res.run.run_id, + fab_id=res.run.fab_id, + fab_version=res.run.fab_version, + ) return self.stub, self._run.run_id @@ -254,9 +258,9 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - grpc_driver_helper, run_id = self._get_stub_and_run_id() + stub, run_id = self._get_stub_and_run_id() # Call GrpcDriverStub method - res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) + res = stub.get_nodes(GetNodesRequest(run_id=run_id)) return [node.node_id for node in res.nodes] def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: @@ -265,7 +269,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - grpc_driver_helper, _ = self._get_stub_and_run_id() + stub, _ = self._get_stub_and_run_id() # Construct TaskIns task_ins_list: List[TaskIns] = [] for msg in messages: @@ -276,9 +280,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: # Add to list task_ins_list.append(taskins) # Call GrpcDriverStub method - res = grpc_driver_helper.push_task_ins( - PushTaskInsRequest(task_ins_list=task_ins_list) - ) + res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) return list(res.task_ids) def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: @@ -287,9 +289,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. """ - grpc_driver, _ = self._get_stub_and_run_id() + stub, _ = self._get_stub_and_run_id() # Pull TaskRes - res = grpc_driver.pull_task_res( + res = stub.pull_task_res( PullTaskResRequest(node=self.node, task_ids=message_ids) ) # Convert TaskRes to Message From 2cdaf83694b40e3028fbad1373926f0a8188c638 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 20 Jun 2024 12:49:25 +0100 Subject: [PATCH 23/23] update in mem driver --- src/py/flwr/server/driver/inmemory_driver.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index ce809d823e59..53406796750f 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -51,9 +51,10 @@ def __init__( self.node = Node(node_id=0, anonymous=True) def _check_message(self, message: Message) -> None: + self._init_run() # Check if the message is valid if not ( - message.metadata.run_id == self.run.run_id + message.metadata.run_id == cast(Run, self._run).run_id and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" @@ -89,6 +90,7 @@ def create_message( # pylint: disable=too-many-arguments This method constructs a new `Message` with given content and metadata. The `run_id` and `src_node_id` will be set automatically. """ + self._init_run() if ttl: warnings.warn( "A custom TTL was set, but note that the SuperLink does not enforce " @@ -99,7 +101,7 @@ def create_message( # pylint: disable=too-many-arguments ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( - run_id=self.run.run_id, + run_id=cast(Run, self._run).run_id, message_id="", # Will be set by the server src_node_id=self.node.node_id, dst_node_id=dst_node_id, @@ -112,7 +114,8 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" - return list(self.state.get_nodes(self.run.run_id)) + self._init_run() + return list(self.state.get_nodes(cast(Run, self._run).run_id)) def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs.