Skip to content

Commit

Permalink
feat(framework) Add PullServerAppInputs and PushServerAppOutputs
Browse files Browse the repository at this point in the history
…rpcs to `Driver` service (#4363)
  • Loading branch information
panh99 authored Oct 24, 2024
1 parent ed97ac3 commit eb6d9be
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 15 deletions.
24 changes: 24 additions & 0 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ syntax = "proto3";
package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/message.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";
import "flwr/proto/fab.proto";
Expand All @@ -40,6 +41,14 @@ service Driver {

// Get FAB
rpc GetFab(GetFabRequest) returns (GetFabResponse) {}

// Pull ServerApp inputs
rpc PullServerAppInputs(PullServerAppInputsRequest)
returns (PullServerAppInputsResponse) {}

// Push ServerApp outputs
rpc PushServerAppOutputs(PushServerAppOutputsRequest)
returns (PushServerAppOutputsResponse) {}
}

// GetNodes messages
Expand All @@ -56,3 +65,18 @@ message PullTaskResRequest {
repeated string task_ids = 2;
}
message PullTaskResResponse { repeated TaskRes task_res_list = 1; }

// PullServerAppInputs messages
message PullServerAppInputsRequest { uint64 run_id = 1; }
message PullServerAppInputsResponse {
Context context = 1;
Run run = 2;
Fab fab = 3;
}

// PushServerAppOutputs messages
message PushServerAppOutputsRequest {
uint64 run_id = 1;
Context context = 2;
}
message PushServerAppOutputsResponse {}
39 changes: 24 additions & 15 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions src/py/flwr/proto/driver_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
isort:skip_file
"""
import builtins
import flwr.proto.fab_pb2
import flwr.proto.message_pb2
import flwr.proto.node_pb2
import flwr.proto.run_pb2
import flwr.proto.task_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
Expand Down Expand Up @@ -91,3 +94,59 @@ class PullTaskResResponse(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ...
global___PullTaskResResponse = PullTaskResResponse

class PullServerAppInputsRequest(google.protobuf.message.Message):
"""PullServerAppInputs messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
run_id: builtins.int
def __init__(self,
*,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
global___PullServerAppInputsRequest = PullServerAppInputsRequest

class PullServerAppInputsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CONTEXT_FIELD_NUMBER: builtins.int
RUN_FIELD_NUMBER: builtins.int
FAB_FIELD_NUMBER: builtins.int
@property
def context(self) -> flwr.proto.message_pb2.Context: ...
@property
def run(self) -> flwr.proto.run_pb2.Run: ...
@property
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
def __init__(self,
*,
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
run: typing.Optional[flwr.proto.run_pb2.Run] = ...,
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> None: ...
global___PullServerAppInputsResponse = PullServerAppInputsResponse

class PushServerAppOutputsRequest(google.protobuf.message.Message):
"""PushServerAppOutputs messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
CONTEXT_FIELD_NUMBER: builtins.int
run_id: builtins.int
@property
def context(self) -> flwr.proto.message_pb2.Context: ...
def __init__(self,
*,
run_id: builtins.int = ...,
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["context",b"context"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","run_id",b"run_id"]) -> None: ...
global___PushServerAppOutputsRequest = PushServerAppOutputsRequest

class PushServerAppOutputsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(self,
) -> None: ...
global___PushServerAppOutputsResponse = PushServerAppOutputsResponse
68 changes: 68 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
)
self.PullServerAppInputs = channel.unary_unary(
'/flwr.proto.Driver/PullServerAppInputs',
request_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString,
)
self.PushServerAppOutputs = channel.unary_unary(
'/flwr.proto.Driver/PushServerAppOutputs',
request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
)


class DriverServicer(object):
Expand Down Expand Up @@ -93,6 +103,20 @@ def GetFab(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PullServerAppInputs(self, request, context):
"""Pull ServerApp inputs
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PushServerAppOutputs(self, request, context):
"""Push ServerApp outputs
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_DriverServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -126,6 +150,16 @@ def add_DriverServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.FromString,
response_serializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.SerializeToString,
),
'PullServerAppInputs': grpc.unary_unary_rpc_method_handler(
servicer.PullServerAppInputs,
request_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.SerializeToString,
),
'PushServerAppOutputs': grpc.unary_unary_rpc_method_handler(
servicer.PushServerAppOutputs,
request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Driver', rpc_method_handlers)
Expand Down Expand Up @@ -237,3 +271,37 @@ def GetFab(request,
flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def PullServerAppInputs(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PullServerAppInputs',
flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString,
flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def PushServerAppOutputs(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PushServerAppOutputs',
flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
26 changes: 26 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ class DriverStub:
flwr.proto.fab_pb2.GetFabResponse]
"""Get FAB"""

PullServerAppInputs: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.PullServerAppInputsRequest,
flwr.proto.driver_pb2.PullServerAppInputsResponse]
"""Pull ServerApp inputs"""

PushServerAppOutputs: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.PushServerAppOutputsRequest,
flwr.proto.driver_pb2.PushServerAppOutputsResponse]
"""Push ServerApp outputs"""


class DriverServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand Down Expand Up @@ -90,5 +100,21 @@ class DriverServicer(metaclass=abc.ABCMeta):
"""Get FAB"""
pass

@abc.abstractmethod
def PullServerAppInputs(self,
request: flwr.proto.driver_pb2.PullServerAppInputsRequest,
context: grpc.ServicerContext,
) -> flwr.proto.driver_pb2.PullServerAppInputsResponse:
"""Pull ServerApp inputs"""
pass

@abc.abstractmethod
def PushServerAppOutputs(self,
request: flwr.proto.driver_pb2.PushServerAppOutputsRequest,
context: grpc.ServicerContext,
) -> flwr.proto.driver_pb2.PushServerAppOutputsResponse:
"""Push ServerApp outputs"""
pass


def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
16 changes: 16 additions & 0 deletions src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
GetNodesRequest,
GetNodesResponse,
PullServerAppInputsRequest,
PullServerAppInputsResponse,
PullTaskResRequest,
PullTaskResResponse,
PushServerAppOutputsRequest,
PushServerAppOutputsResponse,
PushTaskInsRequest,
PushTaskInsResponse,
)
Expand Down Expand Up @@ -200,6 +204,18 @@ def GetFab(

raise ValueError(f"Found no FAB with hash: {request.hash_str}")

def PullServerAppInputs(
self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
) -> PullServerAppInputsResponse:
"""Pull ServerApp process inputs."""
raise NotImplementedError()

def PushServerAppOutputs(
self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
) -> PushServerAppOutputsResponse:
"""Push ServerApp process outputs."""
raise NotImplementedError()


def _raise_if(validation_error: bool, detail: str) -> None:
if validation_error:
Expand Down

0 comments on commit eb6d9be

Please sign in to comment.