-
Notifications
You must be signed in to change notification settings - Fork 942
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Flower Baseline: FedAvg MNIST (#1497)
- Loading branch information
1 parent
bdb2ee8
commit da1c8f7
Showing
24 changed files
with
741 additions
and
0 deletions.
There are no files selected for viewing
4 changes: 4 additions & 0 deletions
4
baselines/flwr_baselines/publications/fedavg_mnist/.gitignore
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dataset/ | ||
outputs/ | ||
playground.ipynb | ||
multirun/ |
87 changes: 87 additions & 0 deletions
87
baselines/flwr_baselines/publications/fedavg_mnist/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Federated Averaging MNIST | ||
|
||
The following baseline replicates the experiments in *Communication-Efficient Learning of Deep Networks from Decentralized Data* (McMahan et al., 2017), which was the first paper to coin the term Federated Learning and to propose the FederatedAveraging algorthim. | ||
|
||
**Paper Abstract:** | ||
|
||
<center> | ||
<i>Modern mobile devices have access to a wealth | ||
of data suitable for learning models, which in turn | ||
can greatly improve the user experience on the | ||
device. For example, language models can improve speech recognition and text entry, and image models can automatically select good photos. | ||
However, this rich data is often privacy sensitive, | ||
large in quantity, or both, which may preclude | ||
logging to the data center and training there using | ||
conventional approaches. We advocate an alternative that leaves the training data distributed on | ||
the mobile devices, and learns a shared model by | ||
aggregating locally-computed updates. We term | ||
this decentralized approach Federated Learning. | ||
We present a practical method for the federated | ||
learning of deep networks based on iterative | ||
model averaging, and conduct an extensive empirical evaluation, considering five different model architectures and four datasets. These experiments | ||
demonstrate the approach is robust to the unbalanced and non-IID data distributions that are a | ||
defining characteristic of this setting. Communication costs are the principal constraint, and | ||
we show a reduction in required communication | ||
rounds by 10–100× as compared to synchronized | ||
stochastic gradient descent</i> | ||
</center> | ||
|
||
**Paper Authors:** | ||
|
||
H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. | ||
|
||
|
||
Note: If you use this implementation in your work, please remember to cite the original authors of the paper. | ||
|
||
**[Link to paper.](https://arxiv.org/abs/1602.05629)** | ||
|
||
## Training Setup | ||
|
||
### CNN Architecture | ||
|
||
The CNN architecture is detailed in the paper and used to create the **Federated Averaging MNIST** baseline. | ||
|
||
| Layer | Details| | ||
| ----- | ------ | | ||
| 1 | Conv2D(1, 32, 5, 1, 1) <br/> ReLU, MaxPool2D(2, 2, 1) | | ||
| 2 | Conv2D(32, 64, 5, 1, 1) <br/> ReLU, MaxPool2D(2, 2, 1) | | ||
| 3 | FC(64 * 7 * 7, 512) <br/> ReLU | | ||
| 5 | FC(512, 10) | | ||
|
||
### Training Paramaters | ||
|
||
| Description | Value | | ||
| ----------- | ----- | | ||
| loss | cross entropy loss | | ||
| optimizer | SGD | | ||
| learning rate | 0.1 (by default) | | ||
| local epochs | 5 (by default) | | ||
| local batch size | 10 (by default) | | ||
|
||
## Running experiments | ||
|
||
The `config.yaml` file containing all the tunable hyperparameters and the necessary variables can be found under the `conf` folder. | ||
[Hydra](https://hydra.cc/docs/tutorials/) is used to manage the different parameters experiments can be ran with. | ||
|
||
To run using the default parameters, just enter `python main.py`, if some parameters need to be overwritten, you can do it like in the following example: | ||
|
||
```sh | ||
python main.py num_epochs=5 num_rounds=1000 iid=True | ||
``` | ||
|
||
Results will be stored as timestamped folders inside either `outputs` or `multiruns`, depending on whether you perform single- or multi-runs. | ||
|
||
### Example output | ||
|
||
To help visualize results, the script also plots evaluation curves. Here is an example: | ||
|
||
<p align="center"> | ||
<img src="docs/centralized_metrics.png" alt="Centralized evaluation results" width="400"> | ||
</p> | ||
|
||
You will also find the saved history in the `docs/results/` folder, | ||
here `C` is referring to the number of clients, `B` the batch size, | ||
`E` the number of local epochs, `R` the number of rounds, and `stag` | ||
the proportion of clients that are unreachable at each round. | ||
|
||
|
Empty file.
122 changes: 122 additions & 0 deletions
122
baselines/flwr_baselines/publications/fedavg_mnist/client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# pylint: disable=too-many-arguments | ||
"""Defines the MNIST Flower Client and a function to instantiate it.""" | ||
|
||
|
||
from collections import OrderedDict | ||
from typing import Callable, List, Tuple | ||
|
||
import flwr as fl | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from flwr_baselines.publications.fedavg_mnist import model | ||
from flwr_baselines.publications.fedavg_mnist.dataset import load_datasets | ||
|
||
|
||
class FlowerClient(fl.client.NumPyClient): | ||
"""Standard Flower client for CNN training.""" | ||
|
||
def __init__( | ||
self, | ||
net: torch.nn.Module, | ||
trainloader: DataLoader, | ||
valloader: DataLoader, | ||
device: torch.device, | ||
num_epochs: int, | ||
learning_rate: float, | ||
): | ||
self.net = net | ||
self.trainloader = trainloader | ||
self.valloader = valloader | ||
self.device = device | ||
self.num_epochs = num_epochs | ||
self.learning_rate = learning_rate | ||
|
||
def get_parameters(self, config) -> List[np.ndarray]: | ||
"""Returns the parameters of the current net.""" | ||
return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | ||
|
||
def set_parameters(self, parameters: List[np.ndarray]) -> None: | ||
"""Changes the parameters of the model using the given ones.""" | ||
params_dict = zip(self.net.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | ||
self.net.load_state_dict(state_dict, strict=True) | ||
|
||
def fit( | ||
self, parameters: List[np.ndarray], config | ||
) -> Tuple[List[np.ndarray], int, dict]: | ||
"""Implements distributed fit function for a given client.""" | ||
self.set_parameters(parameters) | ||
model.train( | ||
self.net, | ||
self.trainloader, | ||
self.device, | ||
epochs=self.num_epochs, | ||
learning_rate=self.learning_rate, | ||
) | ||
return self.get_parameters(self.net), len(self.trainloader), {} | ||
|
||
def evaluate(self, parameters: List[np.ndarray], config): | ||
"""Implements distributed evaluation for a given client.""" | ||
self.set_parameters(parameters) | ||
loss, accuracy = model.test(self.net, self.valloader, self.device) | ||
return float(loss), len(self.valloader), {"accuracy": float(accuracy)} | ||
|
||
|
||
def gen_client_fn( | ||
device: torch.device, | ||
iid: bool, | ||
num_clients: int, | ||
num_epochs: int, | ||
batch_size: int, | ||
learning_rate: float, | ||
) -> Tuple[Callable[[str], FlowerClient], DataLoader]: | ||
"""Generates the client function that creates the Flower Clients. | ||
Parameters | ||
---------- | ||
device : torch.device | ||
The device on which the the client will train on and test on. | ||
iid : bool | ||
The way to partition the data for each client, i.e. whether the data | ||
should be independent and identically distributed between the clients | ||
or if the data should first be sorted by labels and distributed by chunks | ||
to each client (used to test the convergence in a worst case scenario) | ||
num_clients : int | ||
The number of clients present in the setup | ||
num_epochs : int | ||
The number of local epochs each client should run the training for before | ||
sending it to the server. | ||
batch_size : int | ||
The size of the local batches each client trains on. | ||
learning_rate : float | ||
The learning rate for the SGD optimizer of clients. | ||
Returns | ||
------- | ||
Tuple[Callable[[str], FlowerClient], DataLoader] | ||
A tuple containing the client function that creates Flower Clients and | ||
the DataLoader that will be used for testing | ||
""" | ||
trainloaders, valloaders, testloader = load_datasets( | ||
iid=iid, num_clients=num_clients, batch_size=batch_size | ||
) | ||
|
||
def client_fn(cid: str) -> FlowerClient: | ||
"""Create a Flower client representing a single organization.""" | ||
|
||
# Load model | ||
net = model.Net().to(device) | ||
|
||
# Note: each client gets a different trainloader/valloader, so each client | ||
# will train and evaluate on their own unique data | ||
trainloader = trainloaders[int(cid)] | ||
valloader = valloaders[int(cid)] | ||
|
||
# Create a single Flower client representing a single organization | ||
return FlowerClient( | ||
net, trainloader, valloader, device, num_epochs, learning_rate | ||
) | ||
|
||
return client_fn, testloader |
121 changes: 121 additions & 0 deletions
121
baselines/flwr_baselines/publications/fedavg_mnist/dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""MNIST dataset utilities for federated learning.""" | ||
|
||
|
||
from typing import List, Optional, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset, random_split | ||
from torchvision.datasets import MNIST | ||
|
||
|
||
def load_datasets( | ||
num_clients: int = 10, | ||
iid: Optional[bool] = True, | ||
val_ratio: float = 0.1, | ||
batch_size: Optional[int] = 32, | ||
seed: Optional[int] = 42, | ||
) -> Tuple[DataLoader, DataLoader, DataLoader]: | ||
"""Creates the dataloaders to be fed into the model. | ||
Parameters | ||
---------- | ||
num_clients : int, optional | ||
The number of clients that hold a part of the data, by default 10 | ||
iid : bool, optional | ||
Whether the data should be independent and identically distributed between the | ||
clients or if the data should first be sorted by labels and distributed by chunks | ||
to each client (used to test the convergence in a worst case scenario), by default True | ||
val_ratio : float, optional | ||
The ratio of training data that will be used for validation (between 0 and 1), | ||
by default 0.1 | ||
batch_size : int, optional | ||
The size of the batches to be fed into the model, by default 32 | ||
seed : int, optional | ||
Used to set a fix seed to replicate experiments, by default 42 | ||
Returns | ||
------- | ||
Tuple[DataLoader, DataLoader, DataLoader] | ||
The DataLoader for training, the DataLoader for validation, the DataLoader for testing. | ||
""" | ||
datasets, testset = _partition_data(num_clients, iid, seed) | ||
# Split each partition into train/val and create DataLoader | ||
trainloaders = [] | ||
valloaders = [] | ||
for dataset in datasets: | ||
len_val = int(len(dataset) / (1 / val_ratio)) | ||
len_train = len(dataset) - len_val | ||
lengths = [len_train, len_val] | ||
ds_train, ds_val = random_split( | ||
dataset, lengths, torch.Generator().manual_seed(seed) | ||
) | ||
trainloaders.append(DataLoader(ds_train, batch_size=batch_size, shuffle=True)) | ||
valloaders.append(DataLoader(ds_val, batch_size=batch_size)) | ||
return trainloaders, valloaders, DataLoader(testset, batch_size=batch_size) | ||
|
||
|
||
def _download_data() -> Tuple[Dataset, Dataset]: | ||
"""Downloads (if necessary) and returns the MNIST dataset. | ||
Returns | ||
------- | ||
Tuple[MNIST, MNIST] | ||
The dataset for training and the dataset for testing MNIST. | ||
""" | ||
transform = transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||
) | ||
trainset = MNIST("./dataset", train=True, download=True, transform=transform) | ||
testset = MNIST("./dataset", train=False, download=True, transform=transform) | ||
return trainset, testset | ||
|
||
|
||
def _partition_data( | ||
num_clients: int = 10, | ||
iid: Optional[bool] = True, | ||
seed: Optional[int] = 42, | ||
) -> Tuple[List[Dataset], Dataset]: | ||
"""Split training set into iid or non iid partitions to simulate the | ||
federated setting. | ||
Parameters | ||
---------- | ||
num_clients : int, optional | ||
The number of clients that hold a part of the data, by default 10 | ||
iid : bool, optional | ||
Whether the data should be independent and identically distributed between | ||
the clients or if the data should first be sorted by labels and distributed by chunks | ||
to each client (used to test the convergence in a worst case scenario), by default True | ||
seed : int, optional | ||
Used to set a fix seed to replicate experiments, by default 42 | ||
Returns | ||
------- | ||
Tuple[List[Dataset], Dataset] | ||
A list of dataset for each client and a single dataset to be use for testing the model. | ||
""" | ||
trainset, testset = _download_data() | ||
partition_size = int(len(trainset) / num_clients) | ||
lengths = [partition_size] * num_clients | ||
if iid: | ||
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(seed)) | ||
else: | ||
shard_size = int(partition_size / 2) | ||
idxs = trainset.targets.argsort() | ||
sorted_data = Subset(trainset, idxs) | ||
tmp = [] | ||
for idx in range(num_clients * 2): | ||
tmp.append( | ||
Subset(sorted_data, np.arange(shard_size * idx, shard_size * (idx + 1))) | ||
) | ||
idxs_list = torch.randperm( | ||
num_clients * 2, generator=torch.Generator().manual_seed(seed) | ||
) | ||
datasets = [ | ||
ConcatDataset((tmp[idxs_list[2 * i]], tmp[idxs_list[2 * i + 1]])) | ||
for i in range(num_clients) | ||
] | ||
|
||
return datasets, testset |
Binary file added
BIN
+55.2 KB
baselines/flwr_baselines/publications/fedavg_mnist/docs/centralized_metrics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions
9
baselines/flwr_baselines/publications/fedavg_mnist/docs/conf/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
num_clients: 10 | ||
num_rounds: 10 | ||
num_epochs: 5 | ||
batch_size: 10 | ||
iid: False | ||
client_fraction: 1.0 | ||
expected_maximum: 0.9924 | ||
learning_rate: 0.1 | ||
save_path: "docs/results" |
Binary file added
BIN
+33.2 KB
...ons/fedavg_mnist/docs/results/centralized_metrics_cli=100_rds=1000_stag=0.0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+37.3 KB
...ions/fedavg_mnist/docs/results/centralized_metrics_cli=100_rds=200_stag=0.5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+44.3 KB
...ions/fedavg_mnist/docs/results/centralized_metrics_cli=100_rds=200_stag=0.9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+34.5 KB
...ations/fedavg_mnist/docs/results/centralized_metrics_cli=100_rds=200_stag=0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+34.5 KB
...lines/publications/fedavg_mnist/docs/results/hist_C=100_B=1000000_E=1_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+34.5 KB
...ines/publications/fedavg_mnist/docs/results/hist_C=100_B=1000000_E=20_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+34.5 KB
...lines/publications/fedavg_mnist/docs/results/hist_C=100_B=1000000_E=5_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+34.5 KB
..._baselines/publications/fedavg_mnist/docs/results/hist_C=100_B=10_E=1_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+16.9 KB
..._baselines/publications/fedavg_mnist/docs/results/hist_C=100_B=10_E=5_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+34.5 KB
..._baselines/publications/fedavg_mnist/docs/results/hist_C=100_B=50_E=1_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+34.5 KB
..._baselines/publications/fedavg_mnist/docs/results/hist_C=100_B=50_E=5_R=1000_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+623 Bytes
...flwr_baselines/publications/fedavg_mnist/docs/results/hist_C=10_B=10_E=1_R=5_stag=0.0.npy
Binary file not shown.
Binary file added
BIN
+511 Bytes
.../flwr_baselines/publications/fedavg_mnist/docs/results/hist_C=1_B=10_E=1_R=1_stag=0.0.npy
Binary file not shown.
Oops, something went wrong.