From b896059da70ae1760f2fa214ae36af1f606cface Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Tue, 5 Mar 2024 22:31:20 +0100 Subject: [PATCH 01/13] Introduce ClientApp train/evaluate/query --- src/py/flwr/client/client_app.py | 210 ++++++++++++++++++++++++++++--- 1 file changed, 196 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 9de6516c7a3..0ed33c7bd69 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -16,7 +16,7 @@ import importlib -from typing import List, Optional, cast +from typing import Callable, List, Optional, cast from flwr.client.message_handler.message_handler import ( handle_legacy_message_from_msgtype, @@ -25,6 +25,8 @@ from flwr.client.typing import ClientFn, Mod from flwr.common import Context, Message +from .typing import ClientAppCallable + class ClientApp: """Flower ClientApp. @@ -54,25 +56,205 @@ class ClientApp: def __init__( self, - client_fn: ClientFn, # Only for backward compatibility + client_fn: Optional[ClientFn] = None, # Only for backward compatibility mods: Optional[List[Mod]] = None, ) -> None: # Create wrapper function for `handle` - def ffn( - message: Message, - context: Context, - ) -> Message: # pylint: disable=invalid-name - out_message = handle_legacy_message_from_msgtype( - client_fn=client_fn, message=message, context=context - ) - return out_message - - # Wrap mods around the wrapped handle function - self._call = make_ffn(ffn, mods if mods is not None else []) + if client_fn is not None: + + def ffn( + message: Message, + context: Context, + ) -> Message: # pylint: disable=invalid-name + out_message = handle_legacy_message_from_msgtype( + client_fn=client_fn, message=message, context=context + ) + return out_message + + # Wrap mods around the wrapped handle function + self._call = make_ffn(ffn, mods if mods is not None else []) + else: + self._call = None + + # Step functions + self._train: Optional[ClientAppCallable] = None + self._evaluate: Optional[ClientAppCallable] = None + self._query: Optional[ClientAppCallable] = None def __call__(self, message: Message, context: Context) -> Message: """Execute `ClientApp`.""" - return self._call(message, context) + if self._call: + return self._call(message, context) + + if message.metadata.message_type == "train": + return self._train(message, context) + if message.metadata.message_type == "evaluate": + return self._evaluate(message, context) + if message.metadata.message_type == "query": + return self._query(message, context) + raise ValueError(f"Unknown message_type: {message.metadata.message_type}") + + def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the train fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.train() + >>> def train(message: Message, context: Context) -> Message: + >>> print("ClientApp training running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: + """Register the main fn with the ServerApp object.""" + if self._call: + raise ValueError( + """Use either `@app.train()` or `client_fn`, but not both. + + Use the `ClientApp` with an existing `client_fn`: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid) -> Client: + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp() + >>> client_fn=client_fn, + >>> ) + + Use the `ClientApp` with a custom train function: + + >>> app = ClientApp() + >>> + >>> @app.train() + >>> def train(message: Message, context: Context) -> Message: + >>> print("ClientApp training running") + >>> # Create and return an echo reply message + >>> return message.create_reply( + >>> content=message.content(), ttl="" + >>> ) + """, + ) + + # Register provided function with the ClientApp object + self._train = train_fn + + # Return provided function unmodified + return train_fn + + return train_decorator + + def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the evaluate fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.evaluate() + >>> def evaluate(message: Message, context: Context) -> Message: + >>> print("ClientApp evaluation running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: + """Register the main fn with the ServerApp object.""" + if self._call: + raise ValueError( + """Use either `@app.evaluate()` or `client_fn`, but not both. + + Use the `ClientApp` with an existing `client_fn`: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid) -> Client: + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp() + >>> client_fn=client_fn, + >>> ) + + Use the `ClientApp` with a custom evaluate function: + + >>> app = ClientApp() + >>> + >>> @app.evaluate() + >>> def evaluate(message: Message, context: Context) -> Message: + >>> print("ClientApp evaluation running") + >>> # Create and return an echo reply message + >>> return message.create_reply( + >>> content=message.content(), ttl="" + >>> ) + """, + ) + + # Register provided function with the ClientApp object + self._evaluate = evaluate_fn + + # Return provided function unmodified + return evaluate_fn + + return evaluate_decorator + + def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: + """Return a decorator that registers the query fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.query() + >>> def query(message: Message, context: Context) -> Message: + >>> print("ClientApp query running") + >>> # Create and return an echo reply message + >>> return message.create_reply(content=message.content(), ttl="") + """ + + def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: + """Register the main fn with the ServerApp object.""" + if self._call: + raise ValueError( + """Use either `@app.query()` or `client_fn`, but not both. + + Use the `ClientApp` with an existing `client_fn`: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid) -> Client: + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp() + >>> client_fn=client_fn, + >>> ) + + Use the `ClientApp` with a custom query function: + + >>> app = ClientApp() + >>> + >>> @app.query() + >>> def query(message: Message, context: Context) -> Message: + >>> print("ClientApp query running") + >>> # Create and return an echo reply message + >>> return message.create_reply( + >>> content=message.content(), ttl="" + >>> ) + """, + ) + + # Register provided function with the ClientApp object + self._query = query_fn + + # Return provided function unmodified + return query_fn + + return query_decorator class LoadClientAppError(Exception): From 16b93fd2b2de39443b9001c2b0ea55cc1c069e39 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 11:16:46 +0100 Subject: [PATCH 02/13] Use MessageType --- src/py/flwr/client/client_app.py | 8 ++++---- src/py/flwr/common/constant.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 0ed33c7bd69..4ca907aab49 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -23,7 +23,7 @@ ) from flwr.client.mod.utils import make_ffn from flwr.client.typing import ClientFn, Mod -from flwr.common import Context, Message +from flwr.common import Context, Message, MessageType from .typing import ClientAppCallable @@ -86,11 +86,11 @@ def __call__(self, message: Message, context: Context) -> Message: if self._call: return self._call(message, context) - if message.metadata.message_type == "train": + if message.metadata.message_type == MessageType.TRAIN: return self._train(message, context) - if message.metadata.message_type == "evaluate": + if message.metadata.message_type == MessageType.EVALUATE: return self._evaluate(message, context) - if message.metadata.message_type == "query": + if message.metadata.message_type == MessageType.QUERY: return self._query(message, context) raise ValueError(f"Unknown message_type: {message.metadata.message_type}") diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index d3f429586a0..c7a9baa113e 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -47,6 +47,7 @@ class MessageType: TRAIN = "train" EVALUATE = "evaluate" + QUERY = "query" def __new__(cls) -> MessageType: """Prevent instantiation.""" From 857fcfaa22c3d9bbf0b641a1c69f7f1cd0bdf7d4 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 12:20:16 +0100 Subject: [PATCH 03/13] Refactor __call__ --- src/py/flwr/client/client_app.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 4ca907aab49..c5f87fd3adc 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -59,6 +59,8 @@ def __init__( client_fn: Optional[ClientFn] = None, # Only for backward compatibility mods: Optional[List[Mod]] = None, ) -> None: + self._call: Optional[ClientAppCallable] = None + # Create wrapper function for `handle` if client_fn is not None: @@ -73,8 +75,6 @@ def ffn( # Wrap mods around the wrapped handle function self._call = make_ffn(ffn, mods if mods is not None else []) - else: - self._call = None # Step functions self._train: Optional[ClientAppCallable] = None @@ -83,15 +83,25 @@ def ffn( def __call__(self, message: Message, context: Context) -> Message: """Execute `ClientApp`.""" + # Execute message using `client_fn` if self._call: return self._call(message, context) + # Execute message using a new if message.metadata.message_type == MessageType.TRAIN: - return self._train(message, context) + if self._train: + return self._train(message, context) + raise ValueError("No `train` function registered") if message.metadata.message_type == MessageType.EVALUATE: - return self._evaluate(message, context) + if self._evaluate: + return self._evaluate(message, context) + raise ValueError("No `evaluate` function registered") if message.metadata.message_type == MessageType.QUERY: - return self._query(message, context) + if self._query: + return self._query(message, context) + raise ValueError("No `query` function registered") + + # Message type did not match one of the known message types abvoe raise ValueError(f"Unknown message_type: {message.metadata.message_type}") def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: From d8023210ba7e020cc89e92ea2e308258130591a3 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 12:57:27 +0100 Subject: [PATCH 04/13] Refactor error handling --- src/py/flwr/client/client_app.py | 118 +++++++++---------------------- 1 file changed, 34 insertions(+), 84 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index c5f87fd3adc..a1948ae82db 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -121,34 +121,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: """Register the main fn with the ServerApp object.""" if self._call: - raise ValueError( - """Use either `@app.train()` or `client_fn`, but not both. - - Use the `ClientApp` with an existing `client_fn`: - - >>> class FlowerClient(NumPyClient): - >>> # ... - >>> - >>> def client_fn(cid) -> Client: - >>> return FlowerClient().to_client() - >>> - >>> app = ClientApp() - >>> client_fn=client_fn, - >>> ) - - Use the `ClientApp` with a custom train function: - - >>> app = ClientApp() - >>> - >>> @app.train() - >>> def train(message: Message, context: Context) -> Message: - >>> print("ClientApp training running") - >>> # Create and return an echo reply message - >>> return message.create_reply( - >>> content=message.content(), ttl="" - >>> ) - """, - ) + raise _registration_error(MessageType.TRAIN) # Register provided function with the ClientApp object self._train = train_fn @@ -175,34 +148,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: """Register the main fn with the ServerApp object.""" if self._call: - raise ValueError( - """Use either `@app.evaluate()` or `client_fn`, but not both. - - Use the `ClientApp` with an existing `client_fn`: - - >>> class FlowerClient(NumPyClient): - >>> # ... - >>> - >>> def client_fn(cid) -> Client: - >>> return FlowerClient().to_client() - >>> - >>> app = ClientApp() - >>> client_fn=client_fn, - >>> ) - - Use the `ClientApp` with a custom evaluate function: - - >>> app = ClientApp() - >>> - >>> @app.evaluate() - >>> def evaluate(message: Message, context: Context) -> Message: - >>> print("ClientApp evaluation running") - >>> # Create and return an echo reply message - >>> return message.create_reply( - >>> content=message.content(), ttl="" - >>> ) - """, - ) + raise _registration_error(MessageType.EVALUATE) # Register provided function with the ClientApp object self._evaluate = evaluate_fn @@ -229,34 +175,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: """Register the main fn with the ServerApp object.""" if self._call: - raise ValueError( - """Use either `@app.query()` or `client_fn`, but not both. - - Use the `ClientApp` with an existing `client_fn`: - - >>> class FlowerClient(NumPyClient): - >>> # ... - >>> - >>> def client_fn(cid) -> Client: - >>> return FlowerClient().to_client() - >>> - >>> app = ClientApp() - >>> client_fn=client_fn, - >>> ) - - Use the `ClientApp` with a custom query function: - - >>> app = ClientApp() - >>> - >>> @app.query() - >>> def query(message: Message, context: Context) -> Message: - >>> print("ClientApp query running") - >>> # Create and return an echo reply message - >>> return message.create_reply( - >>> content=message.content(), ttl="" - >>> ) - """, - ) + raise _registration_error(MessageType.QUERY) # Register provided function with the ClientApp object self._query = query_fn @@ -314,3 +233,34 @@ def load_client_app(module_attribute_str: str) -> ClientApp: ) from None return cast(ClientApp, attribute) + + +def _registration_error(fn_name: str) -> ValueError: + return ValueError( + f"""Use either `@app.{fn_name}()` or `client_fn`, but not both. + + Use the `ClientApp` with an existing `client_fn`: + + >>> class FlowerClient(NumPyClient): + >>> # ... + >>> + >>> def client_fn(cid) -> Client: + >>> return FlowerClient().to_client() + >>> + >>> app = ClientApp() + >>> client_fn=client_fn, + >>> ) + + Use the `ClientApp` with a custom {fn_name} function: + + >>> app = ClientApp() + >>> + >>> @app.{fn_name}() + >>> def {fn_name}(message: Message, context: Context) -> Message: + >>> print("ClientApp {fn_name} running") + >>> # Create and return an echo reply message + >>> return message.create_reply( + >>> content=message.content(), ttl="" + >>> ) + """, + ) From bbbeb956ad52b7984c2a6ef438453de97be1bd58 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 21:39:55 +0100 Subject: [PATCH 05/13] Add app-pytorch example for ClientApp functions --- examples/app-pytorch/client_new.py | 24 ++++++++++++++++++++++++ src/py/flwr/client/app.py | 3 +++ 2 files changed, 27 insertions(+) create mode 100644 examples/app-pytorch/client_new.py diff --git a/examples/app-pytorch/client_new.py b/examples/app-pytorch/client_new.py new file mode 100644 index 00000000000..d10a60c3ccd --- /dev/null +++ b/examples/app-pytorch/client_new.py @@ -0,0 +1,24 @@ +import flwr +from flwr.common import Message, Context + + +# Run via `flower-client-app client:app` +app = flwr.client.ClientApp() + + +@app.train() +def train(msg: Message, ctx: Context): + print("`train` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + + +@app.evaluate() +def eval(msg: Message, ctx: Context): + print("`evaluate` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + + +@app.query() +def q(msg: Message, ctx: Context): + print("`query` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl=msg.metadata.ttl) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 43781776f78..b47eabd006c 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -445,6 +445,8 @@ def _load_client_app() -> ClientApp: time.sleep(3) # Wait for 3s before asking again continue + log(INFO, "Received message") + # Handle control message out_message, sleep_duration = handle_control_message(message) if out_message: @@ -471,6 +473,7 @@ def _load_client_app() -> ClientApp: # Send send(out_message) + log(INFO, "Sent reply") # Unregister node if delete_node is not None: From 03e7b89db36dc55a63f5ae6343f50515a91841c9 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 22:14:53 +0100 Subject: [PATCH 06/13] Use MessageType in server_custom --- examples/app-pytorch/server_custom.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py index 1f0cb0d26d9..db7f0db41cc 100644 --- a/examples/app-pytorch/server_custom.py +++ b/examples/app-pytorch/server_custom.py @@ -3,27 +3,22 @@ import time import flwr as fl -from flwr.server import Driver -from flwr.common import Context - from flwr.common import ( - ServerMessage, + Context, FitIns, ndarrays_to_parameters, - serde, parameters_to_ndarrays, - ClientMessage, NDArrays, Code, + Message, + MessageType, + Metrics, ) -from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 -from flwr.server.strategy.aggregate import aggregate -from flwr.common import Metrics -from flwr.server import History -from flwr.common import serde -from task import Net, get_parameters, set_parameters from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres -from flwr.common import Message +from flwr.server import Driver, History +from flwr.server.strategy.aggregate import aggregate + +from task import Net, get_parameters # Define metric aggregation function @@ -56,7 +51,6 @@ def main(driver: Driver, context: Context) -> None: """.""" print("RUNNING!!!!!") - anonymous_client_nodes = False num_client_nodes_per_round = 2 sleep_time = 1 num_rounds = 3 @@ -92,7 +86,7 @@ def main(driver: Driver, context: Context) -> None: for node_id in sampled_nodes: message = driver.create_message( content=recordset, - message_type="fit", + message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), ttl="", From 41075b2395c1f6cf3c40d6fd90a19f8499b85c85 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 22:42:46 +0100 Subject: [PATCH 07/13] Remove client_new --- examples/app-pytorch/client_new.py | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 examples/app-pytorch/client_new.py diff --git a/examples/app-pytorch/client_new.py b/examples/app-pytorch/client_new.py deleted file mode 100644 index d10a60c3ccd..00000000000 --- a/examples/app-pytorch/client_new.py +++ /dev/null @@ -1,24 +0,0 @@ -import flwr -from flwr.common import Message, Context - - -# Run via `flower-client-app client:app` -app = flwr.client.ClientApp() - - -@app.train() -def train(msg: Message, ctx: Context): - print("`train` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=msg.metadata.ttl) - - -@app.evaluate() -def eval(msg: Message, ctx: Context): - print("`evaluate` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=msg.metadata.ttl) - - -@app.query() -def q(msg: Message, ctx: Context): - print("`query` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=msg.metadata.ttl) From 281b4f7053ff7f724f88f5f5b935a29874a50c2d Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 6 Mar 2024 22:46:17 +0100 Subject: [PATCH 08/13] Create ClientApp fn example --- examples/app-pytorch/client_new.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 examples/app-pytorch/client_new.py diff --git a/examples/app-pytorch/client_new.py b/examples/app-pytorch/client_new.py new file mode 100644 index 00000000000..f9be16db4f7 --- /dev/null +++ b/examples/app-pytorch/client_new.py @@ -0,0 +1,24 @@ +import flwr +from flwr.common import Message, Context + + +# Run via `flower-client-app client:app` +app = flwr.client.ClientApp() + + +@app.train() +def train(msg: Message, ctx: Context): + print("`train` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + + +@app.evaluate() +def eval(msg: Message, ctx: Context): + print("`evaluate` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + + +@app.query() +def query(msg: Message, ctx: Context): + print("`query` is not implemented, echoing original message") + return msg.create_reply(msg.content, ttl=msg.metadata.ttl) From fa2dd48fc155c0111ac80f5196648aaa328e7902 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 7 Mar 2024 16:57:38 +0100 Subject: [PATCH 09/13] Improve wording --- src/py/flwr/client/client_app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index a1948ae82db..139a8a32f67 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -119,7 +119,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: """ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: - """Register the main fn with the ServerApp object.""" + """Register the train fn with the ServerApp object.""" if self._call: raise _registration_error(MessageType.TRAIN) @@ -146,7 +146,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: """ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: - """Register the main fn with the ServerApp object.""" + """Register the evaluate fn with the ServerApp object.""" if self._call: raise _registration_error(MessageType.EVALUATE) @@ -173,7 +173,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: """ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: - """Register the main fn with the ServerApp object.""" + """Register the query fn with the ServerApp object.""" if self._call: raise _registration_error(MessageType.QUERY) From 9de20e1f3902893c3f466c175864801f774789fb Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 7 Mar 2024 21:08:39 +0100 Subject: [PATCH 10/13] Support mods --- src/py/flwr/client/client_app.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 139a8a32f67..5a9c76f8131 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -59,9 +59,10 @@ def __init__( client_fn: Optional[ClientFn] = None, # Only for backward compatibility mods: Optional[List[Mod]] = None, ) -> None: - self._call: Optional[ClientAppCallable] = None + self._mods: List[Mod] = mods if mods is not None else [] # Create wrapper function for `handle` + self._call: Optional[ClientAppCallable] = None if client_fn is not None: def ffn( @@ -124,7 +125,8 @@ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: raise _registration_error(MessageType.TRAIN) # Register provided function with the ClientApp object - self._train = train_fn + # Wrap mods around the wrapped step function + self._train = make_ffn(train_fn, self._mods) # Return provided function unmodified return train_fn @@ -151,7 +153,8 @@ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: raise _registration_error(MessageType.EVALUATE) # Register provided function with the ClientApp object - self._evaluate = evaluate_fn + # Wrap mods around the wrapped step function + self._evaluate = make_ffn(evaluate_fn, self._mods) # Return provided function unmodified return evaluate_fn @@ -178,7 +181,8 @@ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: raise _registration_error(MessageType.QUERY) # Register provided function with the ClientApp object - self._query = query_fn + # Wrap mods around the wrapped step function + self._query = make_ffn(query_fn, self._mods) # Return provided function unmodified return query_fn From d9db4b8a0df52a170c4f048cbd590b089cd06289 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 8 Mar 2024 14:43:10 +0100 Subject: [PATCH 11/13] Add low-level example --- examples/app-pytorch/client.py | 10 +++- .../{client_new.py => client_low_level.py} | 19 +++++-- examples/app-pytorch/server.py | 24 ++++++++- examples/app-pytorch/server_low_level.py | 52 +++++++++++++++++++ 4 files changed, 97 insertions(+), 8 deletions(-) rename examples/app-pytorch/{client_new.py => client_low_level.py} (52%) create mode 100644 examples/app-pytorch/server_low_level.py diff --git a/examples/app-pytorch/client.py b/examples/app-pytorch/client.py index 8095a2d7aa9..e429e2fe54f 100644 --- a/examples/app-pytorch/client.py +++ b/examples/app-pytorch/client.py @@ -20,8 +20,6 @@ # Define Flower client class FlowerClient(fl.client.NumPyClient): - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: - return get_parameters(net) def fit(self, parameters, config): set_parameters(net, parameters) @@ -42,3 +40,11 @@ def client_fn(cid: str): app = fl.client.ClientApp( client_fn=client_fn, ) + + +# Legacy mode +if __name__ == "__main__": + fl.client.start_client( + server_address="127.0.0.1:8080", + client=FlowerClient().to_client(), + ) diff --git a/examples/app-pytorch/client_new.py b/examples/app-pytorch/client_low_level.py similarity index 52% rename from examples/app-pytorch/client_new.py rename to examples/app-pytorch/client_low_level.py index f9be16db4f7..2036bb0f920 100644 --- a/examples/app-pytorch/client_new.py +++ b/examples/app-pytorch/client_low_level.py @@ -2,23 +2,34 @@ from flwr.common import Message, Context +def hello_world_mod(msg, ctx, call_next) -> Message: + print("Hello, ... [pause for dramatic effect]") + out = call_next(msg, ctx) + print("[pause was long enough] ... World!") + return out + + # Run via `flower-client-app client:app` -app = flwr.client.ClientApp() +app = flwr.client.ClientApp( + mods=[ + hello_world_mod, + ], +) @app.train() def train(msg: Message, ctx: Context): print("`train` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + return msg.create_reply(msg.content, ttl="") @app.evaluate() def eval(msg: Message, ctx: Context): print("`evaluate` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + return msg.create_reply(msg.content, ttl="") @app.query() def query(msg: Message, ctx: Context): print("`query` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=msg.metadata.ttl) + return msg.create_reply(msg.content, ttl="") diff --git a/examples/app-pytorch/server.py b/examples/app-pytorch/server.py index fbf3f24a133..3acaca4fe84 100644 --- a/examples/app-pytorch/server.py +++ b/examples/app-pytorch/server.py @@ -1,7 +1,8 @@ from typing import List, Tuple import flwr as fl -from flwr.common import Metrics +from flwr.common import Metrics, ndarrays_to_parameters +from task import Net, get_parameters # Define metric aggregation function @@ -25,17 +26,36 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } +# Initialize model parameters +ndarrays = get_parameters(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, # Select all available clients fraction_evaluate=0.0, # Disable evaluation min_available_clients=2, fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, ) +# Define config +config = fl.server.ServerConfig(num_rounds=3) + + # Run via `flower-server-app server:app` app = fl.server.ServerApp( - config=fl.server.ServerConfig(num_rounds=3), + config=config, strategy=strategy, ) + + +# Legacy mode +if __name__ == "__main__": + fl.server.start_server( + server_address="0.0.0.0:8080", + config=config, + strategy=strategy, + ) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py new file mode 100644 index 00000000000..ea8c161c428 --- /dev/null +++ b/examples/app-pytorch/server_low_level.py @@ -0,0 +1,52 @@ +from typing import List, Tuple, Dict +import random +import time + +import flwr as fl +from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet +from flwr.server import Driver + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + """.""" + print("Starting test run") + for server_round in range(3): + print(f"Commencing server round {server_round + 1}") + + # Get node IDs + node_ids = driver.get_node_ids() + + # Create messages + recordset = RecordSet() + messages = [] + for node_id in node_ids: + message = driver.create_message( + content=recordset, + message_type=MessageType.TRAIN, + dst_node_id=node_id, + group_id=str(server_round), + ttl="", + ) + messages.append(message) + + # Send messages + message_ids = driver.push_messages(messages) + print(f"Pushed {len(message_ids)} messages: {message_ids}") + + # Wait for results, ignore empty message_ids + message_ids = [message_id for message_id in message_ids if message_id != ""] + all_replies: List[Message] = [] + while True: + replies = driver.pull_messages(message_ids=message_ids) + print(f"Got {len(replies)} results") + all_replies += replies + if len(all_replies) == len(message_ids): + break + time.sleep(3) + + print(f"Received {len(all_replies)} results") From 34ff6448829c2b977dfa4d67324394022680bc43 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 8 Mar 2024 14:48:40 +0100 Subject: [PATCH 12/13] Tweak messages --- examples/app-pytorch/client_low_level.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py index 2036bb0f920..f8c7eb05837 100644 --- a/examples/app-pytorch/client_low_level.py +++ b/examples/app-pytorch/client_low_level.py @@ -3,9 +3,9 @@ def hello_world_mod(msg, ctx, call_next) -> Message: - print("Hello, ... [pause for dramatic effect]") + print("Hello, ...[pause for dramatic effect]...") out = call_next(msg, ctx) - print("[pause was long enough] ... World!") + print("...[pause was long enough]... World!") return out From e4238d6e087dcc969f4b8e4b2cbd5bf8963e4608 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 8 Mar 2024 16:29:14 +0100 Subject: [PATCH 13/13] Update examples/app-pytorch/server_low_level.py --- examples/app-pytorch/server_low_level.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py index ea8c161c428..560babac1b9 100644 --- a/examples/app-pytorch/server_low_level.py +++ b/examples/app-pytorch/server_low_level.py @@ -13,7 +13,7 @@ @app.main() def main(driver: Driver, context: Context) -> None: - """.""" + """This is a stub example that simply sends and receives messages.""" print("Starting test run") for server_round in range(3): print(f"Commencing server round {server_round + 1}")