Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Implement ExecServicer.ListRuns RPC for flwr ls #4460

Merged
merged 16 commits into from
Nov 14, 2024
9 changes: 9 additions & 0 deletions src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ package flwr.proto;
import "flwr/proto/fab.proto";
import "flwr/proto/transport.proto";
import "flwr/proto/recordset.proto";
import "flwr/proto/run.proto";

service Exec {
// Start run upon request
rpc StartRun(StartRunRequest) returns (StartRunResponse) {}

// Start log stream upon request
rpc StreamLogs(StreamLogsRequest) returns (stream StreamLogsResponse) {}

// flwr ls command
rpc ListRuns(ListRunsRequest) returns (ListRunsResponse) {}
}

message StartRunRequest {
Expand All @@ -43,3 +47,8 @@ message StreamLogsResponse {
string log_output = 1;
double latest_timestamp = 2;
}
message ListRunsRequest { optional uint64 run_id = 1; }
message ListRunsResponse {
map<uint64, Run> run_dict = 1;
string now = 2;
}
35 changes: 22 additions & 13 deletions src/py/flwr/proto/exec_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions src/py/flwr/proto/exec_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ isort:skip_file
import builtins
import flwr.proto.fab_pb2
import flwr.proto.recordset_pb2
import flwr.proto.run_pb2
import flwr.proto.transport_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
Expand Down Expand Up @@ -88,3 +89,46 @@ class StreamLogsResponse(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["latest_timestamp",b"latest_timestamp","log_output",b"log_output"]) -> None: ...
global___StreamLogsResponse = StreamLogsResponse

class ListRunsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
run_id: builtins.int
def __init__(self,
*,
run_id: typing.Optional[builtins.int] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_run_id",b"_run_id","run_id",b"run_id"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_run_id",b"_run_id","run_id",b"run_id"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_run_id",b"_run_id"]) -> typing.Optional[typing_extensions.Literal["run_id"]]: ...
global___ListRunsRequest = ListRunsRequest

class ListRunsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class RunDictEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.int
@property
def value(self) -> flwr.proto.run_pb2.Run: ...
def __init__(self,
*,
key: builtins.int = ...,
value: typing.Optional[flwr.proto.run_pb2.Run] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

RUN_DICT_FIELD_NUMBER: builtins.int
NOW_FIELD_NUMBER: builtins.int
@property
def run_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, flwr.proto.run_pb2.Run]: ...
now: typing.Text
def __init__(self,
*,
run_dict: typing.Optional[typing.Mapping[builtins.int, flwr.proto.run_pb2.Run]] = ...,
now: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["now",b"now","run_dict",b"run_dict"]) -> None: ...
global___ListRunsResponse = ListRunsResponse
34 changes: 34 additions & 0 deletions src/py/flwr/proto/exec_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.FromString,
)
self.ListRuns = channel.unary_unary(
'/flwr.proto.Exec/ListRuns',
request_serializer=flwr_dot_proto_dot_exec__pb2.ListRunsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_exec__pb2.ListRunsResponse.FromString,
)


class ExecServicer(object):
Expand All @@ -43,6 +48,13 @@ def StreamLogs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ListRuns(self, request, context):
"""flwr ls command
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_ExecServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -56,6 +68,11 @@ def add_ExecServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_exec__pb2.StreamLogsRequest.FromString,
response_serializer=flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.SerializeToString,
),
'ListRuns': grpc.unary_unary_rpc_method_handler(
servicer.ListRuns,
request_deserializer=flwr_dot_proto_dot_exec__pb2.ListRunsRequest.FromString,
response_serializer=flwr_dot_proto_dot_exec__pb2.ListRunsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Exec', rpc_method_handlers)
Expand Down Expand Up @@ -99,3 +116,20 @@ def StreamLogs(request,
flwr_dot_proto_dot_exec__pb2.StreamLogsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def ListRuns(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Exec/ListRuns',
flwr_dot_proto_dot_exec__pb2.ListRunsRequest.SerializeToString,
flwr_dot_proto_dot_exec__pb2.ListRunsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
13 changes: 13 additions & 0 deletions src/py/flwr/proto/exec_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class ExecStub:
flwr.proto.exec_pb2.StreamLogsResponse]
"""Start log stream upon request"""

ListRuns: grpc.UnaryUnaryMultiCallable[
flwr.proto.exec_pb2.ListRunsRequest,
flwr.proto.exec_pb2.ListRunsResponse]
"""flwr ls command"""


class ExecServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand All @@ -37,5 +42,13 @@ class ExecServicer(metaclass=abc.ABCMeta):
"""Start log stream upon request"""
pass

@abc.abstractmethod
def ListRuns(self,
request: flwr.proto.exec_pb2.ListRunsRequest,
context: grpc.ServicerContext,
) -> flwr.proto.exec_pb2.ListRunsResponse:
"""flwr ls command"""
pass


def add_ExecServicer_to_server(servicer: ExecServicer, server: grpc.Server) -> None: ...
33 changes: 31 additions & 2 deletions src/py/flwr/superexec/exec_servicer.py
panh99 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,25 @@

import grpc

from flwr.common import now
from flwr.common.constant import LOG_STREAM_INTERVAL, Status
from flwr.common.logger import log
from flwr.common.serde import configs_record_from_proto, user_config_from_proto
from flwr.common.serde import (
configs_record_from_proto,
run_to_proto,
user_config_from_proto,
)
from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
ListRunsRequest,
ListRunsResponse,
StartRunRequest,
StartRunResponse,
StreamLogsRequest,
StreamLogsResponse,
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory

from .executor import Executor

Expand Down Expand Up @@ -105,3 +112,25 @@ def StreamLogs( # pylint: disable=C0103
context.cancel()

time.sleep(LOG_STREAM_INTERVAL) # Sleep briefly to avoid busy waiting

def ListRuns(
self, request: ListRunsRequest, context: grpc.ServicerContext
) -> ListRunsResponse:
"""Handle `flwr ls` command."""
log(INFO, "ExecServicer.List")
state = self.linkstate_factory.state()

# Handle `flwr ls --runs`
if not request.HasField("run_id"):
return _create_list_runs_response(state.get_run_ids(), state)
# Handle `flwr ls --run-id <run_id>`
return _create_list_runs_response({request.run_id}, state)


def _create_list_runs_response(run_ids: set[int], state: LinkState) -> ListRunsResponse:
"""Create response for `flwr ls --runs` and `flwr ls --run-id <run_id>`."""
run_dict = {run_id: state.get_run(run_id) for run_id in run_ids}
return ListRunsResponse(
run_dict={run_id: run_to_proto(run) for run_id, run in run_dict.items() if run},
now=now().isoformat(),
)
57 changes: 56 additions & 1 deletion src/py/flwr/superexec/exec_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@


import subprocess
import unittest
from datetime import datetime
from unittest.mock import MagicMock, Mock

from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.common import ConfigsRecord, now
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
ListRunsRequest,
StartRunRequest,
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkStateFactory

from .exec_servicer import ExecServicer

Expand Down Expand Up @@ -49,3 +57,50 @@ def test_start_run() -> None:
# Execute
response = servicer.StartRun(request, context_mock)
assert response.run_id == 10


class TestExecServicer(unittest.TestCase):
"""Test the Exec API servicer."""

def setUp(self) -> None:
"""Set up test fixtures."""
self.servicer = ExecServicer(
linkstate_factory=LinkStateFactory(":flwr-in-memory-state:"),
ffs_factory=FfsFactory("./tmp"),
executor=Mock(),
)
self.state = self.servicer.linkstate_factory.state()

def test_list_runs(self) -> None:
"""Test List method of ExecServicer with --runs option."""
# Prepare
run_ids = set()
for _ in range(3):
run_id = self.state.create_run(
"mock fabid", "mock fabver", "fake hash", {}, ConfigsRecord()
)
run_ids.add(run_id)

# Execute
response = self.servicer.ListRuns(ListRunsRequest(), Mock())
retrieved_timestamp = datetime.fromisoformat(response.now).timestamp()

# Assert
self.assertLess(abs(retrieved_timestamp - now().timestamp()), 1e-3)
self.assertEqual(set(response.run_dict.keys()), run_ids)

def test_list_run_id(self) -> None:
"""Test List method of ExecServicer with --run-id option."""
# Prepare
for _ in range(3):
run_id = self.state.create_run(
"mock fabid", "mock fabver", "fake hash", {}, ConfigsRecord()
)

# Execute
response = self.servicer.ListRuns(ListRunsRequest(run_id=run_id), Mock())
retrieved_timestamp = datetime.fromisoformat(response.now).timestamp()

# Assert
self.assertLess(abs(retrieved_timestamp - now().timestamp()), 1e-3)
self.assertEqual(set(response.run_dict.keys()), {run_id})