Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom strategy tutorial #1623

Merged
merged 9 commits into from
Feb 5, 2023
102 changes: 84 additions & 18 deletions doc/source/tutorial/Flower-3-Building-a-Strategy-PyTorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@
"source": [
"## Build a Strategy from scratch\n",
"\n",
"[WIP - add description]"
"Let’s overwrite the `configure_fit` method such that it passes a higher learning rate (potentially also other hyperparameters) to the optimizer of a fraction of the clients. We will keep the sampling of the clients as it is in `FedAvg` and then change the configuration dictionary (one of the `FitIns` attributes)."
]
},
{
Expand All @@ -331,6 +331,7 @@
},
"outputs": [],
"source": [
"from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg\n",
"from typing import Callable, Union\n",
"\n",
"from flwr.common import (\n",
Expand All @@ -350,6 +351,22 @@
"\n",
"\n",
"class FedCustom(fl.server.strategy.Strategy):\n",
"\n",
" def __init__(\n",
" self,\n",
" fraction_fit: float = 1.0,\n",
" fraction_evaluate: float = 1.0,\n",
" min_fit_clients: int = 2,\n",
" min_evaluate_clients: int = 2,\n",
" min_available_clients: int = 2\n",
" ) -> None:\n",
" super().__init__()\n",
" self.fraction_fit = fraction_fit\n",
" self.fraction_evaluate = fraction_evaluate\n",
" self.min_fit_clients = min_fit_clients\n",
" self.min_evaluate_clients = min_evaluate_clients\n",
" self.min_available_clients = min_available_clients\n",
"\n",
" def __repr__(self) -> str:\n",
" return \"FedCustom\"\n",
"\n",
Expand All @@ -366,9 +383,26 @@
" ) -> List[Tuple[ClientProxy, FitIns]]:\n",
" \"\"\"Configure the next round of training.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
" \n",
" return []\n",
" # Sample clients\n",
" sample_size, min_num_clients = self.num_fit_clients(\n",
" client_manager.num_available()\n",
" )\n",
" clients = client_manager.sample(\n",
" num_clients=sample_size, min_num_clients=min_num_clients\n",
" )\n",
"\n",
" # Create custom configs\n",
" n_clients = len(clients)\n",
" half_clients = n_clients // 2\n",
" standard_config = { \"lr\": 0.001 }\n",
" higher_lr_config = { \"lr\": 0.003 }\n",
" fit_configurations = []\n",
" for idx, client in enumerate(clients):\n",
" if idx < half_clients:\n",
" fit_configurations.append((client, FitIns(parameters, standard_config)))\n",
" else:\n",
" fit_configurations.append((client, FitIns(parameters, higher_lr_config)))\n",
" return fit_configurations\n",
"\n",
" def aggregate_fit(\n",
" self,\n",
Expand All @@ -378,18 +412,33 @@
" ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
" \"\"\"Aggregate fit results using weighted average.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
"\n",
" return None, {}\n",
" weights_results = [\n",
" (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
" for _, fit_res in results\n",
" ]\n",
" parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))\n",
" metrics_aggregated = {}\n",
" return parameters_aggregated, metrics_aggregated\n",
"\n",
" def configure_evaluate(\n",
" self, server_round: int, parameters: Parameters, client_manager: ClientManager\n",
" ) -> List[Tuple[ClientProxy, EvaluateIns]]:\n",
" \"\"\"Configure the next round of evaluation.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
"\n",
" return []\n",
" if self.fraction_evaluate == 0.0:\n",
" return []\n",
" config = {}\n",
" evaluate_ins = EvaluateIns(parameters, config)\n",
"\n",
" # Sample clients\n",
" sample_size, min_num_clients = self.num_evaluation_clients(\n",
" client_manager.num_available()\n",
" )\n",
" clients = client_manager.sample(\n",
" num_clients=sample_size, min_num_clients=min_num_clients\n",
" )\n",
"\n",
" # Return client/config pairs\n",
" return [(client, evaluate_ins) for client in clients]\n",
"\n",
" def aggregate_evaluate(\n",
" self,\n",
Expand All @@ -399,18 +448,35 @@
" ) -> Tuple[Optional[float], Dict[str, Scalar]]:\n",
" \"\"\"Aggregate evaluation losses using weighted average.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
" if not results:\n",
" return None, {}\n",
"\n",
" return None, {}\n",
" loss_aggregated = weighted_loss_avg(\n",
" [\n",
" (evaluate_res.num_examples, evaluate_res.loss)\n",
" for _, evaluate_res in results\n",
" ]\n",
" )\n",
" metrics_aggregated = {}\n",
" return loss_aggregated, metrics_aggregated\n",
"\n",
" def evaluate(\n",
" self, server_round: int, parameters: Parameters\n",
" ) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
" \"\"\"Evaluate model parameters using an evaluation function.\"\"\"\n",
" \"\"\"Evaluate global model parameters using an evaluation function.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
" # Let's assume we won't perform the global model evaluation on the server side.\n",
" return None\n",
"\n",
" return None"
" def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
" \"\"\"Return sample size and required number of clients.\"\"\"\n",
" num_clients = int(num_available_clients * self.fraction_fit)\n",
" return max(num_clients, self.min_fit_clients), self.min_available_clients\n",
"\n",
" def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
" \"\"\"Use a fraction of available clients for evaluation.\"\"\"\n",
" num_clients = int(num_available_clients * self.fraction_evaluate)\n",
" return max(num_clients, self.min_evaluate_clients), self.min_available_clients"
]
},
{
Expand Down Expand Up @@ -447,7 +513,7 @@
"source": [
"## Recap\n",
"\n",
"[WIP - add description]"
"In this notebook, we’ve seen how to implement a custom strategy. A custom strategy enables granular control over client node configuration, result aggregation, and more. To define a custom strategy, you only have to overwrite the abstract methods of the (abstract) base class `Strategy`. To make custom strategies even more powerful, you can pass custom functions to the constructor of your new class (`__init__`) and then call these functions whenever needed. "
]
},
{
Expand All @@ -469,7 +535,7 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Flower-2-Strategies-in-FL-PyTorch.ipynb",
"name": "Flower-3-Building-a-Strategy-PyTorch.ipynb",
"provenance": [],
"toc_visible": true
},
Expand Down