Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make client methods optional #1277

Merged
merged 13 commits into from
Jul 14, 2022
Merged
165 changes: 150 additions & 15 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,65 @@

import time
from logging import INFO
from typing import Optional, Union
from typing import Callable, Dict, List, Optional, Union

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
import numpy as np

from flwr.common import (
GRPC_MAX_MESSAGE_LENGTH,
parameters_to_weights,
weights_to_parameters,
)
from flwr.common.logger import log
from flwr.common.typing import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
Status,
)

from .client import Client
from .grpc_client.connection import grpc_connection
from .grpc_client.message_handler import handle
from .numpy_client import NumPyClient, NumPyClientWrapper
from .numpy_client import NumPyClient
from .numpy_client import has_evaluate as numpyclient_has_evaluate
from .numpy_client import has_fit as numpyclient_has_fit
from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
from .numpy_client import has_get_properties as numpyclient_has_get_properties

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
NumPyClient.fit did not return a tuple with 3 elements.
The returned values should have the following type signature:

Tuple[List[np.ndarray], int, Dict[str, Scalar]]

Example
-------

model.get_weights(), 10, {"accuracy": 0.95}

"""

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE = """
NumPyClient.evaluate did not return a tuple with 3 elements.
The returned values should have the following type signature:

Tuple[float, int, Dict[str, Scalar]]

Example
-------

0.5, 10, {"accuracy": 0.95}

"""


ClientLike = Union[Client, NumPyClient]


Expand Down Expand Up @@ -159,20 +207,10 @@ class `flwr.client.NumPyClient`.
>>> )
"""

# Wrap the NumPyClient
flower_client = NumPyClientWrapper(client)

# Delete get_properties method from NumPyClientWrapper if the user-provided
# NumPyClient instance does not implement get_properties. This enables the
# following call to start_client to handle NumPyClientWrapper instances like any
# other Client instance (which might or might not implement get_properties).
if not numpyclient_has_get_properties(client=client):
del NumPyClientWrapper.get_properties

# Start
start_client(
server_address=server_address,
client=flower_client,
client=_wrap_numpy_client(client=client),
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
)
Expand All @@ -181,5 +219,102 @@ class `flwr.client.NumPyClient`.
def to_client(client_like: ClientLike) -> Client:
"""Take any Client-like object and return it as a Client."""
if isinstance(client_like, NumPyClient):
return NumPyClientWrapper(numpy_client=client_like)
return _wrap_numpy_client(client=client_like)
return client_like


def _constructor(self: Client, numpy_client: NumPyClient) -> None:
self.numpy_client = numpy_client # type: ignore


