From 3a2ef9567e08112f41029432f96af701b6a3ac12 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 4 Dec 2024 08:43:17 +0000 Subject: [PATCH 1/8] Add protos for flwr stop --- src/proto/flwr/proto/control.proto | 3 -- src/proto/flwr/proto/exec.proto | 8 +++++ src/proto/flwr/proto/run.proto | 7 ++-- src/proto/flwr/proto/serverappio.proto | 3 ++ src/py/flwr/proto/control_pb2.py | 4 +-- src/py/flwr/proto/control_pb2_grpc.py | 34 ------------------- src/py/flwr/proto/control_pb2_grpc.pyi | 13 -------- src/py/flwr/proto/exec_pb2.py | 10 ++++-- src/py/flwr/proto/exec_pb2.pyi | 27 +++++++++++++++ src/py/flwr/proto/exec_pb2_grpc.py | 34 +++++++++++++++++++ src/py/flwr/proto/exec_pb2_grpc.pyi | 13 ++++++++ src/py/flwr/proto/run_pb2.py | 20 +++++------- src/py/flwr/proto/run_pb2.pyi | 38 +++++----------------- src/py/flwr/proto/serverappio_pb2.py | 4 +-- src/py/flwr/proto/serverappio_pb2_grpc.py | 34 +++++++++++++++++++ src/py/flwr/proto/serverappio_pb2_grpc.pyi | 13 ++++++++ 16 files changed, 162 insertions(+), 103 deletions(-) diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto index 8b75c66fccaa..f5668a3d977f 100644 --- a/src/proto/flwr/proto/control.proto +++ b/src/proto/flwr/proto/control.proto @@ -23,9 +23,6 @@ service Control { // Request to create a new run rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} - // Get the status of a given run - rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {} - // Update the status of a given run rpc UpdateRunStatus(UpdateRunStatusRequest) returns (UpdateRunStatusResponse) {} diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 583c42ff5704..3e624063cebd 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -26,6 +26,9 @@ service Exec { // Start run upon request rpc StartRun(StartRunRequest) returns (StartRunResponse) {} + // Stop run upon request + rpc StopRun(StopRunRequest) returns (StopRunResponse) {} + // Start log stream upon request rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} @@ -52,3 +55,8 @@ message ListRunsResponse { map run_dict = 1; string now = 2; } +message StopRunRequest { + uint64 run_id = 1; + Fab fab = 2; +} +message StopRunResponse { bool success = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 75bd0c8860d9..547f228ba4c1 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -68,11 +68,8 @@ message UpdateRunStatusRequest { message UpdateRunStatusResponse {} // GetRunStatus -message GetRunStatusRequest { - Node node = 1; - repeated uint64 run_ids = 2; -} -message GetRunStatusResponse { map run_status_dict = 1; } +message GetRunStatusRequest { uint64 run_id = 1; } +message GetRunStatusResponse { RunStatus run_status = 1; } // Get Federation Options associated with run message GetFederationOptionsRequest { uint64 run_id = 1; } diff --git a/src/proto/flwr/proto/serverappio.proto b/src/proto/flwr/proto/serverappio.proto index 3d8d3d6aa0d6..5b729e505356 100644 --- a/src/proto/flwr/proto/serverappio.proto +++ b/src/proto/flwr/proto/serverappio.proto @@ -55,6 +55,9 @@ service ServerAppIo { rpc UpdateRunStatus(UpdateRunStatusRequest) returns (UpdateRunStatusResponse) {} + // Get the status of a given run + rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {} + // Push ServerApp logs rpc PushLogs(PushLogsRequest) returns (PushLogsResponse) {} } diff --git a/src/py/flwr/proto/control_pb2.py b/src/py/flwr/proto/control_pb2.py index eb1c18d8dcff..b9d78ef96746 100644 --- a/src/py/flwr/proto/control_pb2.py +++ b/src/py/flwr/proto/control_pb2.py @@ -15,7 +15,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2\x88\x02\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2\xb3\x01\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,5 +23,5 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals['_CONTROL']._serialized_start=63 - _globals['_CONTROL']._serialized_end=327 + _globals['_CONTROL']._serialized_end=242 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py index a59f90f15935..e970d8f327fa 100644 --- a/src/py/flwr/proto/control_pb2_grpc.py +++ b/src/py/flwr/proto/control_pb2_grpc.py @@ -19,11 +19,6 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) - self.GetRunStatus = channel.unary_unary( - '/flwr.proto.Control/GetRunStatus', - request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, - ) self.UpdateRunStatus = channel.unary_unary( '/flwr.proto.Control/UpdateRunStatus', request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, @@ -41,13 +36,6 @@ def CreateRun(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def GetRunStatus(self, request, context): - """Get the status of a given run - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def UpdateRunStatus(self, request, context): """Update the status of a given run """ @@ -63,11 +51,6 @@ def add_ControlServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), - 'GetRunStatus': grpc.unary_unary_rpc_method_handler( - servicer.GetRunStatus, - request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString, - response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString, - ), 'UpdateRunStatus': grpc.unary_unary_rpc_method_handler( servicer.UpdateRunStatus, request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString, @@ -100,23 +83,6 @@ def CreateRun(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - @staticmethod - def GetRunStatus(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/GetRunStatus', - flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, - flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - @staticmethod def UpdateRunStatus(request, target, diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi index 7817e2b12e31..1e008ad1492b 100644 --- a/src/py/flwr/proto/control_pb2_grpc.pyi +++ b/src/py/flwr/proto/control_pb2_grpc.pyi @@ -13,11 +13,6 @@ class ControlStub: flwr.proto.run_pb2.CreateRunResponse] """Request to create a new run""" - GetRunStatus: grpc.UnaryUnaryMultiCallable[ - flwr.proto.run_pb2.GetRunStatusRequest, - flwr.proto.run_pb2.GetRunStatusResponse] - """Get the status of a given run""" - UpdateRunStatus: grpc.UnaryUnaryMultiCallable[ flwr.proto.run_pb2.UpdateRunStatusRequest, flwr.proto.run_pb2.UpdateRunStatusResponse] @@ -33,14 +28,6 @@ class ControlServicer(metaclass=abc.ABCMeta): """Request to create a new run""" pass - @abc.abstractmethod - def GetRunStatus(self, - request: flwr.proto.run_pb2.GetRunStatusRequest, - context: grpc.ServicerContext, - ) -> flwr.proto.run_pb2.GetRunStatusResponse: - """Get the status of a given run""" - pass - @abc.abstractmethod def UpdateRunStatus(self, request: flwr.proto.run_pb2.UpdateRunStatusRequest, diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 2240988e87a0..4a251fa58637 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\x32\xe9\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\">\n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x1c\n\x03\x66\x61\x62\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Fab\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xaf\x02\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -45,6 +45,10 @@ _globals['_LISTRUNSRESPONSE']._serialized_end=766 _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_start=703 _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_end=766 - _globals['_EXEC']._serialized_start=769 - _globals['_EXEC']._serialized_end=1002 + _globals['_STOPRUNREQUEST']._serialized_start=768 + _globals['_STOPRUNREQUEST']._serialized_end=830 + _globals['_STOPRUNRESPONSE']._serialized_start=832 + _globals['_STOPRUNRESPONSE']._serialized_end=866 + _globals['_EXEC']._serialized_start=869 + _globals['_EXEC']._serialized_end=1172 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 08e0b1c14346..de022246501b 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -132,3 +132,30 @@ class ListRunsResponse(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["now",b"now","run_dict",b"run_dict"]) -> None: ... global___ListRunsResponse = ListRunsResponse + +class StopRunRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + FAB_FIELD_NUMBER: builtins.int + run_id: builtins.int + @property + def fab(self) -> flwr.proto.fab_pb2.Fab: ... + def __init__(self, + *, + run_id: builtins.int = ..., + fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","run_id",b"run_id"]) -> None: ... +global___StopRunRequest = StopRunRequest + +class StopRunResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__(self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ... +global___StopRunResponse = StopRunResponse diff --git a/src/py/flwr/proto/exec_pb2_grpc.py b/src/py/flwr/proto/exec_pb2_grpc.py index 63f9285fed58..9ce04915771d 100644 --- a/src/py/flwr/proto/exec_pb2_grpc.py +++ b/src/py/flwr/proto/exec_pb2_grpc.py @@ -19,6 +19,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_exec__pb2.StartRunRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_exec__pb2.StartRunResponse.FromString, ) + self.StopRun = channel.unary_unary( + '/flwr.proto.Exec/StopRun', + request_serializer=flwr_dot_proto_dot_exec__pb2.StopRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_exec__pb2.StopRunResponse.FromString, + ) self.StreamLogs = channel.unary_stream( '/flwr.proto.Exec/StreamLogs', request_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.SerializeToString, @@ -41,6 +46,13 @@ def StartRun(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def StopRun(self, request, context): + """Stop run upon request + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def StreamLogs(self, request, context): """Start log stream upon request """ @@ -63,6 +75,11 @@ def add_ExecServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_exec__pb2.StartRunRequest.FromString, response_serializer=flwr_dot_proto_dot_exec__pb2.StartRunResponse.SerializeToString, ), + 'StopRun': grpc.unary_unary_rpc_method_handler( + servicer.StopRun, + request_deserializer=flwr_dot_proto_dot_exec__pb2.StopRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_exec__pb2.StopRunResponse.SerializeToString, + ), 'StreamLogs': grpc.unary_stream_rpc_method_handler( servicer.StreamLogs, request_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.FromString, @@ -100,6 +117,23 @@ def StartRun(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def StopRun(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Exec/StopRun', + flwr_dot_proto_dot_exec__pb2.StopRunRequest.SerializeToString, + flwr_dot_proto_dot_exec__pb2.StopRunResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def StreamLogs(request, target, diff --git a/src/py/flwr/proto/exec_pb2_grpc.pyi b/src/py/flwr/proto/exec_pb2_grpc.pyi index 550c282bface..12f42befab06 100644 --- a/src/py/flwr/proto/exec_pb2_grpc.pyi +++ b/src/py/flwr/proto/exec_pb2_grpc.pyi @@ -14,6 +14,11 @@ class ExecStub: flwr.proto.exec_pb2.StartRunResponse] """Start run upon request""" + StopRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.exec_pb2.StopRunRequest, + flwr.proto.exec_pb2.StopRunResponse] + """Stop run upon request""" + StreamLogs: grpc.UnaryStreamMultiCallable[ flwr.proto.exec_pb2.StreamLogsRequest, flwr.proto.exec_pb2.StreamLogsResponse] @@ -34,6 +39,14 @@ class ExecServicer(metaclass=abc.ABCMeta): """Start run upon request""" pass + @abc.abstractmethod + def StopRun(self, + request: flwr.proto.exec_pb2.StopRunRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.exec_pb2.StopRunResponse: + """Stop run upon request""" + pass + @abc.abstractmethod def StreamLogs(self, request: flwr.proto.exec_pb2.StreamLogsRequest, diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index a3aac417f9a9..df219df168ed 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xce\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"U\n\x1cGetFederationOptionsResponse\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecordb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xce\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"%\n\x13GetRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"A\n\x14GetRunStatusResponse\x12)\n\nrun_status\x18\x01 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"U\n\x1cGetFederationOptionsResponse\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecordb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,8 +29,6 @@ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' _globals['_RUN']._serialized_start=138 _globals['_RUN']._serialized_end=472 _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=399 @@ -52,13 +50,11 @@ _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=1013 _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=1038 _globals['_GETRUNSTATUSREQUEST']._serialized_start=1040 - _globals['_GETRUNSTATUSREQUEST']._serialized_end=1110 - _globals['_GETRUNSTATUSRESPONSE']._serialized_start=1113 - _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1290 - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=1215 - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1290 - _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1292 - _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1337 - _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1339 - _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1424 + _globals['_GETRUNSTATUSREQUEST']._serialized_end=1077 + _globals['_GETRUNSTATUSRESPONSE']._serialized_start=1079 + _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1144 + _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1146 + _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1191 + _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1193 + _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1278 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index cbaad46f2785..9aab14c24b69 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -200,46 +200,26 @@ global___UpdateRunStatusResponse = UpdateRunStatusResponse class GetRunStatusRequest(google.protobuf.message.Message): """GetRunStatus""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - NODE_FIELD_NUMBER: builtins.int - RUN_IDS_FIELD_NUMBER: builtins.int - @property - def node(self) -> flwr.proto.node_pb2.Node: ... - @property - def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - node: typing.Optional[flwr.proto.node_pb2.Node] = ..., - run_ids: typing.Optional[typing.Iterable[builtins.int]] = ..., + run_id: builtins.int = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_ids",b"run_ids"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... global___GetRunStatusRequest = GetRunStatusRequest class GetRunStatusResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class RunStatusDictEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.int - @property - def value(self) -> global___RunStatus: ... - def __init__(self, - *, - key: builtins.int = ..., - value: typing.Optional[global___RunStatus] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - RUN_STATUS_DICT_FIELD_NUMBER: builtins.int + RUN_STATUS_FIELD_NUMBER: builtins.int @property - def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___RunStatus]: ... + def run_status(self) -> global___RunStatus: ... def __init__(self, *, - run_status_dict: typing.Optional[typing.Mapping[builtins.int, global___RunStatus]] = ..., + run_status: typing.Optional[global___RunStatus] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> None: ... global___GetRunStatusResponse = GetRunStatusResponse class GetFederationOptionsRequest(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/serverappio_pb2.py b/src/py/flwr/proto/serverappio_pb2.py index 2bbd33b5c42b..76e2dfc31c00 100644 --- a/src/py/flwr/proto/serverappio_pb2.py +++ b/src/py/flwr/proto/serverappio_pb2.py @@ -20,7 +20,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xca\x06\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -48,5 +48,5 @@ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=760 _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=790 _globals['_SERVERAPPIO']._serialized_start=793 - _globals['_SERVERAPPIO']._serialized_end=1635 + _globals['_SERVERAPPIO']._serialized_end=1720 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/serverappio_pb2_grpc.py b/src/py/flwr/proto/serverappio_pb2_grpc.py index 1a7740db4271..ede888543883 100644 --- a/src/py/flwr/proto/serverappio_pb2_grpc.py +++ b/src/py/flwr/proto/serverappio_pb2_grpc.py @@ -62,6 +62,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, ) + self.GetRunStatus = channel.unary_unary( + '/flwr.proto.ServerAppIo/GetRunStatus', + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + ) self.PushLogs = channel.unary_unary( '/flwr.proto.ServerAppIo/PushLogs', request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString, @@ -135,6 +140,13 @@ def UpdateRunStatus(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetRunStatus(self, request, context): + """Get the status of a given run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def PushLogs(self, request, context): """Push ServerApp logs """ @@ -190,6 +202,11 @@ def add_ServerAppIoServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString, response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString, ), + 'GetRunStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetRunStatus, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString, + ), 'PushLogs': grpc.unary_unary_rpc_method_handler( servicer.PushLogs, request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString, @@ -358,6 +375,23 @@ def UpdateRunStatus(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def GetRunStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.ServerAppIo/GetRunStatus', + flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def PushLogs(request, target, diff --git a/src/py/flwr/proto/serverappio_pb2_grpc.pyi b/src/py/flwr/proto/serverappio_pb2_grpc.pyi index aa2d29473ae8..f4e3fdc208a8 100644 --- a/src/py/flwr/proto/serverappio_pb2_grpc.pyi +++ b/src/py/flwr/proto/serverappio_pb2_grpc.pyi @@ -56,6 +56,11 @@ class ServerAppIoStub: flwr.proto.run_pb2.UpdateRunStatusResponse] """Update the status of a given run""" + GetRunStatus: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.GetRunStatusRequest, + flwr.proto.run_pb2.GetRunStatusResponse] + """Get the status of a given run""" + PushLogs: grpc.UnaryUnaryMultiCallable[ flwr.proto.log_pb2.PushLogsRequest, flwr.proto.log_pb2.PushLogsResponse] @@ -135,6 +140,14 @@ class ServerAppIoServicer(metaclass=abc.ABCMeta): """Update the status of a given run""" pass + @abc.abstractmethod + def GetRunStatus(self, + request: flwr.proto.run_pb2.GetRunStatusRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.GetRunStatusResponse: + """Get the status of a given run""" + pass + @abc.abstractmethod def PushLogs(self, request: flwr.proto.log_pb2.PushLogsRequest, From 3019ebb1211a9844ff4307b062f94de87b26ed1b Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 4 Dec 2024 08:57:31 +0000 Subject: [PATCH 2/8] Add StopRun to ExecServicer --- src/py/flwr/superexec/exec_servicer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 3a484ea8c47c..7f47f2bdbe60 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -36,6 +36,8 @@ ListRunsResponse, StartRunRequest, StartRunResponse, + StopRunRequest, + StopRunResponse, StreamLogsRequest, StreamLogsResponse, ) @@ -126,6 +128,12 @@ def ListRuns( # Handle `flwr ls --run-id ` return _create_list_runs_response({request.run_id}, state) + def StopRun( + self, request: StopRunRequest, context: grpc.ServicerContext + ) -> StopRunResponse: + """Stop a given run ID.""" + raise NotImplementedError() + def _create_list_runs_response(run_ids: set[int], state: LinkState) -> ListRunsResponse: """Create response for `flwr ls --runs` and `flwr ls --run-id `.""" From dde847189eb1aa9f54f5478156466b26dc48e995 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 4 Dec 2024 16:43:03 +0000 Subject: [PATCH 3/8] Update --- src/proto/flwr/proto/control.proto | 29 ----- src/proto/flwr/proto/exec.proto | 5 +- src/py/flwr/proto/control_pb2_grpc.py | 101 ------------------ src/py/flwr/proto/control_pb2_grpc.pyi | 40 ------- src/py/flwr/proto/exec_pb2.py | 12 +-- src/py/flwr/proto/exec_pb2.pyi | 7 +- .../superlink/driver/serverappio_servicer.py | 9 ++ src/py/flwr/superexec/exec_servicer.py | 1 + 8 files changed, 18 insertions(+), 186 deletions(-) delete mode 100644 src/proto/flwr/proto/control.proto delete mode 100644 src/py/flwr/proto/control_pb2_grpc.py delete mode 100644 src/py/flwr/proto/control_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto deleted file mode 100644 index f5668a3d977f..000000000000 --- a/src/proto/flwr/proto/control.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 Flower Labs GmbH. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== - -syntax = "proto3"; - -package flwr.proto; - -import "flwr/proto/run.proto"; - -service Control { - // Request to create a new run - rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} - - // Update the status of a given run - rpc UpdateRunStatus(UpdateRunStatusRequest) - returns (UpdateRunStatusResponse) {} -} diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 972838a9e2f5..6fb2bec9f0a3 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -55,8 +55,5 @@ message ListRunsResponse { map run_dict = 1; string now = 2; } -message StopRunRequest { - uint64 run_id = 1; - Fab fab = 2; -} +message StopRunRequest { uint64 run_id = 1; } message StopRunResponse { bool success = 1; } diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py deleted file mode 100644 index e970d8f327fa..000000000000 --- a/src/py/flwr/proto/control_pb2_grpc.py +++ /dev/null @@ -1,101 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 - - -class ControlStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.CreateRun = channel.unary_unary( - '/flwr.proto.Control/CreateRun', - request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, - ) - self.UpdateRunStatus = channel.unary_unary( - '/flwr.proto.Control/UpdateRunStatus', - request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, - ) - - -class ControlServicer(object): - """Missing associated documentation comment in .proto file.""" - - def CreateRun(self, request, context): - """Request to create a new run - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def UpdateRunStatus(self, request, context): - """Update the status of a given run - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_ControlServicer_to_server(servicer, server): - rpc_method_handlers = { - 'CreateRun': grpc.unary_unary_rpc_method_handler( - servicer.CreateRun, - request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, - ), - 'UpdateRunStatus': grpc.unary_unary_rpc_method_handler( - servicer.UpdateRunStatus, - request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString, - response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'flwr.proto.Control', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class Control(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def CreateRun(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/CreateRun', - flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, - flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def UpdateRunStatus(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/UpdateRunStatus', - flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, - flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi deleted file mode 100644 index 1e008ad1492b..000000000000 --- a/src/py/flwr/proto/control_pb2_grpc.pyi +++ /dev/null @@ -1,40 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import abc -import flwr.proto.run_pb2 -import grpc - -class ControlStub: - def __init__(self, channel: grpc.Channel) -> None: ... - CreateRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.run_pb2.CreateRunRequest, - flwr.proto.run_pb2.CreateRunResponse] - """Request to create a new run""" - - UpdateRunStatus: grpc.UnaryUnaryMultiCallable[ - flwr.proto.run_pb2.UpdateRunStatusRequest, - flwr.proto.run_pb2.UpdateRunStatusResponse] - """Update the status of a given run""" - - -class ControlServicer(metaclass=abc.ABCMeta): - @abc.abstractmethod - def CreateRun(self, - request: flwr.proto.run_pb2.CreateRunRequest, - context: grpc.ServicerContext, - ) -> flwr.proto.run_pb2.CreateRunResponse: - """Request to create a new run""" - pass - - @abc.abstractmethod - def UpdateRunStatus(self, - request: flwr.proto.run_pb2.UpdateRunStatusRequest, - context: grpc.ServicerContext, - ) -> flwr.proto.run_pb2.UpdateRunStatusResponse: - """Update the status of a given run""" - pass - - -def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 71a426ba128d..18253dd1d02a 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\">\n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x1c\n\x03\x66\x61\x62\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Fab\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xaf\x02\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xaf\x02\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -46,9 +46,9 @@ _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_start=719 _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_end=782 _globals['_STOPRUNREQUEST']._serialized_start=784 - _globals['_STOPRUNREQUEST']._serialized_end=846 - _globals['_STOPRUNRESPONSE']._serialized_start=848 - _globals['_STOPRUNRESPONSE']._serialized_end=882 - _globals['_EXEC']._serialized_start=885 - _globals['_EXEC']._serialized_end=1188 + _globals['_STOPRUNREQUEST']._serialized_end=816 + _globals['_STOPRUNRESPONSE']._serialized_start=818 + _globals['_STOPRUNRESPONSE']._serialized_end=852 + _globals['_EXEC']._serialized_start=855 + _globals['_EXEC']._serialized_end=1158 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 0e7c42de554b..70ff9147e02a 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -138,17 +138,12 @@ global___ListRunsResponse = ListRunsResponse class StopRunRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor RUN_ID_FIELD_NUMBER: builtins.int - FAB_FIELD_NUMBER: builtins.int run_id: builtins.int - @property - def fab(self) -> flwr.proto.fab_pb2.Fab: ... def __init__(self, *, run_id: builtins.int = ..., - fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","run_id",b"run_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... global___StopRunRequest = StopRunRequest class StopRunResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index dddac1a93b1a..7b51e02b00b5 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -48,6 +48,8 @@ CreateRunResponse, GetRunRequest, GetRunResponse, + GetRunStatusRequest, + GetRunStatusResponse, UpdateRunStatusRequest, UpdateRunStatusResponse, ) @@ -284,6 +286,13 @@ def PushLogs( state.add_serverapp_log(request.run_id, merged_logs) return PushLogsResponse() + def GetRunStatus( + self, request: GetRunStatusRequest, context: grpc.ServicerContext + ) -> GetRunStatusResponse: + """Get the status of a run.""" + log(DEBUG, "ServerAppIoServicer.GetRunStatus") + raise NotImplementedError() + def _raise_if(validation_error: bool, detail: str) -> None: if validation_error: diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 7f47f2bdbe60..64fc17695afc 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -132,6 +132,7 @@ def StopRun( self, request: StopRunRequest, context: grpc.ServicerContext ) -> StopRunResponse: """Stop a given run ID.""" + log(INFO, "ExecServicer.StopRun") raise NotImplementedError() From 18d10db8fde5e6d77d1acbbadcb4e72db6b009fb Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 9 Dec 2024 10:07:47 +0000 Subject: [PATCH 4/8] Update protos --- src/proto/flwr/proto/run.proto | 7 +++++-- src/py/flwr/proto/run_pb2.py | 20 +++++++++++------- src/py/flwr/proto/run_pb2.pyi | 38 ++++++++++++++++++++++++++-------- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 547f228ba4c1..75bd0c8860d9 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -68,8 +68,11 @@ message UpdateRunStatusRequest { message UpdateRunStatusResponse {} // GetRunStatus -message GetRunStatusRequest { uint64 run_id = 1; } -message GetRunStatusResponse { RunStatus run_status = 1; } +message GetRunStatusRequest { + Node node = 1; + repeated uint64 run_ids = 2; +} +message GetRunStatusResponse { map run_status_dict = 1; } // Get Federation Options associated with run message GetFederationOptionsRequest { uint64 run_id = 1; } diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index df219df168ed..a3aac417f9a9 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xce\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"%\n\x13GetRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"A\n\x14GetRunStatusResponse\x12)\n\nrun_status\x18\x01 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"U\n\x1cGetFederationOptionsResponse\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecordb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xce\x02\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"U\n\x1cGetFederationOptionsResponse\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecordb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,6 +29,8 @@ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' _globals['_RUN']._serialized_start=138 _globals['_RUN']._serialized_end=472 _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=399 @@ -50,11 +52,13 @@ _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=1013 _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=1038 _globals['_GETRUNSTATUSREQUEST']._serialized_start=1040 - _globals['_GETRUNSTATUSREQUEST']._serialized_end=1077 - _globals['_GETRUNSTATUSRESPONSE']._serialized_start=1079 - _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1144 - _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1146 - _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1191 - _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1193 - _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1278 + _globals['_GETRUNSTATUSREQUEST']._serialized_end=1110 + _globals['_GETRUNSTATUSRESPONSE']._serialized_start=1113 + _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1290 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=1215 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1290 + _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=1292 + _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=1337 + _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=1339 + _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1424 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 9aab14c24b69..cbaad46f2785 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -200,26 +200,46 @@ global___UpdateRunStatusResponse = UpdateRunStatusResponse class GetRunStatusRequest(google.protobuf.message.Message): """GetRunStatus""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - run_id: builtins.int + NODE_FIELD_NUMBER: builtins.int + RUN_IDS_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... + @property + def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__(self, *, - run_id: builtins.int = ..., + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., + run_ids: typing.Optional[typing.Iterable[builtins.int]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_ids",b"run_ids"]) -> None: ... global___GetRunStatusRequest = GetRunStatusRequest class GetRunStatusResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_STATUS_FIELD_NUMBER: builtins.int + class RunStatusDictEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.int + @property + def value(self) -> global___RunStatus: ... + def __init__(self, + *, + key: builtins.int = ..., + value: typing.Optional[global___RunStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + RUN_STATUS_DICT_FIELD_NUMBER: builtins.int @property - def run_status(self) -> global___RunStatus: ... + def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___RunStatus]: ... def __init__(self, *, - run_status: typing.Optional[global___RunStatus] = ..., + run_status_dict: typing.Optional[typing.Mapping[builtins.int, global___RunStatus]] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ... global___GetRunStatusResponse = GetRunStatusResponse class GetFederationOptionsRequest(google.protobuf.message.Message): From 2ebd50e68cab954acf0b41d5ba7cbefcda258ec8 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 9 Dec 2024 10:10:55 +0000 Subject: [PATCH 5/8] Add GetRunStatus --- .../flwr/server/superlink/driver/serverappio_servicer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index 7b51e02b00b5..7565dab83aba 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -32,6 +32,7 @@ fab_from_proto, fab_to_proto, run_status_from_proto, + run_status_to_proto, run_to_proto, user_config_from_proto, ) @@ -291,7 +292,13 @@ def GetRunStatus( ) -> GetRunStatusResponse: """Get the status of a run.""" log(DEBUG, "ServerAppIoServicer.GetRunStatus") - raise NotImplementedError() + state = self.state_factory.state() + + # Get run status from LinkState + run_status = run_status_to_proto( + state.get_run_status({request.run_id})[request.run_id] + ) + return GetRunStatusResponse(run_status=run_status) def _raise_if(validation_error: bool, detail: str) -> None: From 468ac5e1eb38052d508703fa7b60be47d3542674 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 9 Dec 2024 10:15:35 +0000 Subject: [PATCH 6/8] Add run_id to PushTaskInsRequest and PullTaskResRequest --- src/proto/flwr/proto/serverappio.proto | 6 +++- src/py/flwr/proto/serverappio_pb2.py | 36 ++++++++++++------------ src/py/flwr/proto/serverappio_pb2.pyi | 10 +++++-- src/py/flwr/server/driver/grpc_driver.py | 8 ++++-- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/proto/flwr/proto/serverappio.proto b/src/proto/flwr/proto/serverappio.proto index 5b729e505356..76352866a891 100644 --- a/src/proto/flwr/proto/serverappio.proto +++ b/src/proto/flwr/proto/serverappio.proto @@ -67,13 +67,17 @@ message GetNodesRequest { uint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages -message PushTaskInsRequest { repeated TaskIns task_ins_list = 1; } +message PushTaskInsRequest { + repeated TaskIns task_ins_list = 1; + uint64 run_id = 2; +} message PushTaskInsResponse { repeated string task_ids = 2; } // PullTaskRes messages message PullTaskResRequest { Node node = 1; repeated string task_ids = 2; + uint64 run_id = 3; } message PullTaskResResponse { repeated TaskRes task_res_list = 1; } diff --git a/src/py/flwr/proto/serverappio_pb2.py b/src/py/flwr/proto/serverappio_pb2.py index 76e2dfc31c00..f97d8362e8df 100644 --- a/src/py/flwr/proto/serverappio_pb2.py +++ b/src/py/flwr/proto/serverappio_pb2.py @@ -20,7 +20,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/serverappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"P\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"V\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\x1c\n\x1aPullServerAppInputsRequest\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9f\x07\n\x0bServerAppIo\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,21 +32,21 @@ _globals['_GETNODESRESPONSE']._serialized_start=217 _globals['_GETNODESRESPONSE']._serialized_end=268 _globals['_PUSHTASKINSREQUEST']._serialized_start=270 - _globals['_PUSHTASKINSREQUEST']._serialized_end=334 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=336 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=375 - _globals['_PULLTASKRESREQUEST']._serialized_start=377 - _globals['_PULLTASKRESREQUEST']._serialized_end=447 - _globals['_PULLTASKRESRESPONSE']._serialized_start=449 - _globals['_PULLTASKRESRESPONSE']._serialized_end=514 - _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=516 - _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=544 - _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=546 - _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=673 - _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=675 - _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=758 - _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=760 - _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=790 - _globals['_SERVERAPPIO']._serialized_start=793 - _globals['_SERVERAPPIO']._serialized_end=1720 + _globals['_PUSHTASKINSREQUEST']._serialized_end=350 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=352 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=391 + _globals['_PULLTASKRESREQUEST']._serialized_start=393 + _globals['_PULLTASKRESREQUEST']._serialized_end=479 + _globals['_PULLTASKRESRESPONSE']._serialized_start=481 + _globals['_PULLTASKRESRESPONSE']._serialized_end=546 + _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=548 + _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=576 + _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=578 + _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=705 + _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=707 + _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=790 + _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=792 + _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=822 + _globals['_SERVERAPPIO']._serialized_start=825 + _globals['_SERVERAPPIO']._serialized_end=1752 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/serverappio_pb2.pyi b/src/py/flwr/proto/serverappio_pb2.pyi index 8191ec663442..38eff6456c03 100644 --- a/src/py/flwr/proto/serverappio_pb2.pyi +++ b/src/py/flwr/proto/serverappio_pb2.pyi @@ -44,13 +44,16 @@ class PushTaskInsRequest(google.protobuf.message.Message): """PushTaskIns messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_INS_LIST_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int @property def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ... + run_id: builtins.int def __init__(self, *, task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["task_ins_list",b"task_ins_list"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","task_ins_list",b"task_ins_list"]) -> None: ... global___PushTaskInsRequest = PushTaskInsRequest class PushTaskInsResponse(google.protobuf.message.Message): @@ -70,17 +73,20 @@ class PullTaskResRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor NODE_FIELD_NUMBER: builtins.int TASK_IDS_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int @property def node(self) -> flwr.proto.node_pb2.Node: ... @property def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + run_id: builtins.int def __init__(self, *, node: typing.Optional[flwr.proto.node_pb2.Node] = ..., task_ids: typing.Optional[typing.Iterable[typing.Text]] = ..., + run_id: builtins.int = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id","task_ids",b"task_ids"]) -> None: ... global___PullTaskResRequest = PullTaskResRequest class PullTaskResResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 05b7ce4be8bc..09318c32b704 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -203,7 +203,9 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: task_ins_list.append(taskins) # Call GrpcDriverStub method res: PushTaskInsResponse = self._stub.PushTaskIns( - PushTaskInsRequest(task_ins_list=task_ins_list) + PushTaskInsRequest( + task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id + ) ) return list(res.task_ids) @@ -215,7 +217,9 @@ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: """ # Pull TaskRes res: PullTaskResResponse = self._stub.PullTaskRes( - PullTaskResRequest(node=self.node, task_ids=message_ids) + PullTaskResRequest( + node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id + ) ) # Convert TaskRes to Message msgs = [message_from_taskres(taskres) for taskres in res.task_res_list] From 73483b8c638e28925fee6ae42e0fffabd43eaad6 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 9 Dec 2024 10:31:16 +0000 Subject: [PATCH 7/8] Fix GetRunStatus --- .../server/superlink/driver/serverappio_servicer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index 7565dab83aba..6c46b5ea4408 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -295,10 +295,12 @@ def GetRunStatus( state = self.state_factory.state() # Get run status from LinkState - run_status = run_status_to_proto( - state.get_run_status({request.run_id})[request.run_id] - ) - return GetRunStatusResponse(run_status=run_status) + run_statuses = state.get_run_status({request.run_ids}) + run_status_dict = { + run_id: run_status_to_proto(run_status) + for run_id, run_status in run_statuses.items() + } + return GetRunStatusResponse(run_status_dict=run_status_dict) def _raise_if(validation_error: bool, detail: str) -> None: From 4f3399213df1560c87855a95d31cb70c06b67a14 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 9 Dec 2024 11:08:10 +0000 Subject: [PATCH 8/8] Add set --- src/py/flwr/server/superlink/driver/serverappio_servicer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index 6c46b5ea4408..ca0b2ec0d8a5 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -295,7 +295,7 @@ def GetRunStatus( state = self.state_factory.state() # Get run status from LinkState - run_statuses = state.get_run_status({request.run_ids}) + run_statuses = state.get_run_status(set(request.run_ids)) run_status_dict = { run_id: run_status_to_proto(run_status) for run_id, run_status in run_statuses.items()