From 02f4f69d4448a81b1562d1a70bf9921d15e20832 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Nov 2024 15:48:25 +0000 Subject: [PATCH 01/15] add abc method --- .../server/superlink/linkstate/linkstate.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index a9965d430a3a..6b57b8f82368 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -182,6 +182,35 @@ def get_run(self, run_id: int) -> Optional[Run]: The `Run` instance if found; otherwise, `None`. """ + @abc.abstractmethod + def get_run_timestamps(self, run_id: int) -> tuple[str, str, str, str]: + """Retrieve the timestamps for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run. + + Returns + ------- + tuple[str, str, str, str] + A tuple containing four ISO format timestamps, representing: + - pending_at : str + The timestamp when the run was created. + - starting_at : str + The timestamp when the run started. + - running_at : str + The timestamp when the run began running. + - finished_at : str + The timestamp when the run finished. + + Notes + ----- + If a particular timestamp is not available (e.g., if the run is still + starting and doesn't have a `running_at` timestamp), an empty + string will be returned for that timestamp. + """ + @abc.abstractmethod def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs. From 36efd0b7c741d1317a3a3500bfacba15662aae3f Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Nov 2024 18:20:59 +0000 Subject: [PATCH 02/15] implement methods --- .../linkstate/in_memory_linkstate.py | 14 +++++++++++++ .../superlink/linkstate/sqlite_linkstate.py | 20 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index b689fb8b0791..e05455e5d310 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -446,6 +446,20 @@ def get_run(self, run_id: int) -> Optional[Run]: return None return self.run_ids[run_id].run + def get_run_timestamps(self, run_id: int) -> tuple[str, str, str, str]: + """Retrieve the timestamps for the specified `run_id`.""" + with self.lock: + if run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") + return "", "", "", "" + run_record = self.run_ids[run_id] + return ( + run_record.pending_at, + run_record.starting_at, + run_record.running_at, + run_record.finished_at, + ) + def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs.""" with self.lock: diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index ed0acf4213c1..73ad4b00022f 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -935,6 +935,26 @@ def get_run(self, run_id: int) -> Optional[Run]: log(ERROR, "`run_id` does not exist.") return None + def get_run_timestamps(self, run_id: int) -> tuple[str, str, str, str]: + """Retrieve the timestamps for the specified `run_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = convert_uint64_to_sint64(run_id) + query = """ + SELECT pending_at, starting_at, running_at, finished_at + FROM run WHERE run_id = ?; + """ + rows = self.query(query, (sint64_run_id,)) + if rows: + row = rows[0] + return ( + row["pending_at"], + row["starting_at"], + row["running_at"], + row["finished_at"], + ) + log(ERROR, "`run_id` is invalid") + return ("", "", "", "") + def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs.""" # Convert the uint64 value to sint64 for SQLite From 53d69a071cd1c49eee106af1762275e82b0c231f Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Nov 2024 18:29:03 +0000 Subject: [PATCH 03/15] add unit test --- .../superlink/linkstate/linkstate_test.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index d53b1ee51f6e..bab155a3b42f 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -75,6 +75,34 @@ def test_create_and_get_run(self) -> None: assert run.fab_hash == "9f86d08" assert run.override_config["test_key"] == "test_value" + @parameterized.expand([(0,), (1,), (2,), (3,)]) # type: ignore + def test_get_run_timestamps(self, num_status_transitions: int) -> None: + """Test if get_run_timestamps works correctly.""" + # Prepare + state: LinkState = self.state_factory() + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) + + # Execute + if num_status_transitions > 0: + state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + if num_status_transitions > 1: + state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + if num_status_transitions > 2: + state.update_run_status( + run_id, RunStatus(Status.FINISHED, SubStatus.COMPLETED, "") + ) + timestamps = state.get_run_timestamps(run_id) + + # Assert + prev_timestamp = 0.0 + for i in range(num_status_transitions + 1): + assert timestamps[i] != "" + timestamp = datetime.fromisoformat(timestamps[i]).timestamp() + assert timestamp > prev_timestamp + prev_timestamp = timestamp + def test_get_all_run_ids(self) -> None: """Test if get_run_ids works correctly.""" # Prepare From 714d5e429a40fd9cebfdc5794e3be92e934dac26 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 8 Nov 2024 15:51:48 +0000 Subject: [PATCH 04/15] implement rpc for flwr ls --- src/proto/flwr/proto/exec.proto | 9 ++++++ src/py/flwr/proto/exec_pb2.py | 35 ++++++++++++-------- src/py/flwr/proto/exec_pb2.pyi | 44 ++++++++++++++++++++++++++ src/py/flwr/proto/exec_pb2_grpc.py | 34 ++++++++++++++++++++ src/py/flwr/proto/exec_pb2_grpc.pyi | 13 ++++++++ src/py/flwr/superexec/exec_servicer.py | 41 ++++++++++++++++++++++-- 6 files changed, 161 insertions(+), 15 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index e5b4136d18e8..93f3ee43df29 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -20,6 +20,7 @@ package flwr.proto; import "flwr/proto/fab.proto"; import "flwr/proto/transport.proto"; import "flwr/proto/recordset.proto"; +import "flwr/proto/run.proto"; service Exec { // Start run upon request @@ -27,6 +28,9 @@ service Exec { // Start log stream upon request rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} + + // flwr ls command + rpc List(ListRequest) returns (ListResponse) {} } message StartRunRequest { @@ -43,3 +47,8 @@ message StreamLogsResponse { string log_output = 1; double latest_timestamp = 2; } +message ListRequest { + string option = 1; + Scalar value = 2; +} +message ListResponse { map run_status_dict = 1; } diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index e8fda2cfb4f8..0dd6c865fa27 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -15,9 +15,10 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\x32\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xa1\x01\n\x0cListResponse\x12\x44\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32+.flwr.proto.ListResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -26,16 +27,24 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_STARTRUNREQUEST']._serialized_start=116 - _globals['_STARTRUNREQUEST']._serialized_end=367 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=294 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=367 - _globals['_STARTRUNRESPONSE']._serialized_start=369 - _globals['_STARTRUNRESPONSE']._serialized_end=403 - _globals['_STREAMLOGSREQUEST']._serialized_start=405 - _globals['_STREAMLOGSREQUEST']._serialized_end=465 - _globals['_STREAMLOGSRESPONSE']._serialized_start=467 - _globals['_STREAMLOGSRESPONSE']._serialized_end=533 - _globals['_EXEC']._serialized_start=536 - _globals['_EXEC']._serialized_end=696 + _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._options = None + _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' + _globals['_STARTRUNREQUEST']._serialized_start=138 + _globals['_STARTRUNREQUEST']._serialized_end=389 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=316 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=389 + _globals['_STARTRUNRESPONSE']._serialized_start=391 + _globals['_STARTRUNRESPONSE']._serialized_end=425 + _globals['_STREAMLOGSREQUEST']._serialized_start=427 + _globals['_STREAMLOGSREQUEST']._serialized_end=487 + _globals['_STREAMLOGSRESPONSE']._serialized_start=489 + _globals['_STREAMLOGSRESPONSE']._serialized_end=555 + _globals['_LISTREQUEST']._serialized_start=557 + _globals['_LISTREQUEST']._serialized_end=621 + _globals['_LISTRESPONSE']._serialized_start=624 + _globals['_LISTRESPONSE']._serialized_end=785 + _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=710 + _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=785 + _globals['_EXEC']._serialized_start=788 + _globals['_EXEC']._serialized_end=1009 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 380c57ab0780..0d3773864348 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -5,6 +5,7 @@ isort:skip_file import builtins import flwr.proto.fab_pb2 import flwr.proto.recordset_pb2 +import flwr.proto.run_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -88,3 +89,46 @@ class StreamLogsResponse(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["latest_timestamp",b"latest_timestamp","log_output",b"log_output"]) -> None: ... global___StreamLogsResponse = StreamLogsResponse + +class ListRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + OPTION_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + option: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... + def __init__(self, + *, + option: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["option",b"option","value",b"value"]) -> None: ... +global___ListRequest = ListRequest + +class ListResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class RunStatusDictEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.int + @property + def value(self) -> flwr.proto.run_pb2.RunStatus: ... + def __init__(self, + *, + key: builtins.int = ..., + value: typing.Optional[flwr.proto.run_pb2.RunStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + RUN_STATUS_DICT_FIELD_NUMBER: builtins.int + @property + def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, flwr.proto.run_pb2.RunStatus]: ... + def __init__(self, + *, + run_status_dict: typing.Optional[typing.Mapping[builtins.int, flwr.proto.run_pb2.RunStatus]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ... +global___ListResponse = ListResponse diff --git a/src/py/flwr/proto/exec_pb2_grpc.py b/src/py/flwr/proto/exec_pb2_grpc.py index 8cf4ce52a300..374fe9289878 100644 --- a/src/py/flwr/proto/exec_pb2_grpc.py +++ b/src/py/flwr/proto/exec_pb2_grpc.py @@ -24,6 +24,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.FromString, ) + self.List = channel.unary_unary( + '/flwr.proto.Exec/List', + request_serializer=flwr_dot_proto_dot_exec__pb2.ListRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_exec__pb2.ListResponse.FromString, + ) class ExecServicer(object): @@ -43,6 +48,13 @@ def StreamLogs(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def List(self, request, context): + """flwr ls command + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ExecServicer_to_server(servicer, server): rpc_method_handlers = { @@ -56,6 +68,11 @@ def add_ExecServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.FromString, response_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.SerializeToString, ), + 'List': grpc.unary_unary_rpc_method_handler( + servicer.List, + request_deserializer=flwr_dot_proto_dot_exec__pb2.ListRequest.FromString, + response_serializer=flwr_dot_proto_dot_exec__pb2.ListResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'flwr.proto.Exec', rpc_method_handlers) @@ -99,3 +116,20 @@ def StreamLogs(request, flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def List(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Exec/List', + flwr_dot_proto_dot_exec__pb2.ListRequest.SerializeToString, + flwr_dot_proto_dot_exec__pb2.ListResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/exec_pb2_grpc.pyi b/src/py/flwr/proto/exec_pb2_grpc.pyi index 20da3a53f4a8..ffe03de7e584 100644 --- a/src/py/flwr/proto/exec_pb2_grpc.pyi +++ b/src/py/flwr/proto/exec_pb2_grpc.pyi @@ -19,6 +19,11 @@ class ExecStub: flwr.proto.exec_pb2.StreamLogsResponse] """Start log stream upon request""" + List: grpc.UnaryUnaryMultiCallable[ + flwr.proto.exec_pb2.ListRequest, + flwr.proto.exec_pb2.ListResponse] + """flwr ls command""" + class ExecServicer(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -37,5 +42,13 @@ class ExecServicer(metaclass=abc.ABCMeta): """Start log stream upon request""" pass + @abc.abstractmethod + def List(self, + request: flwr.proto.exec_pb2.ListRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.exec_pb2.ListResponse: + """flwr ls command""" + pass + def add_ExecServicer_to_server(servicer: ExecServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 359f87894021..ad31a6de733c 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -18,15 +18,22 @@ import time from collections.abc import Generator from logging import ERROR, INFO -from typing import Any +from typing import Any, cast import grpc from flwr.common.constant import LOG_STREAM_INTERVAL, Status from flwr.common.logger import log -from flwr.common.serde import configs_record_from_proto, user_config_from_proto +from flwr.common.serde import ( + configs_record_from_proto, + run_status_to_proto, + scalar_from_proto, + user_config_from_proto, +) from flwr.proto import exec_pb2_grpc # pylint: disable=E0611 from flwr.proto.exec_pb2 import ( # pylint: disable=E0611 + ListRequest, + ListResponse, StartRunRequest, StartRunResponse, StreamLogsRequest, @@ -105,3 +112,33 @@ def StreamLogs( # pylint: disable=C0103 context.cancel() time.sleep(LOG_STREAM_INTERVAL) # Sleep briefly to avoid busy waiting + + def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListResponse: + """Handle flwr list command.""" + log(INFO, "ExecServicer.List") + state = self.linkstate_factory.state() + + # Handle `flwr list --runs` + if request.option == "--runs": + run_ids = state.get_run_ids() + run_status_dict = state.get_run_status(run_ids) + return ListResponse( + run_status_dict={ + run_id: run_status_to_proto(run_status) + for run_id, run_status in run_status_dict.items() + } + ) + # Handle `flwr list --run-id ` + if request.option == "--run-id": + run_id = cast(int, scalar_from_proto(request.value)) + if not isinstance(run_id, int): + context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Invalid run ID") + + status = state.get_run_status({run_id}).get(run_id) + return ListResponse( + run_status_dict={run_id: run_status_to_proto(status)} if status else {} + ) + + # Unknown option + context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Invalid option") + return ListResponse() From b12b2479b9438e0a8b7b3cb2924f95fcc4778160 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Nov 2024 16:15:48 +0000 Subject: [PATCH 05/15] fix comment --- src/py/flwr/superexec/exec_servicer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index ad31a6de733c..b67c8689d73c 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -114,11 +114,11 @@ def StreamLogs( # pylint: disable=C0103 time.sleep(LOG_STREAM_INTERVAL) # Sleep briefly to avoid busy waiting def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListResponse: - """Handle flwr list command.""" + """Handle `flwr ls` command.""" log(INFO, "ExecServicer.List") state = self.linkstate_factory.state() - # Handle `flwr list --runs` + # Handle `flwr ls --runs` if request.option == "--runs": run_ids = state.get_run_ids() run_status_dict = state.get_run_status(run_ids) @@ -128,7 +128,7 @@ def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListRespo for run_id, run_status in run_status_dict.items() } ) - # Handle `flwr list --run-id ` + # Handle `flwr ls --run-id ` if request.option == "--run-id": run_id = cast(int, scalar_from_proto(request.value)) if not isinstance(run_id, int): From 46c8669d3348649c6531ff5fe548628aa0b04534 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Nov 2024 17:37:02 +0000 Subject: [PATCH 06/15] change error code --- src/py/flwr/superexec/exec_servicer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index b67c8689d73c..b29dbda2e028 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -140,5 +140,5 @@ def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListRespo ) # Unknown option - context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Invalid option") + context.abort(grpc.StatusCode.UNIMPLEMENTED, "Invalid option") return ListResponse() From 8bc57da465f135877792a93a3b5c1f593b63ef09 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Nov 2024 18:49:29 +0000 Subject: [PATCH 07/15] amend ListReq/Res --- src/proto/flwr/proto/exec.proto | 13 +++++++++- src/py/flwr/proto/exec_pb2.py | 18 +++++++------- src/py/flwr/proto/exec_pb2.pyi | 42 +++++++++++++++++++++++++++------ 3 files changed, 57 insertions(+), 16 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 93f3ee43df29..4ce670472037 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -51,4 +51,15 @@ message ListRequest { string option = 1; Scalar value = 2; } -message ListResponse { map run_status_dict = 1; } +message ListResponse { + message RunInfo { + Run run = 1; + RunStatus status = 2; + string pending_at = 3; + string starting_at = 4; + string running_at = 5; + string finished_at = 6; + } + + map run_info_dict = 1; +} diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 0dd6c865fa27..d7a0d11d31ed 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xa1\x01\n\x0cListResponse\x12\x44\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32+.flwr.proto.ListResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xc9\x02\n\x0cListResponse\x12@\n\rrun_info_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListResponse.RunInfoDictEntry\x1a\xa0\x01\n\x07RunInfo\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\x12%\n\x06status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x12\n\npending_at\x18\x03 \x01(\t\x12\x13\n\x0bstarting_at\x18\x04 \x01(\t\x12\x12\n\nrunning_at\x18\x05 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\x06 \x01(\t\x1aT\n\x10RunInfoDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .flwr.proto.ListResponse.RunInfo:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -27,8 +27,8 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._options = None - _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._options = None + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_options = b'8\001' _globals['_STARTRUNREQUEST']._serialized_start=138 _globals['_STARTRUNREQUEST']._serialized_end=389 _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=316 @@ -42,9 +42,11 @@ _globals['_LISTREQUEST']._serialized_start=557 _globals['_LISTREQUEST']._serialized_end=621 _globals['_LISTRESPONSE']._serialized_start=624 - _globals['_LISTRESPONSE']._serialized_end=785 - _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=710 - _globals['_LISTRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=785 - _globals['_EXEC']._serialized_start=788 - _globals['_EXEC']._serialized_end=1009 + _globals['_LISTRESPONSE']._serialized_end=953 + _globals['_LISTRESPONSE_RUNINFO']._serialized_start=707 + _globals['_LISTRESPONSE_RUNINFO']._serialized_end=867 + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_start=869 + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_end=953 + _globals['_EXEC']._serialized_start=956 + _globals['_EXEC']._serialized_end=1177 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 0d3773864348..2f6d26e8bc4f 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -108,27 +108,55 @@ global___ListRequest = ListRequest class ListResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class RunStatusDictEntry(google.protobuf.message.Message): + class RunInfo(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_FIELD_NUMBER: builtins.int + STATUS_FIELD_NUMBER: builtins.int + PENDING_AT_FIELD_NUMBER: builtins.int + STARTING_AT_FIELD_NUMBER: builtins.int + RUNNING_AT_FIELD_NUMBER: builtins.int + FINISHED_AT_FIELD_NUMBER: builtins.int + @property + def run(self) -> flwr.proto.run_pb2.Run: ... + @property + def status(self) -> flwr.proto.run_pb2.RunStatus: ... + pending_at: typing.Text + starting_at: typing.Text + running_at: typing.Text + finished_at: typing.Text + def __init__(self, + *, + run: typing.Optional[flwr.proto.run_pb2.Run] = ..., + status: typing.Optional[flwr.proto.run_pb2.RunStatus] = ..., + pending_at: typing.Text = ..., + starting_at: typing.Text = ..., + running_at: typing.Text = ..., + finished_at: typing.Text = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["run",b"run","status",b"status"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["finished_at",b"finished_at","pending_at",b"pending_at","run",b"run","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ... + + class RunInfoDictEntry(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: builtins.int @property - def value(self) -> flwr.proto.run_pb2.RunStatus: ... + def value(self) -> global___ListResponse.RunInfo: ... def __init__(self, *, key: builtins.int = ..., - value: typing.Optional[flwr.proto.run_pb2.RunStatus] = ..., + value: typing.Optional[global___ListResponse.RunInfo] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - RUN_STATUS_DICT_FIELD_NUMBER: builtins.int + RUN_INFO_DICT_FIELD_NUMBER: builtins.int @property - def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, flwr.proto.run_pb2.RunStatus]: ... + def run_info_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___ListResponse.RunInfo]: ... def __init__(self, *, - run_status_dict: typing.Optional[typing.Mapping[builtins.int, flwr.proto.run_pb2.RunStatus]] = ..., + run_info_dict: typing.Optional[typing.Mapping[builtins.int, global___ListResponse.RunInfo]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_info_dict",b"run_info_dict"]) -> None: ... global___ListResponse = ListResponse From b5f0306662ed774c7dc08cdb94efda1d7f00162d Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Nov 2024 20:03:05 +0000 Subject: [PATCH 08/15] amend protocols --- src/proto/flwr/proto/exec.proto | 1 + src/py/flwr/proto/exec_pb2.py | 14 +++++++------- src/py/flwr/proto/exec_pb2.pyi | 5 ++++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 4ce670472037..072e3ebe68f6 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -59,6 +59,7 @@ message ListResponse { string starting_at = 4; string running_at = 5; string finished_at = 6; + string elapsed_time = 7; } map run_info_dict = 1; diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index d7a0d11d31ed..c2429a84687b 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xc9\x02\n\x0cListResponse\x12@\n\rrun_info_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListResponse.RunInfoDictEntry\x1a\xa0\x01\n\x07RunInfo\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\x12%\n\x06status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x12\n\npending_at\x18\x03 \x01(\t\x12\x13\n\x0bstarting_at\x18\x04 \x01(\t\x12\x12\n\nrunning_at\x18\x05 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\x06 \x01(\t\x1aT\n\x10RunInfoDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .flwr.proto.ListResponse.RunInfo:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xdf\x02\n\x0cListResponse\x12@\n\rrun_info_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListResponse.RunInfoDictEntry\x1a\xb6\x01\n\x07RunInfo\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\x12%\n\x06status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x12\n\npending_at\x18\x03 \x01(\t\x12\x13\n\x0bstarting_at\x18\x04 \x01(\t\x12\x12\n\nrunning_at\x18\x05 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\x06 \x01(\t\x12\x14\n\x0c\x65lapsed_time\x18\x07 \x01(\t\x1aT\n\x10RunInfoDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .flwr.proto.ListResponse.RunInfo:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -42,11 +42,11 @@ _globals['_LISTREQUEST']._serialized_start=557 _globals['_LISTREQUEST']._serialized_end=621 _globals['_LISTRESPONSE']._serialized_start=624 - _globals['_LISTRESPONSE']._serialized_end=953 + _globals['_LISTRESPONSE']._serialized_end=975 _globals['_LISTRESPONSE_RUNINFO']._serialized_start=707 - _globals['_LISTRESPONSE_RUNINFO']._serialized_end=867 - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_start=869 - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_end=953 - _globals['_EXEC']._serialized_start=956 - _globals['_EXEC']._serialized_end=1177 + _globals['_LISTRESPONSE_RUNINFO']._serialized_end=889 + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_start=891 + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_end=975 + _globals['_EXEC']._serialized_start=978 + _globals['_EXEC']._serialized_end=1199 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 2f6d26e8bc4f..f2f36b0b35bf 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -116,6 +116,7 @@ class ListResponse(google.protobuf.message.Message): STARTING_AT_FIELD_NUMBER: builtins.int RUNNING_AT_FIELD_NUMBER: builtins.int FINISHED_AT_FIELD_NUMBER: builtins.int + ELAPSED_TIME_FIELD_NUMBER: builtins.int @property def run(self) -> flwr.proto.run_pb2.Run: ... @property @@ -124,6 +125,7 @@ class ListResponse(google.protobuf.message.Message): starting_at: typing.Text running_at: typing.Text finished_at: typing.Text + elapsed_time: typing.Text def __init__(self, *, run: typing.Optional[flwr.proto.run_pb2.Run] = ..., @@ -132,9 +134,10 @@ class ListResponse(google.protobuf.message.Message): starting_at: typing.Text = ..., running_at: typing.Text = ..., finished_at: typing.Text = ..., + elapsed_time: typing.Text = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["run",b"run","status",b"status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["finished_at",b"finished_at","pending_at",b"pending_at","run",b"run","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["elapsed_time",b"elapsed_time","finished_at",b"finished_at","pending_at",b"pending_at","run",b"run","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ... class RunInfoDictEntry(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor From 6ef34508f2618e1ebb5d6473f37956ddd0ecf794 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Nov 2024 20:49:50 +0000 Subject: [PATCH 09/15] update ExecServicer.List --- src/proto/flwr/proto/exec.proto | 2 +- src/py/flwr/proto/exec_pb2.py | 14 ++++---- src/py/flwr/proto/exec_pb2.pyi | 8 ++--- src/py/flwr/superexec/exec_servicer.py | 46 +++++++++++++++++--------- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 072e3ebe68f6..17e8e2fd76f9 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -59,7 +59,7 @@ message ListResponse { string starting_at = 4; string running_at = 5; string finished_at = 6; - string elapsed_time = 7; + string now = 7; } map run_info_dict = 1; diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index c2429a84687b..a4db71a01807 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xdf\x02\n\x0cListResponse\x12@\n\rrun_info_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListResponse.RunInfoDictEntry\x1a\xb6\x01\n\x07RunInfo\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\x12%\n\x06status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x12\n\npending_at\x18\x03 \x01(\t\x12\x13\n\x0bstarting_at\x18\x04 \x01(\t\x12\x12\n\nrunning_at\x18\x05 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\x06 \x01(\t\x12\x14\n\x0c\x65lapsed_time\x18\x07 \x01(\t\x1aT\n\x10RunInfoDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .flwr.proto.ListResponse.RunInfo:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xd6\x02\n\x0cListResponse\x12@\n\rrun_info_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListResponse.RunInfoDictEntry\x1a\xad\x01\n\x07RunInfo\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\x12%\n\x06status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x12\n\npending_at\x18\x03 \x01(\t\x12\x13\n\x0bstarting_at\x18\x04 \x01(\t\x12\x12\n\nrunning_at\x18\x05 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\x06 \x01(\t\x12\x0b\n\x03now\x18\x07 \x01(\t\x1aT\n\x10RunInfoDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .flwr.proto.ListResponse.RunInfo:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -42,11 +42,11 @@ _globals['_LISTREQUEST']._serialized_start=557 _globals['_LISTREQUEST']._serialized_end=621 _globals['_LISTRESPONSE']._serialized_start=624 - _globals['_LISTRESPONSE']._serialized_end=975 + _globals['_LISTRESPONSE']._serialized_end=966 _globals['_LISTRESPONSE_RUNINFO']._serialized_start=707 - _globals['_LISTRESPONSE_RUNINFO']._serialized_end=889 - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_start=891 - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_end=975 - _globals['_EXEC']._serialized_start=978 - _globals['_EXEC']._serialized_end=1199 + _globals['_LISTRESPONSE_RUNINFO']._serialized_end=880 + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_start=882 + _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_end=966 + _globals['_EXEC']._serialized_start=969 + _globals['_EXEC']._serialized_end=1190 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index f2f36b0b35bf..6447186256c8 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -116,7 +116,7 @@ class ListResponse(google.protobuf.message.Message): STARTING_AT_FIELD_NUMBER: builtins.int RUNNING_AT_FIELD_NUMBER: builtins.int FINISHED_AT_FIELD_NUMBER: builtins.int - ELAPSED_TIME_FIELD_NUMBER: builtins.int + NOW_FIELD_NUMBER: builtins.int @property def run(self) -> flwr.proto.run_pb2.Run: ... @property @@ -125,7 +125,7 @@ class ListResponse(google.protobuf.message.Message): starting_at: typing.Text running_at: typing.Text finished_at: typing.Text - elapsed_time: typing.Text + now: typing.Text def __init__(self, *, run: typing.Optional[flwr.proto.run_pb2.Run] = ..., @@ -134,10 +134,10 @@ class ListResponse(google.protobuf.message.Message): starting_at: typing.Text = ..., running_at: typing.Text = ..., finished_at: typing.Text = ..., - elapsed_time: typing.Text = ..., + now: typing.Text = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["run",b"run","status",b"status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["elapsed_time",b"elapsed_time","finished_at",b"finished_at","pending_at",b"pending_at","run",b"run","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["finished_at",b"finished_at","now",b"now","pending_at",b"pending_at","run",b"run","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ... class RunInfoDictEntry(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index b29dbda2e028..8029262817c5 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -18,15 +18,17 @@ import time from collections.abc import Generator from logging import ERROR, INFO -from typing import Any, cast +from typing import Any import grpc +from flwr.common import now from flwr.common.constant import LOG_STREAM_INTERVAL, Status from flwr.common.logger import log from flwr.common.serde import ( configs_record_from_proto, run_status_to_proto, + run_to_proto, scalar_from_proto, user_config_from_proto, ) @@ -40,7 +42,7 @@ StreamLogsResponse, ) from flwr.server.superlink.ffs.ffs_factory import FfsFactory -from flwr.server.superlink.linkstate import LinkStateFactory +from flwr.server.superlink.linkstate import LinkState, LinkStateFactory from .executor import Executor @@ -121,24 +123,38 @@ def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListRespo # Handle `flwr ls --runs` if request.option == "--runs": run_ids = state.get_run_ids() - run_status_dict = state.get_run_status(run_ids) - return ListResponse( - run_status_dict={ - run_id: run_status_to_proto(run_status) - for run_id, run_status in run_status_dict.items() - } - ) + return _list_runs(run_ids, state) # Handle `flwr ls --run-id ` if request.option == "--run-id": - run_id = cast(int, scalar_from_proto(request.value)) + run_id = scalar_from_proto(request.value) if not isinstance(run_id, int): context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Invalid run ID") - - status = state.get_run_status({run_id}).get(run_id) - return ListResponse( - run_status_dict={run_id: run_status_to_proto(status)} if status else {} - ) + return ListResponse() + return _list_runs({run_id}, state) # Unknown option context.abort(grpc.StatusCode.UNIMPLEMENTED, "Invalid option") return ListResponse() + + +def _list_runs(run_ids: set[int], state: LinkState) -> ListResponse: + """Create response for `flwr ls --runs` and `flwr ls --run-id `.""" + run_status_dict = state.get_run_status(run_ids) + run_info_dict: dict[int, ListResponse.RunInfo] = {} + for run_id, run_status in run_status_dict.items(): + run = state.get_run(run_id) + # Very unlikely, as we just retrieved the run status + if not run: + continue + timestamps = state.get_run_timestamps(run_id) + run_info_dict[run_id] = ListResponse.RunInfo( + run=run_to_proto(run), + status=run_status_to_proto(run_status), + pending_at=timestamps[0], + starting_at=timestamps[1], + running_at=timestamps[2], + finished_at=timestamps[3], + now=now().isoformat(), + ) + + return ListResponse(run_info_dict=run_info_dict) From 4ba1a7c55681db43ddc3ade86c13e2d79507307c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 13 Nov 2024 13:40:29 +0000 Subject: [PATCH 10/15] restore to main --- .../linkstate/in_memory_linkstate.py | 14 --------- .../server/superlink/linkstate/linkstate.py | 29 ------------------- .../superlink/linkstate/linkstate_test.py | 28 ------------------ .../superlink/linkstate/sqlite_linkstate.py | 20 ------------- 4 files changed, 91 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index e05455e5d310..b689fb8b0791 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -446,20 +446,6 @@ def get_run(self, run_id: int) -> Optional[Run]: return None return self.run_ids[run_id].run - def get_run_timestamps(self, run_id: int) -> tuple[str, str, str, str]: - """Retrieve the timestamps for the specified `run_id`.""" - with self.lock: - if run_id not in self.run_ids: - log(ERROR, "`run_id` is invalid") - return "", "", "", "" - run_record = self.run_ids[run_id] - return ( - run_record.pending_at, - run_record.starting_at, - run_record.running_at, - run_record.finished_at, - ) - def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs.""" with self.lock: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 6b57b8f82368..a9965d430a3a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -182,35 +182,6 @@ def get_run(self, run_id: int) -> Optional[Run]: The `Run` instance if found; otherwise, `None`. """ - @abc.abstractmethod - def get_run_timestamps(self, run_id: int) -> tuple[str, str, str, str]: - """Retrieve the timestamps for the specified `run_id`. - - Parameters - ---------- - run_id : int - The identifier of the run. - - Returns - ------- - tuple[str, str, str, str] - A tuple containing four ISO format timestamps, representing: - - pending_at : str - The timestamp when the run was created. - - starting_at : str - The timestamp when the run started. - - running_at : str - The timestamp when the run began running. - - finished_at : str - The timestamp when the run finished. - - Notes - ----- - If a particular timestamp is not available (e.g., if the run is still - starting and doesn't have a `running_at` timestamp), an empty - string will be returned for that timestamp. - """ - @abc.abstractmethod def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs. diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index bab155a3b42f..d53b1ee51f6e 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -75,34 +75,6 @@ def test_create_and_get_run(self) -> None: assert run.fab_hash == "9f86d08" assert run.override_config["test_key"] == "test_value" - @parameterized.expand([(0,), (1,), (2,), (3,)]) # type: ignore - def test_get_run_timestamps(self, num_status_transitions: int) -> None: - """Test if get_run_timestamps works correctly.""" - # Prepare - state: LinkState = self.state_factory() - run_id = state.create_run( - None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() - ) - - # Execute - if num_status_transitions > 0: - state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - if num_status_transitions > 1: - state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - if num_status_transitions > 2: - state.update_run_status( - run_id, RunStatus(Status.FINISHED, SubStatus.COMPLETED, "") - ) - timestamps = state.get_run_timestamps(run_id) - - # Assert - prev_timestamp = 0.0 - for i in range(num_status_transitions + 1): - assert timestamps[i] != "" - timestamp = datetime.fromisoformat(timestamps[i]).timestamp() - assert timestamp > prev_timestamp - prev_timestamp = timestamp - def test_get_all_run_ids(self) -> None: """Test if get_run_ids works correctly.""" # Prepare diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 73ad4b00022f..ed0acf4213c1 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -935,26 +935,6 @@ def get_run(self, run_id: int) -> Optional[Run]: log(ERROR, "`run_id` does not exist.") return None - def get_run_timestamps(self, run_id: int) -> tuple[str, str, str, str]: - """Retrieve the timestamps for the specified `run_id`.""" - # Convert the uint64 value to sint64 for SQLite - sint64_run_id = convert_uint64_to_sint64(run_id) - query = """ - SELECT pending_at, starting_at, running_at, finished_at - FROM run WHERE run_id = ?; - """ - rows = self.query(query, (sint64_run_id,)) - if rows: - row = rows[0] - return ( - row["pending_at"], - row["starting_at"], - row["running_at"], - row["finished_at"], - ) - log(ERROR, "`run_id` is invalid") - return ("", "", "", "") - def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs.""" # Convert the uint64 value to sint64 for SQLite From 0de36e7806286a4b190687eace23fa0466426f56 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 13 Nov 2024 13:55:58 +0000 Subject: [PATCH 11/15] update the code based on revamped get_run method --- src/proto/flwr/proto/exec.proto | 13 ++----- src/py/flwr/proto/exec_pb2.py | 18 +++++----- src/py/flwr/proto/exec_pb2.pyi | 48 ++++++-------------------- src/py/flwr/superexec/exec_servicer.py | 25 +++----------- 4 files changed, 25 insertions(+), 79 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 17e8e2fd76f9..1e33b86eb2d6 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -52,15 +52,6 @@ message ListRequest { Scalar value = 2; } message ListResponse { - message RunInfo { - Run run = 1; - RunStatus status = 2; - string pending_at = 3; - string starting_at = 4; - string running_at = 5; - string finished_at = 6; - string now = 7; - } - - map run_info_dict = 1; + map run_dict = 1; + string now = 2; } diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index a4db71a01807..667269f8c055 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\xd6\x02\n\x0cListResponse\x12@\n\rrun_info_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListResponse.RunInfoDictEntry\x1a\xad\x01\n\x07RunInfo\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\x12%\n\x06status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x12\n\npending_at\x18\x03 \x01(\t\x12\x13\n\x0bstarting_at\x18\x04 \x01(\t\x12\x12\n\nrunning_at\x18\x05 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\x06 \x01(\t\x12\x0b\n\x03now\x18\x07 \x01(\t\x1aT\n\x10RunInfoDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .flwr.proto.ListResponse.RunInfo:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\x95\x01\n\x0cListResponse\x12\x37\n\x08run_dict\x18\x01 \x03(\x0b\x32%.flwr.proto.ListResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -27,8 +27,8 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._options = None - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_options = b'8\001' + _globals['_LISTRESPONSE_RUNDICTENTRY']._options = None + _globals['_LISTRESPONSE_RUNDICTENTRY']._serialized_options = b'8\001' _globals['_STARTRUNREQUEST']._serialized_start=138 _globals['_STARTRUNREQUEST']._serialized_end=389 _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=316 @@ -42,11 +42,9 @@ _globals['_LISTREQUEST']._serialized_start=557 _globals['_LISTREQUEST']._serialized_end=621 _globals['_LISTRESPONSE']._serialized_start=624 - _globals['_LISTRESPONSE']._serialized_end=966 - _globals['_LISTRESPONSE_RUNINFO']._serialized_start=707 - _globals['_LISTRESPONSE_RUNINFO']._serialized_end=880 - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_start=882 - _globals['_LISTRESPONSE_RUNINFODICTENTRY']._serialized_end=966 - _globals['_EXEC']._serialized_start=969 - _globals['_EXEC']._serialized_end=1190 + _globals['_LISTRESPONSE']._serialized_end=773 + _globals['_LISTRESPONSE_RUNDICTENTRY']._serialized_start=710 + _globals['_LISTRESPONSE_RUNDICTENTRY']._serialized_end=773 + _globals['_EXEC']._serialized_start=776 + _globals['_EXEC']._serialized_end=997 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 6447186256c8..2920a5f54343 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -108,58 +108,30 @@ global___ListRequest = ListRequest class ListResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class RunInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_FIELD_NUMBER: builtins.int - STATUS_FIELD_NUMBER: builtins.int - PENDING_AT_FIELD_NUMBER: builtins.int - STARTING_AT_FIELD_NUMBER: builtins.int - RUNNING_AT_FIELD_NUMBER: builtins.int - FINISHED_AT_FIELD_NUMBER: builtins.int - NOW_FIELD_NUMBER: builtins.int - @property - def run(self) -> flwr.proto.run_pb2.Run: ... - @property - def status(self) -> flwr.proto.run_pb2.RunStatus: ... - pending_at: typing.Text - starting_at: typing.Text - running_at: typing.Text - finished_at: typing.Text - now: typing.Text - def __init__(self, - *, - run: typing.Optional[flwr.proto.run_pb2.Run] = ..., - status: typing.Optional[flwr.proto.run_pb2.RunStatus] = ..., - pending_at: typing.Text = ..., - starting_at: typing.Text = ..., - running_at: typing.Text = ..., - finished_at: typing.Text = ..., - now: typing.Text = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["run",b"run","status",b"status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["finished_at",b"finished_at","now",b"now","pending_at",b"pending_at","run",b"run","running_at",b"running_at","starting_at",b"starting_at","status",b"status"]) -> None: ... - - class RunInfoDictEntry(google.protobuf.message.Message): + class RunDictEntry(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: builtins.int @property - def value(self) -> global___ListResponse.RunInfo: ... + def value(self) -> flwr.proto.run_pb2.Run: ... def __init__(self, *, key: builtins.int = ..., - value: typing.Optional[global___ListResponse.RunInfo] = ..., + value: typing.Optional[flwr.proto.run_pb2.Run] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - RUN_INFO_DICT_FIELD_NUMBER: builtins.int + RUN_DICT_FIELD_NUMBER: builtins.int + NOW_FIELD_NUMBER: builtins.int @property - def run_info_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___ListResponse.RunInfo]: ... + def run_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, flwr.proto.run_pb2.Run]: ... + now: typing.Text def __init__(self, *, - run_info_dict: typing.Optional[typing.Mapping[builtins.int, global___ListResponse.RunInfo]] = ..., + run_dict: typing.Optional[typing.Mapping[builtins.int, flwr.proto.run_pb2.Run]] = ..., + now: typing.Text = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_info_dict",b"run_info_dict"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["now",b"now","run_dict",b"run_dict"]) -> None: ... global___ListResponse = ListResponse diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 8029262817c5..9a0d3d29dfbf 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -27,7 +27,6 @@ from flwr.common.logger import log from flwr.common.serde import ( configs_record_from_proto, - run_status_to_proto, run_to_proto, scalar_from_proto, user_config_from_proto, @@ -139,22 +138,8 @@ def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListRespo def _list_runs(run_ids: set[int], state: LinkState) -> ListResponse: """Create response for `flwr ls --runs` and `flwr ls --run-id `.""" - run_status_dict = state.get_run_status(run_ids) - run_info_dict: dict[int, ListResponse.RunInfo] = {} - for run_id, run_status in run_status_dict.items(): - run = state.get_run(run_id) - # Very unlikely, as we just retrieved the run status - if not run: - continue - timestamps = state.get_run_timestamps(run_id) - run_info_dict[run_id] = ListResponse.RunInfo( - run=run_to_proto(run), - status=run_status_to_proto(run_status), - pending_at=timestamps[0], - starting_at=timestamps[1], - running_at=timestamps[2], - finished_at=timestamps[3], - now=now().isoformat(), - ) - - return ListResponse(run_info_dict=run_info_dict) + run_dict = {run_id: state.get_run(run_id) for run_id in run_ids} + return ListResponse( + run_dict={run_id: run_to_proto(run) for run_id, run in run_dict.items() if run}, + now=now().isoformat(), + ) From e8c18b299f8a0d6e9b43911d76ad67d3ac0e8d99 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 13 Nov 2024 14:55:46 +0000 Subject: [PATCH 12/15] add unit test --- src/py/flwr/superexec/exec_servicer_test.py | 68 +++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 3b50200d22f2..30c60b6d6a6d 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -16,9 +16,17 @@ import subprocess +import unittest +from datetime import datetime from unittest.mock import MagicMock, Mock +import grpc + +from flwr.common import ConfigsRecord, now +from flwr.common.serde import scalar_to_proto from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 +from flwr.server.superlink.ffs.ffs_factory import FfsFactory +from flwr.server.superlink.linkstate import LinkStateFactory from .exec_servicer import ExecServicer @@ -49,3 +57,63 @@ def test_start_run() -> None: # Execute response = servicer.StartRun(request, context_mock) assert response.run_id == 10 + + +class TestExecServicer(unittest.TestCase): + """Test the Exec API servicer.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.servicer = ExecServicer( + linkstate_factory=LinkStateFactory(":flwr-in-memory-state:"), + ffs_factory=FfsFactory("./tmp"), + executor=Mock(), + ) + self.state = self.servicer.linkstate_factory.state() + + def test_list_runs(self) -> None: + """Test List method of ExecServicer with --runs option.""" + # Prepare + run_ids = set() + for _ in range(3): + run_id = self.state.create_run( + "mock fabid", "mock fabver", "fake hash", {}, ConfigsRecord() + ) + run_ids.add(run_id) + + # Execute + response = self.servicer.List(Mock(option="--runs"), Mock()) + retrieved_timestamp = datetime.fromisoformat(response.now).timestamp() + + # Assert + self.assertLess(abs(retrieved_timestamp - now().timestamp()), 1e-3) + self.assertEqual(set(response.run_dict.keys()), run_ids) + + def test_list_run_id(self) -> None: + """Test List method of ExecServicer with --run-id option.""" + # Prepare + for _ in range(3): + run_id = self.state.create_run( + "mock fabid", "mock fabver", "fake hash", {}, ConfigsRecord() + ) + + # Execute + response = self.servicer.List( + Mock(option="--run-id", value=scalar_to_proto(run_id)), Mock() + ) + retrieved_timestamp = datetime.fromisoformat(response.now).timestamp() + + # Assert + self.assertLess(abs(retrieved_timestamp - now().timestamp()), 1e-3) + self.assertEqual(set(response.run_dict.keys()), {run_id}) + + def test_list_invalid_option(self) -> None: + """Test List method of ExecServicer with invalid option.""" + # Execute + mock_context = Mock() + self.servicer.List(Mock(option="--invalid"), mock_context) + + # Assert + mock_context.abort.assert_called_once_with( + grpc.StatusCode.UNIMPLEMENTED, "Invalid option" + ) From c7acc0d68a7ca8acacda0464d650ce0960dd8c65 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 14 Nov 2024 12:17:19 +0000 Subject: [PATCH 13/15] renaming --- src/proto/flwr/proto/exec.proto | 6 +++--- src/py/flwr/proto/exec_pb2.py | 22 +++++++++++----------- src/py/flwr/proto/exec_pb2.pyi | 8 ++++---- src/py/flwr/proto/exec_pb2_grpc.py | 26 +++++++++++++------------- src/py/flwr/proto/exec_pb2_grpc.pyi | 12 ++++++------ 5 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 1e33b86eb2d6..0daa8098e20e 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -30,7 +30,7 @@ service Exec { rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {} // flwr ls command - rpc List(ListRequest) returns (ListResponse) {} + rpc ListRuns(ListRunsRequest) returns (ListRunsResponse) {} } message StartRunRequest { @@ -47,11 +47,11 @@ message StreamLogsResponse { string log_output = 1; double latest_timestamp = 2; } -message ListRequest { +message ListRunsRequest { string option = 1; Scalar value = 2; } -message ListResponse { +message ListRunsResponse { map run_dict = 1; string now = 2; } diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 667269f8c055..11f2846a726d 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"@\n\x0bListRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\x95\x01\n\x0cListResponse\x12\x37\n\x08run_dict\x18\x01 \x03(\x0b\x32%.flwr.proto.ListResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\x32\xdd\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12;\n\x04List\x12\x17.flwr.proto.ListRequest\x1a\x18.flwr.proto.ListResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"D\n\x0fListRunsRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\x32\xe9\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -27,8 +27,8 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_LISTRESPONSE_RUNDICTENTRY']._options = None - _globals['_LISTRESPONSE_RUNDICTENTRY']._serialized_options = b'8\001' + _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._options = None + _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_options = b'8\001' _globals['_STARTRUNREQUEST']._serialized_start=138 _globals['_STARTRUNREQUEST']._serialized_end=389 _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=316 @@ -39,12 +39,12 @@ _globals['_STREAMLOGSREQUEST']._serialized_end=487 _globals['_STREAMLOGSRESPONSE']._serialized_start=489 _globals['_STREAMLOGSRESPONSE']._serialized_end=555 - _globals['_LISTREQUEST']._serialized_start=557 - _globals['_LISTREQUEST']._serialized_end=621 - _globals['_LISTRESPONSE']._serialized_start=624 - _globals['_LISTRESPONSE']._serialized_end=773 - _globals['_LISTRESPONSE_RUNDICTENTRY']._serialized_start=710 - _globals['_LISTRESPONSE_RUNDICTENTRY']._serialized_end=773 - _globals['_EXEC']._serialized_start=776 - _globals['_EXEC']._serialized_end=997 + _globals['_LISTRUNSREQUEST']._serialized_start=557 + _globals['_LISTRUNSREQUEST']._serialized_end=625 + _globals['_LISTRUNSRESPONSE']._serialized_start=628 + _globals['_LISTRUNSRESPONSE']._serialized_end=785 + _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_start=722 + _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_end=785 + _globals['_EXEC']._serialized_start=788 + _globals['_EXEC']._serialized_end=1021 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 2920a5f54343..91bd06172511 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -90,7 +90,7 @@ class StreamLogsResponse(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["latest_timestamp",b"latest_timestamp","log_output",b"log_output"]) -> None: ... global___StreamLogsResponse = StreamLogsResponse -class ListRequest(google.protobuf.message.Message): +class ListRunsRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor OPTION_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int @@ -104,9 +104,9 @@ class ListRequest(google.protobuf.message.Message): ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["option",b"option","value",b"value"]) -> None: ... -global___ListRequest = ListRequest +global___ListRunsRequest = ListRunsRequest -class ListResponse(google.protobuf.message.Message): +class ListRunsResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor class RunDictEntry(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -134,4 +134,4 @@ class ListResponse(google.protobuf.message.Message): now: typing.Text = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["now",b"now","run_dict",b"run_dict"]) -> None: ... -global___ListResponse = ListResponse +global___ListRunsResponse = ListRunsResponse diff --git a/src/py/flwr/proto/exec_pb2_grpc.py b/src/py/flwr/proto/exec_pb2_grpc.py index 374fe9289878..63f9285fed58 100644 --- a/src/py/flwr/proto/exec_pb2_grpc.py +++ b/src/py/flwr/proto/exec_pb2_grpc.py @@ -24,10 +24,10 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.FromString, ) - self.List = channel.unary_unary( - '/flwr.proto.Exec/List', - request_serializer=flwr_dot_proto_dot_exec__pb2.ListRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_exec__pb2.ListResponse.FromString, + self.ListRuns = channel.unary_unary( + '/flwr.proto.Exec/ListRuns', + request_serializer=flwr_dot_proto_dot_exec__pb2.ListRunsRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_exec__pb2.ListRunsResponse.FromString, ) @@ -48,7 +48,7 @@ def StreamLogs(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def List(self, request, context): + def ListRuns(self, request, context): """flwr ls command """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -68,10 +68,10 @@ def add_ExecServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.FromString, response_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.SerializeToString, ), - 'List': grpc.unary_unary_rpc_method_handler( - servicer.List, - request_deserializer=flwr_dot_proto_dot_exec__pb2.ListRequest.FromString, - response_serializer=flwr_dot_proto_dot_exec__pb2.ListResponse.SerializeToString, + 'ListRuns': grpc.unary_unary_rpc_method_handler( + servicer.ListRuns, + request_deserializer=flwr_dot_proto_dot_exec__pb2.ListRunsRequest.FromString, + response_serializer=flwr_dot_proto_dot_exec__pb2.ListRunsResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -118,7 +118,7 @@ def StreamLogs(request, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def List(request, + def ListRuns(request, target, options=(), channel_credentials=None, @@ -128,8 +128,8 @@ def List(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Exec/List', - flwr_dot_proto_dot_exec__pb2.ListRequest.SerializeToString, - flwr_dot_proto_dot_exec__pb2.ListResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Exec/ListRuns', + flwr_dot_proto_dot_exec__pb2.ListRunsRequest.SerializeToString, + flwr_dot_proto_dot_exec__pb2.ListRunsResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/exec_pb2_grpc.pyi b/src/py/flwr/proto/exec_pb2_grpc.pyi index ffe03de7e584..550c282bface 100644 --- a/src/py/flwr/proto/exec_pb2_grpc.pyi +++ b/src/py/flwr/proto/exec_pb2_grpc.pyi @@ -19,9 +19,9 @@ class ExecStub: flwr.proto.exec_pb2.StreamLogsResponse] """Start log stream upon request""" - List: grpc.UnaryUnaryMultiCallable[ - flwr.proto.exec_pb2.ListRequest, - flwr.proto.exec_pb2.ListResponse] + ListRuns: grpc.UnaryUnaryMultiCallable[ + flwr.proto.exec_pb2.ListRunsRequest, + flwr.proto.exec_pb2.ListRunsResponse] """flwr ls command""" @@ -43,10 +43,10 @@ class ExecServicer(metaclass=abc.ABCMeta): pass @abc.abstractmethod - def List(self, - request: flwr.proto.exec_pb2.ListRequest, + def ListRuns(self, + request: flwr.proto.exec_pb2.ListRunsRequest, context: grpc.ServicerContext, - ) -> flwr.proto.exec_pb2.ListResponse: + ) -> flwr.proto.exec_pb2.ListRunsResponse: """flwr ls command""" pass From f6023d63198ebd0a1ef0cd9ffb688b4f077b33cf Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 14 Nov 2024 12:23:55 +0000 Subject: [PATCH 14/15] rename List to ListRuns --- src/py/flwr/superexec/exec_servicer.py | 16 +++++++++------- src/py/flwr/superexec/exec_servicer_test.py | 6 +++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 9a0d3d29dfbf..59ddd2b2d2c0 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -33,8 +33,8 @@ ) from flwr.proto import exec_pb2_grpc # pylint: disable=E0611 from flwr.proto.exec_pb2 import ( # pylint: disable=E0611 - ListRequest, - ListResponse, + ListRunsRequest, + ListRunsResponse, StartRunRequest, StartRunResponse, StreamLogsRequest, @@ -114,7 +114,9 @@ def StreamLogs( # pylint: disable=C0103 time.sleep(LOG_STREAM_INTERVAL) # Sleep briefly to avoid busy waiting - def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListResponse: + def ListRuns( + self, request: ListRunsRequest, context: grpc.ServicerContext + ) -> ListRunsResponse: """Handle `flwr ls` command.""" log(INFO, "ExecServicer.List") state = self.linkstate_factory.state() @@ -128,18 +130,18 @@ def List(self, request: ListRequest, context: grpc.ServicerContext) -> ListRespo run_id = scalar_from_proto(request.value) if not isinstance(run_id, int): context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Invalid run ID") - return ListResponse() + return ListRunsResponse() return _list_runs({run_id}, state) # Unknown option context.abort(grpc.StatusCode.UNIMPLEMENTED, "Invalid option") - return ListResponse() + return ListRunsResponse() -def _list_runs(run_ids: set[int], state: LinkState) -> ListResponse: +def _list_runs(run_ids: set[int], state: LinkState) -> ListRunsResponse: """Create response for `flwr ls --runs` and `flwr ls --run-id `.""" run_dict = {run_id: state.get_run(run_id) for run_id in run_ids} - return ListResponse( + return ListRunsResponse( run_dict={run_id: run_to_proto(run) for run_id, run in run_dict.items() if run}, now=now().isoformat(), ) diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 30c60b6d6a6d..b66b92585d94 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -82,7 +82,7 @@ def test_list_runs(self) -> None: run_ids.add(run_id) # Execute - response = self.servicer.List(Mock(option="--runs"), Mock()) + response = self.servicer.ListRuns(Mock(option="--runs"), Mock()) retrieved_timestamp = datetime.fromisoformat(response.now).timestamp() # Assert @@ -98,7 +98,7 @@ def test_list_run_id(self) -> None: ) # Execute - response = self.servicer.List( + response = self.servicer.ListRuns( Mock(option="--run-id", value=scalar_to_proto(run_id)), Mock() ) retrieved_timestamp = datetime.fromisoformat(response.now).timestamp() @@ -111,7 +111,7 @@ def test_list_invalid_option(self) -> None: """Test List method of ExecServicer with invalid option.""" # Execute mock_context = Mock() - self.servicer.List(Mock(option="--invalid"), mock_context) + self.servicer.ListRuns(Mock(option="--invalid"), mock_context) # Assert mock_context.abort.assert_called_once_with( From 47724fa1cf5667f726e551054ea59a276ed178fd Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 14 Nov 2024 14:34:48 +0000 Subject: [PATCH 15/15] refactor --- src/proto/flwr/proto/exec.proto | 5 +---- src/py/flwr/proto/exec_pb2.py | 16 ++++++------- src/py/flwr/proto/exec_pb2.pyi | 15 +++++-------- src/py/flwr/superexec/exec_servicer.py | 19 ++++------------ src/py/flwr/superexec/exec_servicer_test.py | 25 +++++---------------- 5 files changed, 25 insertions(+), 55 deletions(-) diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 0daa8098e20e..583c42ff5704 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -47,10 +47,7 @@ message StreamLogsResponse { string log_output = 1; double latest_timestamp = 2; } -message ListRunsRequest { - string option = 1; - Scalar value = 2; -} +message ListRunsRequest { optional uint64 run_id = 1; } message ListRunsResponse { map run_dict = 1; string now = 2; diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 11f2846a726d..2240988e87a0 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"D\n\x0fListRunsRequest\x12\x0e\n\x06option\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\x32\xe9\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\x32\xe9\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -40,11 +40,11 @@ _globals['_STREAMLOGSRESPONSE']._serialized_start=489 _globals['_STREAMLOGSRESPONSE']._serialized_end=555 _globals['_LISTRUNSREQUEST']._serialized_start=557 - _globals['_LISTRUNSREQUEST']._serialized_end=625 - _globals['_LISTRUNSRESPONSE']._serialized_start=628 - _globals['_LISTRUNSRESPONSE']._serialized_end=785 - _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_start=722 - _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_end=785 - _globals['_EXEC']._serialized_start=788 - _globals['_EXEC']._serialized_end=1021 + _globals['_LISTRUNSREQUEST']._serialized_end=606 + _globals['_LISTRUNSRESPONSE']._serialized_start=609 + _globals['_LISTRUNSRESPONSE']._serialized_end=766 + _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_start=703 + _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_end=766 + _globals['_EXEC']._serialized_start=769 + _globals['_EXEC']._serialized_end=1002 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 91bd06172511..08e0b1c14346 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -92,18 +92,15 @@ global___StreamLogsResponse = StreamLogsResponse class ListRunsRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - OPTION_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - option: typing.Text - @property - def value(self) -> flwr.proto.transport_pb2.Scalar: ... + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - option: typing.Text = ..., - value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + run_id: typing.Optional[builtins.int] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["option",b"option","value",b"value"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_run_id",b"_run_id","run_id",b"run_id"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_run_id",b"_run_id","run_id",b"run_id"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["_run_id",b"_run_id"]) -> typing.Optional[typing_extensions.Literal["run_id"]]: ... global___ListRunsRequest = ListRunsRequest class ListRunsResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 59ddd2b2d2c0..3a484ea8c47c 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -28,7 +28,6 @@ from flwr.common.serde import ( configs_record_from_proto, run_to_proto, - scalar_from_proto, user_config_from_proto, ) from flwr.proto import exec_pb2_grpc # pylint: disable=E0611 @@ -122,23 +121,13 @@ def ListRuns( state = self.linkstate_factory.state() # Handle `flwr ls --runs` - if request.option == "--runs": - run_ids = state.get_run_ids() - return _list_runs(run_ids, state) + if not request.HasField("run_id"): + return _create_list_runs_response(state.get_run_ids(), state) # Handle `flwr ls --run-id ` - if request.option == "--run-id": - run_id = scalar_from_proto(request.value) - if not isinstance(run_id, int): - context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Invalid run ID") - return ListRunsResponse() - return _list_runs({run_id}, state) + return _create_list_runs_response({request.run_id}, state) - # Unknown option - context.abort(grpc.StatusCode.UNIMPLEMENTED, "Invalid option") - return ListRunsResponse() - -def _list_runs(run_ids: set[int], state: LinkState) -> ListRunsResponse: +def _create_list_runs_response(run_ids: set[int], state: LinkState) -> ListRunsResponse: """Create response for `flwr ls --runs` and `flwr ls --run-id `.""" run_dict = {run_id: state.get_run(run_id) for run_id in run_ids} return ListRunsResponse( diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index b66b92585d94..6045d6eb1a63 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -20,11 +20,11 @@ from datetime import datetime from unittest.mock import MagicMock, Mock -import grpc - from flwr.common import ConfigsRecord, now -from flwr.common.serde import scalar_to_proto -from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 +from flwr.proto.exec_pb2 import ( # pylint: disable=E0611 + ListRunsRequest, + StartRunRequest, +) from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate import LinkStateFactory @@ -82,7 +82,7 @@ def test_list_runs(self) -> None: run_ids.add(run_id) # Execute - response = self.servicer.ListRuns(Mock(option="--runs"), Mock()) + response = self.servicer.ListRuns(ListRunsRequest(), Mock()) retrieved_timestamp = datetime.fromisoformat(response.now).timestamp() # Assert @@ -98,22 +98,9 @@ def test_list_run_id(self) -> None: ) # Execute - response = self.servicer.ListRuns( - Mock(option="--run-id", value=scalar_to_proto(run_id)), Mock() - ) + response = self.servicer.ListRuns(ListRunsRequest(run_id=run_id), Mock()) retrieved_timestamp = datetime.fromisoformat(response.now).timestamp() # Assert self.assertLess(abs(retrieved_timestamp - now().timestamp()), 1e-3) self.assertEqual(set(response.run_dict.keys()), {run_id}) - - def test_list_invalid_option(self) -> None: - """Test List method of ExecServicer with invalid option.""" - # Execute - mock_context = Mock() - self.servicer.ListRuns(Mock(option="--invalid"), mock_context) - - # Assert - mock_context.abort.assert_called_once_with( - grpc.StatusCode.UNIMPLEMENTED, "Invalid option" - )