def _get_properties(self: Client, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return the current client properties."""
properties = self.numpy_client.get_properties(config=ins.config) # type: ignore
return GetPropertiesRes(
status=Status(code=Code.OK, message="Success"),
properties=properties,
)


def _get_parameters(self: Client, ins: GetParametersIns) -> GetParametersRes:
"""Return the current local model parameters."""
parameters = self.numpy_client.get_parameters(config=ins.config) # type: ignore
parameters_proto = weights_to_parameters(parameters)
return GetParametersRes(
status=Status(code=Code.OK, message="Success"), parameters=parameters_proto
)


def _fit(self: Client, ins: FitIns) -> FitRes:
"""Refine the provided weights using the locally held dataset."""
# Deconstruct FitIns
parameters: List[np.ndarray] = parameters_to_weights(ins.parameters)

# Train
results = self.numpy_client.fit(parameters, ins.config) # type: ignore
if not (
len(results) == 3
and isinstance(results[0], list)
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)

# Return FitRes
parameters_prime, num_examples, metrics = results
parameters_prime_proto = weights_to_parameters(parameters_prime)
return FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=parameters_prime_proto,
num_examples=num_examples,
metrics=metrics,
)


def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
"""Evaluate the provided parameters using the locally held dataset."""
parameters: List[np.ndarray] = parameters_to_weights(ins.parameters)

results = self.numpy_client.evaluate(parameters, ins.config) # type: ignore
if not (
len(results) == 3
and isinstance(results[0], float)
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)

# Return EvaluateRes
loss, num_examples, metrics = results
return EvaluateRes(
status=Status(code=Code.OK, message="Success"),
loss=loss,
num_examples=num_examples,
metrics=metrics,
)


def _wrap_numpy_client(client: NumPyClient) -> Client:
member_dict: Dict[str, Callable] = { # type: ignore
"__init__": _constructor,
}

# Add wrapper type methods (if overridden)

if numpyclient_has_get_properties(client=client):
member_dict["get_properties"] = _get_properties

if numpyclient_has_get_parameters(client=client):
member_dict["get_parameters"] = _get_parameters

if numpyclient_has_fit(client=client):
member_dict["fit"] = _fit

if numpyclient_has_evaluate(client=client):
member_dict["evaluate"] = _evaluate

# Create wrapper class
wrapper_class = type("NumPyClientWrapper", (Client,), member_dict)

# Create and return an instance of the newly created class
return wrapper_class(numpy_client=client) # type: ignore
20 changes: 16 additions & 4 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Flower client (abstract base class)."""


from abc import ABC, abstractmethod
from abc import ABC

from flwr.common import (
EvaluateIns,
Expand Down Expand Up @@ -47,7 +47,6 @@ def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
The current client properties.
"""

@abstractmethod
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
"""Return the current local model parameters.

Expand All @@ -63,7 +62,6 @@ def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
The current local model parameters.
"""

@abstractmethod
def fit(self, ins: FitIns) -> FitRes:
"""Refine the provided weights using the locally held dataset.

Expand All @@ -81,7 +79,6 @@ def fit(self, ins: FitIns) -> FitRes:
such as the number of local training examples used for training.
"""

@abstractmethod
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
"""Evaluate the provided weights using the locally held dataset.

Expand All @@ -104,3 +101,18 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
def has_get_properties(client: Client) -> bool:
"""Check if Client implements get_properties."""
return type(client).get_properties != Client.get_properties


def has_get_parameters(client: Client) -> bool:
"""Check if Client implements get_parameters."""
return type(client).get_parameters != Client.get_parameters


def has_fit(client: Client) -> bool:
"""Check if Client implements fit."""
return type(client).fit != Client.fit


def has_evaluate(client: Client) -> bool:
"""Check if Client implements evaluate."""
return type(client).evaluate != Client.evaluate
100 changes: 86 additions & 14 deletions src/py/flwr/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
Status,
)

from .client import Client, has_get_properties
from .client import (
Client,
has_evaluate,
has_fit,
has_get_parameters,
has_get_properties,
)


class OverridingClient(Client):
Expand All @@ -53,19 +59,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:


class NotOverridingClient(Client):
"""Client not overriding `get_properties`."""

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
# This method is not expected to be called
raise Exception()

def fit(self, ins: FitIns) -> FitRes:
# This method is not expected to be called
raise Exception()

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
# This method is not expected to be called
raise Exception()
"""Client not overriding any Client method."""


def test_has_get_properties_true() -> None:
Expand All @@ -92,3 +86,81 @@ def test_has_get_properties_false() -> None:

# Assert
assert actual == expected


def test_has_get_parameters_true() -> None:
"""Test fit_clients."""
# Prepare
client = OverridingClient()
expected = True

# Execute
actual = has_get_parameters(client=client)

# Assert
assert actual == expected


def test_has_get_parameters_false() -> None:
"""Test fit_clients."""
# Prepare
client = NotOverridingClient()
expected = False

# Execute
actual = has_get_parameters(client=client)

# Assert
assert actual == expected


def test_has_fit_true() -> None:
"""Test fit_clients."""
# Prepare
client = OverridingClient()
expected = True

# Execute
actual = has_fit(client=client)

# Assert
assert actual == expected


def test_has_fit_false() -> None:
"""Test fit_clients."""
# Prepare
client = NotOverridingClient()
expected = False

# Execute
actual = has_fit(client=client)

# Assert
assert actual == expected


def test_has_evaluate_true() -> None:
"""Test fit_clients."""
# Prepare
client = OverridingClient()
expected = True

# Execute
actual = has_evaluate(client=client)

# Assert
assert actual == expected


def test_has_evaluate_false() -> None:
"""Test fit_clients."""
# Prepare
client = NotOverridingClient()
expected = False

# Execute
actual = has_evaluate(client=client)

# Assert
assert actual == expected
Loading