Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add PullServerAppInputs and PushServerAppOutputs rpcs to Driver service #4363

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ syntax = "proto3";
package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/message.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";
import "flwr/proto/fab.proto";
Expand All @@ -40,6 +41,14 @@ service Driver {

// Get FAB
rpc GetFab(GetFabRequest) returns (GetFabResponse) {}

// Pull ServerApp inputs
rpc PullServerAppProcessInputs(PullServerAppProcessInputsRequest)
returns (PullServerAppProcessInputsResponse) {}

// Push ServerApp outputs
rpc PushServerAppProcessOutputs(PushServerAppProcessOutputsRequest)
returns (PushServerAppProcessOutputsResponse) {}
panh99 marked this conversation as resolved.
Show resolved Hide resolved
}

// GetNodes messages
Expand All @@ -56,3 +65,18 @@ message PullTaskResRequest {
repeated string task_ids = 2;
}
message PullTaskResResponse { repeated TaskRes task_res_list = 1; }

// PullServerAppProcessInputs messages
message PullServerAppProcessInputsRequest { uint64 run_id = 1; }
message PullServerAppProcessInputsResponse {
Context context = 1;
Run run = 2;
Fab fab = 3;
}

// PushServerAppProcessOutputs messages
message PushServerAppProcessOutputsRequest {
uint64 run_id = 1;
Context context = 2;
}
message PushServerAppProcessOutputsResponse {}
39 changes: 24 additions & 15 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions src/py/flwr/proto/driver_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
isort:skip_file
"""
import builtins
import flwr.proto.fab_pb2
import flwr.proto.message_pb2
import flwr.proto.node_pb2
import flwr.proto.run_pb2
import flwr.proto.task_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
Expand Down Expand Up @@ -91,3 +94,59 @@ class PullTaskResResponse(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ...
global___PullTaskResResponse = PullTaskResResponse

class PullServerAppProcessInputsRequest(google.protobuf.message.Message):
"""PullServerAppProcessInputs messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
run_id: builtins.int
def __init__(self,
*,
run_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
global___PullServerAppProcessInputsRequest = PullServerAppProcessInputsRequest

class PullServerAppProcessInputsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CONTEXT_FIELD_NUMBER: builtins.int
RUN_FIELD_NUMBER: builtins.int
FAB_FIELD_NUMBER: builtins.int
@property
def context(self) -> flwr.proto.message_pb2.Context: ...
@property
def run(self) -> flwr.proto.run_pb2.Run: ...
@property
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
def __init__(self,
*,
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
run: typing.Optional[flwr.proto.run_pb2.Run] = ...,
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> None: ...
global___PullServerAppProcessInputsResponse = PullServerAppProcessInputsResponse

class PushServerAppProcessOutputsRequest(google.protobuf.message.Message):
"""PushServerAppProcessOutputs messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
CONTEXT_FIELD_NUMBER: builtins.int
run_id: builtins.int
@property
def context(self) -> flwr.proto.message_pb2.Context: ...
def __init__(self,
*,
run_id: builtins.int = ...,
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["context",b"context"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","run_id",b"run_id"]) -> None: ...
global___PushServerAppProcessOutputsRequest = PushServerAppProcessOutputsRequest

class PushServerAppProcessOutputsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(self,
) -> None: ...
global___PushServerAppProcessOutputsResponse = PushServerAppProcessOutputsResponse
68 changes: 68 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
)
self.PullServerAppProcessInputs = channel.unary_unary(
'/flwr.proto.Driver/PullServerAppProcessInputs',
request_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppProcessInputsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppProcessInputsResponse.FromString,
)
self.PushServerAppProcessOutputs = channel.unary_unary(
'/flwr.proto.Driver/PushServerAppProcessOutputs',
request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppProcessOutputsRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppProcessOutputsResponse.FromString,
)


class DriverServicer(object):
Expand Down Expand Up @@ -93,6 +103,20 @@ def GetFab(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PullServerAppProcessInputs(self, request, context):
"""Pull ServerApp inputs
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PushServerAppProcessOutputs(self, request, context):
"""Push ServerApp outputs
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_DriverServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -126,6 +150,16 @@ def add_DriverServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.FromString,
response_serializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.SerializeToString,
),
'PullServerAppProcessInputs': grpc.unary_unary_rpc_method_handler(
servicer.PullServerAppProcessInputs,
request_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppProcessInputsRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppProcessInputsResponse.SerializeToString,
),
'PushServerAppProcessOutputs': grpc.unary_unary_rpc_method_handler(
servicer.PushServerAppProcessOutputs,
request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppProcessOutputsRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppProcessOutputsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Driver', rpc_method_handlers)
Expand Down Expand Up @@ -237,3 +271,37 @@ def GetFab(request,
flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def PullServerAppProcessInputs(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PullServerAppProcessInputs',
flwr_dot_proto_dot_driver__pb2.PullServerAppProcessInputsRequest.SerializeToString,
flwr_dot_proto_dot_driver__pb2.PullServerAppProcessInputsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def PushServerAppProcessOutputs(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PushServerAppProcessOutputs',
flwr_dot_proto_dot_driver__pb2.PushServerAppProcessOutputsRequest.SerializeToString,
flwr_dot_proto_dot_driver__pb2.PushServerAppProcessOutputsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
26 changes: 26 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ class DriverStub:
flwr.proto.fab_pb2.GetFabResponse]
"""Get FAB"""

PullServerAppProcessInputs: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.PullServerAppProcessInputsRequest,
flwr.proto.driver_pb2.PullServerAppProcessInputsResponse]
"""Pull ServerApp inputs"""

PushServerAppProcessOutputs: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.PushServerAppProcessOutputsRequest,
flwr.proto.driver_pb2.PushServerAppProcessOutputsResponse]
"""Push ServerApp outputs"""


class DriverServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand Down Expand Up @@ -90,5 +100,21 @@ class DriverServicer(metaclass=abc.ABCMeta):
"""Get FAB"""
pass

@abc.abstractmethod
def PullServerAppProcessInputs(self,
request: flwr.proto.driver_pb2.PullServerAppProcessInputsRequest,
context: grpc.ServicerContext,
) -> flwr.proto.driver_pb2.PullServerAppProcessInputsResponse:
"""Pull ServerApp inputs"""
pass

@abc.abstractmethod
def PushServerAppProcessOutputs(self,
request: flwr.proto.driver_pb2.PushServerAppProcessOutputsRequest,
context: grpc.ServicerContext,
) -> flwr.proto.driver_pb2.PushServerAppProcessOutputsResponse:
"""Push ServerApp outputs"""
pass


def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
16 changes: 16 additions & 0 deletions src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
GetNodesRequest,
GetNodesResponse,
PullServerAppProcessInputsRequest,
PullServerAppProcessInputsResponse,
PullTaskResRequest,
PullTaskResResponse,
PushServerAppProcessOutputsRequest,
PushServerAppProcessOutputsResponse,
PushTaskInsRequest,
PushTaskInsResponse,
)
Expand Down Expand Up @@ -200,6 +204,18 @@ def GetFab(

raise ValueError(f"Found no FAB with hash: {request.hash_str}")

def PullServerAppProcessInputs(
self, request: PullServerAppProcessInputsRequest, context: grpc.ServicerContext
) -> PullServerAppProcessInputsResponse:
"""Pull ServerApp process inputs."""
raise NotImplementedError()

def PushServerAppProcessOutputs(
self, request: PushServerAppProcessOutputsRequest, context: grpc.ServicerContext
) -> PushServerAppProcessOutputsResponse:
"""Push ServerApp process outputs."""
raise NotImplementedError()


def _raise_if(validation_error: bool, detail: str) -> None:
if validation_error:
Expand Down