From 2fa173555325e0e57f3e7f07be4bc2bc713028fc Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Fri, 14 Apr 2023 12:43:15 +0200 Subject: [PATCH 1/7] Add weights averaging --- examples/mt-pytorch/driver.py | 42 +++++++++++++++++++++++++++------ examples/mt-pytorch/run-grpc.sh | 2 +- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index a79b77cc4f31..c0b44e56d109 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -1,10 +1,20 @@ -from typing import List +from typing import List, Tuple import random import time from flwr.driver import Driver -from flwr.common import ServerMessage, FitIns, ndarrays_to_parameters, serde +from flwr.common import ( + ServerMessage, + FitIns, + ndarrays_to_parameters, + serde, + parameters_to_ndarrays, + ClientMessage, + NDArrays, + Code, +) from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 +from flwr.server.strategy.aggregate import aggregate from task import Net, get_parameters, set_parameters @@ -13,9 +23,9 @@ # -------------------------------------------------------------------------- Driver SDK anonymous_client_nodes = True -num_client_nodes_per_round = 1 +num_client_nodes_per_round = 2 sleep_time = 1 -num_rounds = 1 +num_rounds = 3 parameters = ndarrays_to_parameters(get_parameters(net=Net())) # -------------------------------------------------------------------------- Driver SDK @@ -90,7 +100,8 @@ ), consumer=node_pb2.Node( node_id=sampled_node_id, - anonymous=anonymous_client_nodes, # Must be True if we're working with anonymous clients + anonymous=anonymous_client_nodes, + # Must be True if we're working with anonymous clients ), legacy_server_message=server_message_proto, ), @@ -137,10 +148,26 @@ if len(all_task_res) == len(task_ids): break - # "Aggregate" results - node_messages = [task_res.task.legacy_client_message for task_res in all_task_res] + # Collect correct results + node_messages: List[ClientMessage] = [] + for task_res in all_task_res: + if hasattr(task_res.task, "legacy_client_message"): + node_messages.append(task_res.task.legacy_client_message) print(f"Received {len(node_messages)} results") + weights_results: List[Tuple[NDArrays, int]] = [] + for node_message in node_messages: + if hasattr(node_message, "fit_res"): + fit_res = node_message.fit_res + # Aggregate only if the status is OK + if fit_res.status.code == Code.OK.value: + weights_results.append( + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + ) + # Aggregate results - FedAvg + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated + time.sleep(sleep_time) # Repeat @@ -148,3 +175,4 @@ # -------------------------------------------------------------------------- Driver SDK driver.disconnect() # -------------------------------------------------------------------------- Driver SDK +print("Driver disconnected") diff --git a/examples/mt-pytorch/run-grpc.sh b/examples/mt-pytorch/run-grpc.sh index 2aead43a34b5..a8a935786a9d 100755 --- a/examples/mt-pytorch/run-grpc.sh +++ b/examples/mt-pytorch/run-grpc.sh @@ -9,7 +9,7 @@ echo "Starting server" flower-server --rest & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 0`; do +for i in `seq 0 1`; do echo "Starting client $i" python client.py & done From 97083e63e33c156028c7c55f1377eba4dd5a34ad Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 2 May 2023 09:43:59 +0200 Subject: [PATCH 2/7] Use the ProtoBuf HasField instead of hasattr, reduce nesting --- examples/mt-pytorch/driver.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index c0b44e56d109..fcd9a27e154e 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -151,19 +151,21 @@ # Collect correct results node_messages: List[ClientMessage] = [] for task_res in all_task_res: - if hasattr(task_res.task, "legacy_client_message"): + if task_res.task.HasField("legacy_client_message"): node_messages.append(task_res.task.legacy_client_message) print(f"Received {len(node_messages)} results") weights_results: List[Tuple[NDArrays, int]] = [] for node_message in node_messages: - if hasattr(node_message, "fit_res"): - fit_res = node_message.fit_res - # Aggregate only if the status is OK - if fit_res.status.code == Code.OK.value: - weights_results.append( - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - ) + if not node_message.fit_res: + continue + fit_res = node_message.fit_res + # Aggregate only if the status is OK + if fit_res.status.code != Code.OK.value: + continue + weights_results.append( + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + ) # Aggregate results - FedAvg parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) parameters = parameters_aggregated From f1ebd56260b90d45e56a5aa0f8cfae66dacc7735 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 2 May 2023 09:44:33 +0200 Subject: [PATCH 3/7] Add http:// before the server address --- examples/mt-pytorch/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py index c81048d85b90..0435ce57228a 100644 --- a/examples/mt-pytorch/client.py +++ b/examples/mt-pytorch/client.py @@ -31,7 +31,7 @@ def evaluate(self, parameters, config): # Start Flower client fl.client.start_numpy_client( - server_address="0.0.0.0:9093", + server_address="http://0.0.0.0:9093", client=FlowerClient(), rest=True, ) From ace797ec4798511afb1743cf3e28eb2a945bea85 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 4 May 2023 13:51:04 +0100 Subject: [PATCH 4/7] Align driver.py and server.py in mt-pytorch --- examples/mt-pytorch/client.py | 13 ++++++---- examples/mt-pytorch/driver.py | 45 +++++++++++++++++++++++++++++++---- examples/mt-pytorch/server.py | 24 +++++++++++++++---- examples/mt-pytorch/task.py | 29 +++++++++++++++------- 4 files changed, 90 insertions(+), 21 deletions(-) diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py index 0435ce57228a..b187b4477f13 100644 --- a/examples/mt-pytorch/client.py +++ b/examples/mt-pytorch/client.py @@ -1,4 +1,6 @@ +from typing import Dict import flwr as fl +from flwr.common import NDArrays, Scalar from task import ( Net, @@ -18,10 +20,13 @@ # Define Flower client class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + return get_parameters(net) + def fit(self, parameters, config): set_parameters(net, parameters) - train(net, trainloader, epochs=1) - return get_parameters(net), len(trainloader.dataset), {} + results = train(net, trainloader, testloader, epochs=1) + return get_parameters(net), len(trainloader.dataset), results def evaluate(self, parameters, config): set_parameters(net, parameters) @@ -31,7 +36,7 @@ def evaluate(self, parameters, config): # Start Flower client fl.client.start_numpy_client( - server_address="http://0.0.0.0:9093", + server_address="0.0.0.0:9092", client=FlowerClient(), - rest=True, + rest=False, ) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index fcd9a27e154e..d6fa7c8f5d46 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -15,14 +15,37 @@ ) from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 from flwr.server.strategy.aggregate import aggregate - +from flwr.common import Metrics +from flwr.server import History +from flwr.common import serde from task import Net, get_parameters, set_parameters + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + + # -------------------------------------------------------------------------- Driver SDK driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK -anonymous_client_nodes = True +anonymous_client_nodes = False num_client_nodes_per_round = 2 sleep_time = 1 num_rounds = 3 @@ -32,6 +55,7 @@ driver.connect() # -------------------------------------------------------------------------- Driver SDK +history = History() for server_round in range(num_rounds): print(f"Commencing server round {server_round + 1}") @@ -156,6 +180,7 @@ print(f"Received {len(node_messages)} results") weights_results: List[Tuple[NDArrays, int]] = [] + metrics_results: List = [] for node_message in node_messages: if not node_message.fit_res: continue @@ -166,13 +191,25 @@ weights_results.append( (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) ) - # Aggregate results - FedAvg + metrics_results.append((fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics))) + + # Aggregate parameters (FedAvg) parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) parameters = parameters_aggregated + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit(server_round=server_round, metrics=metrics_aggregated) + print("Round ", server_round, " metrics: ", metrics_aggregated) + + # Slow down the start of the next round time.sleep(sleep_time) - # Repeat +print("app_fit: losses_distributed %s", str(history.losses_distributed)) +print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) +print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) +print("app_fit: losses_centralized %s", str(history.losses_centralized)) +print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) # -------------------------------------------------------------------------- Driver SDK driver.disconnect() diff --git a/examples/mt-pytorch/server.py b/examples/mt-pytorch/server.py index fe691a88aba0..61b0de90e102 100644 --- a/examples/mt-pytorch/server.py +++ b/examples/mt-pytorch/server.py @@ -6,20 +6,34 @@ # Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - # Multiply accuracy of each client by number of examples used - accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] # Aggregate and return custom metric (weighted average) - return {"accuracy": sum(accuracies) / sum(examples)} + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } # Define strategy -strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, +) # Start Flower server fl.server.start_server( - server_address="0.0.0.0:8080", + server_address="0.0.0.0:9092", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, ) diff --git a/examples/mt-pytorch/task.py b/examples/mt-pytorch/task.py index b474048461a2..752b9b125a76 100644 --- a/examples/mt-pytorch/task.py +++ b/examples/mt-pytorch/task.py @@ -36,20 +36,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc3(x) -def train(net, trainloader, epochs): +def train(net, trainloader, valloader, epochs, device: str = "cpu"): """Train the model on the training set.""" - criterion = torch.nn.CrossEntropyLoss() + print("Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - batch_count = 0 + net.train() for _ in range(epochs): - for images, labels in tqdm(trainloader): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) optimizer.zero_grad() - criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + loss = criterion(net(images), labels) + loss.backward() optimizer.step() - batch_count += 1 - if batch_count == 100: - break # Just do a few batches + net.to("cpu") # move model back to CPU + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results def test(net, testloader): """Validate the model on the test set.""" From 6df391fa8993f1f99259fb05cb5aa16430e7e283 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 4 May 2023 18:54:18 +0100 Subject: [PATCH 5/7] Update changelog --- doc/source/changelog.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/source/changelog.md b/doc/source/changelog.md index 5d537615186c..d7518b016212 100644 --- a/doc/source/changelog.md +++ b/doc/source/changelog.md @@ -4,6 +4,10 @@ ### What's new? +- **Add parameter aggregation to** `mt-pytorch` **code example** ([#1785](https://github.com/adap/flower/pull/1785)) + + The `mt-pytorch` example shows how to aggregate parameters when writing a driver script. The included `driver.py` and `server.py` have been aligned to demonstrate both the low-level way and the high-level way of building server-side logic. + ### Incompatible changes None From 49dab90c63660c74e15ea0c415669c6924b6baed Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 4 May 2023 18:58:43 +0100 Subject: [PATCH 6/7] Update README --- examples/mt-pytorch/README.md | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/examples/mt-pytorch/README.md b/examples/mt-pytorch/README.md index 93a9ce7bd584..587dbc109026 100644 --- a/examples/mt-pytorch/README.md +++ b/examples/mt-pytorch/README.md @@ -8,16 +8,36 @@ This example contains highly experimental code. Please consult the regular PyTor ./dev/venv-reset.sh ``` -## Exec +## Run with Driver API -Terminal 1: start Driver API server +Terminal 1: start Flower Driver API server ```bash flower-server ``` -Terminal 2: run driver script +Terminal 2+3: start two clients + +```bash +python client.py +``` + +Terminal 4: run driver script ```bash python driver.py ``` + +## Run in legacy mode + +Terminal 1: start Flower server + +```bash +python server.py +``` + +Terminal 2+3: start two clients + +```bash +python client.py +``` From 03dc9b320908c30826b88954423c345a91ee4de4 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 4 May 2023 19:04:02 +0100 Subject: [PATCH 7/7] Format code --- examples/mt-pytorch/client.py | 2 +- examples/mt-pytorch/driver.py | 15 ++++++++++----- examples/mt-pytorch/server.py | 6 ++++-- examples/mt-pytorch/task.py | 1 + 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py index b187b4477f13..ed536adf27a1 100644 --- a/examples/mt-pytorch/client.py +++ b/examples/mt-pytorch/client.py @@ -22,7 +22,7 @@ class FlowerClient(fl.client.NumPyClient): def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: return get_parameters(net) - + def fit(self, parameters, config): set_parameters(net, parameters) results = train(net, trainloader, testloader, epochs=1) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index d6fa7c8f5d46..b3144de03bb9 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -24,10 +24,12 @@ # Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] - + # Multiply accuracy of each client by number of examples used train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] @@ -40,7 +42,6 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } - # -------------------------------------------------------------------------- Driver SDK driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK @@ -191,7 +192,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: weights_results.append( (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) ) - metrics_results.append((fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics))) + metrics_results.append( + (fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics)) + ) # Aggregate parameters (FedAvg) parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) @@ -199,7 +202,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # Aggregate metrics metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit(server_round=server_round, metrics=metrics_aggregated) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) print("Round ", server_round, " metrics: ", metrics_aggregated) # Slow down the start of the next round diff --git a/examples/mt-pytorch/server.py b/examples/mt-pytorch/server.py index 61b0de90e102..d96edd7d45ad 100644 --- a/examples/mt-pytorch/server.py +++ b/examples/mt-pytorch/server.py @@ -7,10 +7,12 @@ # Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] - + # Multiply accuracy of each client by number of examples used train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] diff --git a/examples/mt-pytorch/task.py b/examples/mt-pytorch/task.py index 752b9b125a76..ebdb508e16d9 100644 --- a/examples/mt-pytorch/task.py +++ b/examples/mt-pytorch/task.py @@ -64,6 +64,7 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"): } return results + def test(net, testloader): """Validate the model on the test set.""" criterion = torch.nn.CrossEntropyLoss()