Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add UpdateRunStatus RPC to Driver service #4393

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ service Driver {
// Push ServerApp outputs
rpc PushServerAppOutputs(PushServerAppOutputsRequest)
returns (PushServerAppOutputsResponse) {}

// Update the status of a given run
rpc UpdateRunStatus(UpdateRunStatusRequest)
returns (UpdateRunStatusResponse) {}
}

// GetNodes messages
Expand Down
22 changes: 22 additions & 0 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import SintList, StringList, UintList
from flwr.proto.run_pb2 import Run as ProtoRun
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import (
ClientMessage,
Expand Down Expand Up @@ -910,3 +911,24 @@ def clientappstatus_from_proto(
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
return typing.ClientAppOutputStatus(code=code, message=msg.message)


# === Run status ===


def run_status_to_proto(run_status: typing.RunStatus) -> ProtoRunStatus:
"""Serialize `RunStatus` to ProtoBuf."""
return ProtoRunStatus(
status=run_status.status,
sub_status=run_status.sub_status,
details=run_status.details,
)


def run_status_from_proto(run_status_proto: ProtoRunStatus) -> typing.RunStatus:
"""Deserialize `RunStatus` from ProtoBuf."""
return typing.RunStatus(
status=run_status_proto.status,
sub_status=run_status_proto.sub_status,
details=run_status_proto.details,
)
4 changes: 2 additions & 2 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
)
self.UpdateRunStatus = channel.unary_unary(
'/flwr.proto.Driver/UpdateRunStatus',
request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
)


class DriverServicer(object):
Expand Down Expand Up @@ -117,6 +122,13 @@ def PushServerAppOutputs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def UpdateRunStatus(self, request, context):
"""Update the status of a given run
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_DriverServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -160,6 +172,11 @@ def add_DriverServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.SerializeToString,
),
'UpdateRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.UpdateRunStatus,
request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Driver', rpc_method_handlers)
Expand Down Expand Up @@ -305,3 +322,20 @@ def PushServerAppOutputs(request,
flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def UpdateRunStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/UpdateRunStatus',
flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
13 changes: 13 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class DriverStub:
flwr.proto.driver_pb2.PushServerAppOutputsResponse]
"""Push ServerApp outputs"""

UpdateRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.UpdateRunStatusRequest,
flwr.proto.run_pb2.UpdateRunStatusResponse]
"""Update the status of a given run"""


class DriverServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand Down Expand Up @@ -116,5 +121,13 @@ class DriverServicer(metaclass=abc.ABCMeta):
"""Push ServerApp outputs"""
pass

@abc.abstractmethod
def UpdateRunStatus(self,
request: flwr.proto.run_pb2.UpdateRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.UpdateRunStatusResponse:
"""Update the status of a given run"""
pass


def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
16 changes: 16 additions & 0 deletions src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
context_to_proto,
fab_from_proto,
fab_to_proto,
run_status_from_proto,
run_to_proto,
user_config_from_proto,
)
Expand All @@ -55,6 +56,8 @@
CreateRunResponse,
GetRunRequest,
GetRunResponse,
UpdateRunStatusRequest,
UpdateRunStatusResponse,
)
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs import Ffs
Expand Down Expand Up @@ -253,6 +256,19 @@ def PushServerAppOutputs(
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
return PushServerAppOutputsResponse()

def UpdateRunStatus(
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
) -> UpdateRunStatusResponse:
"""Update the status of a run."""
log(DEBUG, "ControlServicer.UpdateRunStatus")
state = self.state_factory.state()

# Update the run status
state.update_run_status(
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
)
return UpdateRunStatusResponse()


def _raise_if(validation_error: bool, detail: str) -> None:
if validation_error:
Expand Down