Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom strategy tutorial #1623

Merged
merged 9 commits into from
Feb 5, 2023
103 changes: 85 additions & 18 deletions doc/source/tutorial/Flower-3-Building-a-Strategy-PyTorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@
"source": [
"## Build a Strategy from scratch\n",
"\n",
"[WIP - add description]"
"Let’s overwrite the configure_fit method such that it passes a higher learning rate (potentially also other hyperparameters to an optimizer) to a fraction of the client. We will keep the sampling of the clients in the standard way and then change the configuration dictionary (one of the FitIns elements)."
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand All @@ -331,6 +331,7 @@
},
"outputs": [],
"source": [
"from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg\n",
"from typing import Callable, Union\n",
"\n",
"from flwr.common import (\n",
Expand All @@ -350,6 +351,22 @@
"\n",
"\n",
"class FedCustom(fl.server.strategy.Strategy):\n",
"\n",
" def __init__(\n",
" self,\n",
" fraction_fit: float = 1.0,\n",
" fraction_evaluate: float = 1.0,\n",
" min_fit_clients: int = 2,\n",
" min_evaluate_clients: int = 2,\n",
" min_available_clients: int = 2\n",
" ) -> None:\n",
" super().__init__()\n",
" self.fraction_fit = fraction_fit\n",
" self.fraction_evaluate = fraction_evaluate\n",
" self.min_fit_clients = min_fit_clients\n",
" self.min_evaluate_clients = min_evaluate_clients\n",
" self.min_available_clients = min_available_clients\n",
"\n",
" def __repr__(self) -> str:\n",
" return \"FedCustom\"\n",
"\n",
Expand All @@ -366,9 +383,26 @@
" ) -> List[Tuple[ClientProxy, FitIns]]:\n",
" \"\"\"Configure the next round of training.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
" \n",
" return []\n",
" # Sample clients\n",
" sample_size, min_num_clients = self.num_fit_clients(\n",
" client_manager.num_available()\n",
" )\n",
" clients = client_manager.sample(\n",
" num_clients=sample_size, min_num_clients=min_num_clients\n",
" )\n",
"\n",
" # Create custom configs\n",
" n_clients = len(clients)\n",
" half_clients = n_clients // 2\n",
" standard_config = { \"lr\": 0.001 }\n",
" higher_lr_config = { \"lr\": 0.003 }\n",
" fit_configurations = []\n",
" for idx, client in enumerate(clients):\n",
" if idx < half_clients:\n",
" fit_configurations.append((client, FitIns(parameters, standard_config)))\n",
" else:\n",
" fit_configurations.append((client, FitIns(parameters, higher_lr_config)))\n",
" return fit_configurations\n",
"\n",
" def aggregate_fit(\n",
" self,\n",
Expand All @@ -378,18 +412,33 @@
" ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
" \"\"\"Aggregate fit results using weighted average.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
"\n",
" return None, {}\n",
" weights_results = [\n",
" (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
" for _, fit_res in results\n",
" ]\n",
" parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))\n",
" metrics_aggregated = {}\n",
" return parameters_aggregated, metrics_aggregated\n",
"\n",
" def configure_evaluate(\n",
" self, server_round: int, parameters: Parameters, client_manager: ClientManager\n",
" ) -> List[Tuple[ClientProxy, EvaluateIns]]:\n",
" \"\"\"Configure the next round of evaluation.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
"\n",
" return []\n",
" if self.fraction_evaluate == 0.0:\n",
" return []\n",
" config = {}\n",
" evaluate_ins = EvaluateIns(parameters, config)\n",
"\n",
" # Sample clients\n",
" sample_size, min_num_clients = self.num_evaluation_clients(\n",
" client_manager.num_available()\n",
" )\n",
" clients = client_manager.sample(\n",
" num_clients=sample_size, min_num_clients=min_num_clients\n",
" )\n",
"\n",
" # Return client/config pairs\n",
" return [(client, evaluate_ins) for client in clients]\n",
"\n",
" def aggregate_evaluate(\n",
" self,\n",
Expand All @@ -399,18 +448,36 @@
" ) -> Tuple[Optional[float], Dict[str, Scalar]]:\n",
" \"\"\"Aggregate evaluation losses using weighted average.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
" if not results:\n",
" return None, {}\n",
"\n",
" return None, {}\n",
" loss_aggregated = weighted_loss_avg(\n",
" [\n",
" (evaluate_res.num_examples, evaluate_res.loss)\n",
" for _, evaluate_res in results\n",
" ]\n",
" )\n",
" metrics_aggregated = {}\n",
" return loss_aggregated, metrics_aggregated\n",
"\n",
" def evaluate(\n",
" self, server_round: int, parameters: Parameters\n",
" ) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
" \"\"\"Evaluate model parameters using an evaluation function.\"\"\"\n",
" \"\"\"Evaluate global model parameters using an evaluation function.\"\"\"\n",
"\n",
" # TODO WIP - add implementation\n",
" # Let's assume we won't perform the global model evaluation on the server side.\n",
" return None\n",
"\n",
" return None"
" def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
" \"\"\"Return the sample size and the required number of available\n",
" clients.\"\"\"\n",
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
" num_clients = int(num_available_clients * self.fraction_fit)\n",
" return max(num_clients, self.min_fit_clients), self.min_available_clients\n",
"\n",
" def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
" \"\"\"Use a fraction of available clients for evaluation.\"\"\"\n",
" num_clients = int(num_available_clients * self.fraction_evaluate)\n",
" return max(num_clients, self.min_evaluate_clients), self.min_available_clients"
]
},
{
Expand Down Expand Up @@ -447,7 +514,7 @@
"source": [
"## Recap\n",
"\n",
"[WIP - add description]"
"In this notebook, we’ve seen how to implement a custom strategy. It allows very flexible client creation, fitting, and evaluation. You have to overwrite the abstract methods of the Strategy class. And additionally, you can pass custom functions to your new class initializer to make it even more robust. "
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand All @@ -469,7 +536,7 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Flower-2-Strategies-in-FL-PyTorch.ipynb",
"name": "Flower-3-Building-a-Strategy-PyTorch.ipynb",
"provenance": [],
"toc_visible": true
},
Expand Down