Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename protobuf messages #1214

Merged
merged 6 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/proto/flwr/proto/transport.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

syntax = "proto3";

package flower.transport;
package flwr.proto;

service FlowerService {
rpc Join(stream ClientMessage) returns (stream ServerMessage) {}
Expand Down Expand Up @@ -43,7 +43,7 @@ enum Reason {

message ServerMessage {
message Reconnect { int64 seconds = 1; }
message GetParameters {}
message GetParametersIns {}
message FitIns {
Parameters parameters = 1;
map<string, Scalar> config = 2;
Expand All @@ -52,19 +52,19 @@ message ServerMessage {
Parameters parameters = 1;
map<string, Scalar> config = 2;
}
message PropertiesIns { map<string, Scalar> config = 1; }
message GetPropertiesIns { map<string, Scalar> config = 1; }
oneof msg {
Reconnect reconnect = 1;
GetParameters get_parameters = 2;
GetParametersIns get_parameters_ins = 2;
FitIns fit_ins = 3;
EvaluateIns evaluate_ins = 4;
PropertiesIns properties_ins = 5;
GetPropertiesIns get_properties_ins = 5;
}
}

message ClientMessage {
message Disconnect { Reason reason = 1; }
message ParametersRes { Parameters parameters = 1; }
message GetParametersRes { Parameters parameters = 1; }
message FitRes {
Parameters parameters = 1;
int64 num_examples = 2;
Expand All @@ -75,21 +75,21 @@ message ClientMessage {
float loss = 2;
map<string, Scalar> metrics = 4;
}
message PropertiesRes {
message GetPropertiesRes {
Status status = 1;
map<string, Scalar> properties = 2;
}
oneof msg {
Disconnect disconnect = 1;
ParametersRes parameters_res = 2;
GetParametersRes get_parameters_res = 2;
FitRes fit_res = 3;
EvaluateRes evaluate_res = 4;
PropertiesRes properties_res = 5;
GetPropertiesRes get_properties_res = 5;
}
}

message Scalar {
// The following oneof contains all types that ProtoBuf considers to be
// The following `oneof` contains all types that ProtoBuf considers to be
// "Scalar Value Types". Commented-out types are listed for reference and
// might be enabled in future releases. Source:
// https://developers.google.com/protocol-buffers/docs/proto3#scalar
Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
EvaluateRes,
FitIns,
FitRes,
ParametersRes,
PropertiesIns,
PropertiesRes,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
Scalar,
)

Expand All @@ -39,11 +39,11 @@
class PlainClient(Client):
"""Client implementation extending the low-level Client."""

def get_properties(self, ins: PropertiesIns) -> PropertiesRes:
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
# This method is not expected to be called
raise Exception()

def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
# This method is not expected to be called
raise Exception()

Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,37 @@
EvaluateRes,
FitIns,
FitRes,
ParametersRes,
PropertiesIns,
PropertiesRes,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
)


class Client(ABC):
"""Abstract base class for Flower clients."""

def get_properties(self, ins: PropertiesIns) -> PropertiesRes:
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.

Parameters
----------
ins : PropertiesIns
ins : GetPropertiesIns
The get properties instructions received from the server containing
a dictionary of configuration values used to configure.

Returns
-------
PropertiesRes
GetPropertiesRes
Client's properties.
"""

@abstractmethod
def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
"""Return the current local model parameters.

Returns
-------
ParametersRes
GetParametersRes
The current local model parameters.
"""

Expand Down
14 changes: 7 additions & 7 deletions src/py/flwr/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
EvaluateRes,
FitIns,
FitRes,
ParametersRes,
PropertiesIns,
PropertiesRes,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
Status,
)

Expand All @@ -33,12 +33,12 @@
class OverridingClient(Client):
"""Client overriding `get_properties`."""

def get_properties(self, ins: PropertiesIns) -> PropertiesRes:
return PropertiesRes(
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
return GetPropertiesRes(
status=Status(code=Code.OK, message="Success"), properties={}
)

def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
# This method is not expected to be called
raise Exception()

Expand All @@ -54,7 +54,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
class NotOverridingClient(Client):
"""Client not overriding `get_properties`."""

def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
# This method is not expected to be called
raise Exception()

Expand Down
30 changes: 15 additions & 15 deletions src/py/flwr/client/grpc_client/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def handle(
if field == "reconnect":
disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect)
return disconnect_msg, sleep_duration, False
if field == "properties_ins":
return _get_properties(client, server_msg.properties_ins), 0, True
if field == "get_parameters":
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins), 0, True
if field == "get_parameters_ins":
return _get_parameters(client), 0, True
if field == "fit_ins":
return _fit(client, server_msg.fit_ins), 0, True
Expand All @@ -79,35 +79,35 @@ def _reconnect(


def _get_properties(
client: Client, properties_msg: ServerMessage.PropertiesIns
client: Client, get_properties_msg: ServerMessage.GetPropertiesIns
) -> ClientMessage:
# Check if client overrides get_properties
if not has_get_properties(client=client):
# If client does not override get_properties, don't call it
properties_res = typing.PropertiesRes(
get_properties_res = typing.GetPropertiesRes(
status=typing.Status(
code=typing.Code.GET_PARAMETERS_NOT_IMPLEMENTED,
message="Client does not implement get_properties",
),
properties={},
)
properties_res_proto = serde.properties_res_to_proto(properties_res)
return ClientMessage(properties_res=properties_res_proto)
get_properties_res_proto = serde.get_properties_res_to_proto(get_properties_res)
return ClientMessage(get_properties_res=get_properties_res_proto)

# Deserialize get_properties instruction
properties_ins = serde.properties_ins_from_proto(properties_msg)
# Request for properties
properties_res = client.get_properties(properties_ins)
get_properties_ins = serde.get_properties_ins_from_proto(get_properties_msg)
# Request properties
get_properties_res = client.get_properties(get_properties_ins)
# Serialize response
properties_res_proto = serde.properties_res_to_proto(properties_res)
return ClientMessage(properties_res=properties_res_proto)
get_properties_res_proto = serde.get_properties_res_to_proto(get_properties_res)
return ClientMessage(get_properties_res=get_properties_res_proto)


def _get_parameters(client: Client) -> ClientMessage:
# No need to deserialize get_parameters_msg (it's empty)
parameters_res = client.get_parameters()
parameters_res_proto = serde.parameters_res_to_proto(parameters_res)
return ClientMessage(parameters_res=parameters_res_proto)
get_parameters_res = client.get_parameters()
get_parameters_res_proto = serde.get_parameters_res_to_proto(get_parameters_res)
return ClientMessage(get_parameters_res=get_parameters_res_proto)


def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage:
Expand Down
30 changes: 15 additions & 15 deletions src/py/flwr/client/grpc_client/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
EvaluateRes,
FitIns,
FitRes,
ParametersRes,
PropertiesIns,
PropertiesRes,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
serde,
typing,
)
Expand All @@ -35,7 +35,7 @@
class FlowerClientWithoutProps(Client):
"""Flower client not implementing get_properties."""

def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
pass

def fit(self, ins: FitIns) -> FitRes:
Expand All @@ -48,13 +48,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
class FlowerClientWithProps(Client):
"""Flower client implementing get_properties."""

def get_properties(self, ins: PropertiesIns) -> PropertiesRes:
return PropertiesRes(
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
return GetPropertiesRes(
status=typing.Status(code=typing.Code.OK, message="Success"),
properties={"str_prop": "val", "int_prop": 1},
)

def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
pass

def fit(self, ins: FitIns) -> FitRes:
Expand All @@ -68,22 +68,22 @@ def test_client_without_get_properties() -> None:
"""Test client implementing get_properties."""
# Prepare
client = FlowerClientWithoutProps()
ins = ServerMessage.PropertiesIns()
msg = ServerMessage(properties_ins=ins)
ins = ServerMessage.GetPropertiesIns()
msg = ServerMessage(get_properties_ins=ins)

# Execute
actual_msg, actual_sleep_duration, actual_keep_going = handle(
client=client, server_msg=msg
)

# Assert
expected_properties_res = ClientMessage.PropertiesRes(
expected_get_properties_res = ClientMessage.GetPropertiesRes(
status=Status(
code=Code.GET_PARAMETERS_NOT_IMPLEMENTED,
message="Client does not implement get_properties",
)
)
expected_msg = ClientMessage(properties_res=expected_properties_res)
expected_msg = ClientMessage(get_properties_res=expected_get_properties_res)

assert actual_msg == expected_msg
assert actual_sleep_duration == 0
Expand All @@ -94,16 +94,16 @@ def test_client_with_get_properties() -> None:
"""Test client not implementing get_properties."""
# Prepare
client = FlowerClientWithProps()
ins = ServerMessage.PropertiesIns()
msg = ServerMessage(properties_ins=ins)
ins = ServerMessage.GetPropertiesIns()
msg = ServerMessage(get_properties_ins=ins)

# Execute
actual_msg, actual_sleep_duration, actual_keep_going = handle(
client=client, server_msg=msg
)

# Assert
expected_properties_res = ClientMessage.PropertiesRes(
expected_get_properties_res = ClientMessage.GetPropertiesRes(
status=Status(
code=Code.OK,
message="Success",
Expand All @@ -112,7 +112,7 @@ def test_client_with_get_properties() -> None:
properties={"str_prop": "val", "int_prop": 1}
),
)
expected_msg = ClientMessage(properties_res=expected_properties_res)
expected_msg = ClientMessage(get_properties_res=expected_get_properties_res)

assert actual_msg == expected_msg
assert actual_sleep_duration == 0
Expand Down
14 changes: 7 additions & 7 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
EvaluateRes,
FitIns,
FitRes,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
Metrics,
ParametersRes,
PropertiesIns,
PropertiesRes,
Scalar,
Status,
parameters_to_weights,
Expand Down Expand Up @@ -171,19 +171,19 @@ class NumPyClientWrapper(Client):
def __init__(self, numpy_client: NumPyClient) -> None:
self.numpy_client = numpy_client

def get_properties(self, ins: PropertiesIns) -> PropertiesRes:
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return the current client properties."""
properties = self.numpy_client.get_properties(ins.config)
return PropertiesRes(
return GetPropertiesRes(
status=Status(code=Code.OK, message="Success"),
properties=properties,
)

def get_parameters(self) -> ParametersRes:
def get_parameters(self) -> GetParametersRes:
"""Return the current local model parameters."""
parameters = self.numpy_client.get_parameters()
parameters_proto = weights_to_parameters(parameters)
return ParametersRes(parameters=parameters_proto)
return GetParametersRes(parameters=parameters_proto)

def fit(self, ins: FitIns) -> FitRes:
"""Refine the provided weights using the locally held dataset."""
Expand Down
Loading