diff --git a/doc/source/changelog.md b/doc/source/changelog.md index 5d537615186c..d7518b016212 100644 --- a/doc/source/changelog.md +++ b/doc/source/changelog.md @@ -4,6 +4,10 @@ ### What's new? +- **Add parameter aggregation to** `mt-pytorch` **code example** ([#1785](https://github.com/adap/flower/pull/1785)) + + The `mt-pytorch` example shows how to aggregate parameters when writing a driver script. The included `driver.py` and `server.py` have been aligned to demonstrate both the low-level way and the high-level way of building server-side logic. + ### Incompatible changes None diff --git a/examples/mt-pytorch/README.md b/examples/mt-pytorch/README.md index 93a9ce7bd584..587dbc109026 100644 --- a/examples/mt-pytorch/README.md +++ b/examples/mt-pytorch/README.md @@ -8,16 +8,36 @@ This example contains highly experimental code. Please consult the regular PyTor ./dev/venv-reset.sh ``` -## Exec +## Run with Driver API -Terminal 1: start Driver API server +Terminal 1: start Flower Driver API server ```bash flower-server ``` -Terminal 2: run driver script +Terminal 2+3: start two clients + +```bash +python client.py +``` + +Terminal 4: run driver script ```bash python driver.py ``` + +## Run in legacy mode + +Terminal 1: start Flower server + +```bash +python server.py +``` + +Terminal 2+3: start two clients + +```bash +python client.py +``` diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py index c81048d85b90..ed536adf27a1 100644 --- a/examples/mt-pytorch/client.py +++ b/examples/mt-pytorch/client.py @@ -1,4 +1,6 @@ +from typing import Dict import flwr as fl +from flwr.common import NDArrays, Scalar from task import ( Net, @@ -18,10 +20,13 @@ # Define Flower client class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + return get_parameters(net) + def fit(self, parameters, config): set_parameters(net, parameters) - train(net, trainloader, epochs=1) - return get_parameters(net), len(trainloader.dataset), {} + results = train(net, trainloader, testloader, epochs=1) + return get_parameters(net), len(trainloader.dataset), results def evaluate(self, parameters, config): set_parameters(net, parameters) @@ -31,7 +36,7 @@ def evaluate(self, parameters, config): # Start Flower client fl.client.start_numpy_client( - server_address="0.0.0.0:9093", + server_address="0.0.0.0:9092", client=FlowerClient(), - rest=True, + rest=False, ) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index a79b77cc4f31..b3144de03bb9 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -1,27 +1,62 @@ -from typing import List +from typing import List, Tuple import random import time from flwr.driver import Driver -from flwr.common import ServerMessage, FitIns, ndarrays_to_parameters, serde +from flwr.common import ( + ServerMessage, + FitIns, + ndarrays_to_parameters, + serde, + parameters_to_ndarrays, + ClientMessage, + NDArrays, + Code, +) from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2 - +from flwr.server.strategy.aggregate import aggregate +from flwr.common import Metrics +from flwr.server import History +from flwr.common import serde from task import Net, get_parameters, set_parameters + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + # -------------------------------------------------------------------------- Driver SDK driver = Driver(driver_service_address="0.0.0.0:9091", certificates=None) # -------------------------------------------------------------------------- Driver SDK -anonymous_client_nodes = True -num_client_nodes_per_round = 1 +anonymous_client_nodes = False +num_client_nodes_per_round = 2 sleep_time = 1 -num_rounds = 1 +num_rounds = 3 parameters = ndarrays_to_parameters(get_parameters(net=Net())) # -------------------------------------------------------------------------- Driver SDK driver.connect() # -------------------------------------------------------------------------- Driver SDK +history = History() for server_round in range(num_rounds): print(f"Commencing server round {server_round + 1}") @@ -90,7 +125,8 @@ ), consumer=node_pb2.Node( node_id=sampled_node_id, - anonymous=anonymous_client_nodes, # Must be True if we're working with anonymous clients + anonymous=anonymous_client_nodes, + # Must be True if we're working with anonymous clients ), legacy_server_message=server_message_proto, ), @@ -137,14 +173,50 @@ if len(all_task_res) == len(task_ids): break - # "Aggregate" results - node_messages = [task_res.task.legacy_client_message for task_res in all_task_res] + # Collect correct results + node_messages: List[ClientMessage] = [] + for task_res in all_task_res: + if task_res.task.HasField("legacy_client_message"): + node_messages.append(task_res.task.legacy_client_message) print(f"Received {len(node_messages)} results") + weights_results: List[Tuple[NDArrays, int]] = [] + metrics_results: List = [] + for node_message in node_messages: + if not node_message.fit_res: + continue + fit_res = node_message.fit_res + # Aggregate only if the status is OK + if fit_res.status.code != Code.OK.value: + continue + weights_results.append( + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + ) + metrics_results.append( + (fit_res.num_examples, serde.metrics_from_proto(fit_res.metrics)) + ) + + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated + + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + + # Slow down the start of the next round time.sleep(sleep_time) - # Repeat +print("app_fit: losses_distributed %s", str(history.losses_distributed)) +print("app_fit: metrics_distributed_fit %s", str(history.metrics_distributed_fit)) +print("app_fit: metrics_distributed %s", str(history.metrics_distributed)) +print("app_fit: losses_centralized %s", str(history.losses_centralized)) +print("app_fit: metrics_centralized %s", str(history.metrics_centralized)) # -------------------------------------------------------------------------- Driver SDK driver.disconnect() # -------------------------------------------------------------------------- Driver SDK +print("Driver disconnected") diff --git a/examples/mt-pytorch/run-grpc.sh b/examples/mt-pytorch/run-grpc.sh index 2aead43a34b5..a8a935786a9d 100755 --- a/examples/mt-pytorch/run-grpc.sh +++ b/examples/mt-pytorch/run-grpc.sh @@ -9,7 +9,7 @@ echo "Starting server" flower-server --rest & sleep 3 # Sleep for 3s to give the server enough time to start -for i in `seq 0 0`; do +for i in `seq 0 1`; do echo "Starting client $i" python client.py & done diff --git a/examples/mt-pytorch/server.py b/examples/mt-pytorch/server.py index fe691a88aba0..d96edd7d45ad 100644 --- a/examples/mt-pytorch/server.py +++ b/examples/mt-pytorch/server.py @@ -6,20 +6,36 @@ # Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - # Multiply accuracy of each client by number of examples used - accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + # Aggregate and return custom metric (weighted average) - return {"accuracy": sum(accuracies) / sum(examples)} + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } # Define strategy -strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, +) # Start Flower server fl.server.start_server( - server_address="0.0.0.0:8080", + server_address="0.0.0.0:9092", config=fl.server.ServerConfig(num_rounds=3), strategy=strategy, ) diff --git a/examples/mt-pytorch/task.py b/examples/mt-pytorch/task.py index b474048461a2..ebdb508e16d9 100644 --- a/examples/mt-pytorch/task.py +++ b/examples/mt-pytorch/task.py @@ -36,19 +36,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc3(x) -def train(net, trainloader, epochs): +def train(net, trainloader, valloader, epochs, device: str = "cpu"): """Train the model on the training set.""" - criterion = torch.nn.CrossEntropyLoss() + print("Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - batch_count = 0 + net.train() for _ in range(epochs): - for images, labels in tqdm(trainloader): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) optimizer.zero_grad() - criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + loss = criterion(net(images), labels) + loss.backward() optimizer.step() - batch_count += 1 - if batch_count == 100: - break # Just do a few batches + + net.to("cpu") # move model back to CPU + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results def test(net, testloader):