Skip to content

Commit

Permalink
Flower Baseline: FedAvg MNIST (#1497)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Jan 9, 2023
1 parent bdb2ee8 commit da1c8f7
Show file tree
Hide file tree
Showing 24 changed files with 741 additions and 0 deletions.
4 changes: 4 additions & 0 deletions baselines/flwr_baselines/publications/fedavg_mnist/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset/
outputs/
playground.ipynb
multirun/
87 changes: 87 additions & 0 deletions baselines/flwr_baselines/publications/fedavg_mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Federated Averaging MNIST

The following baseline replicates the experiments in *Communication-Efficient Learning of Deep Networks from Decentralized Data* (McMahan et al., 2017), which was the first paper to coin the term Federated Learning and to propose the FederatedAveraging algorthim.

**Paper Abstract:**

<center>
<i>Modern mobile devices have access to a wealth
of data suitable for learning models, which in turn
can greatly improve the user experience on the
device. For example, language models can improve speech recognition and text entry, and image models can automatically select good photos.
However, this rich data is often privacy sensitive,
large in quantity, or both, which may preclude
logging to the data center and training there using
conventional approaches. We advocate an alternative that leaves the training data distributed on
the mobile devices, and learns a shared model by
aggregating locally-computed updates. We term
this decentralized approach Federated Learning.
We present a practical method for the federated
learning of deep networks based on iterative
model averaging, and conduct an extensive empirical evaluation, considering five different model architectures and four datasets. These experiments
demonstrate the approach is robust to the unbalanced and non-IID data distributions that are a
defining characteristic of this setting. Communication costs are the principal constraint, and
we show a reduction in required communication
rounds by 10–100× as compared to synchronized
stochastic gradient descent</i>
</center>

**Paper Authors:**

H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas.


Note: If you use this implementation in your work, please remember to cite the original authors of the paper.

**[Link to paper.](https://arxiv.org/abs/1602.05629)**

## Training Setup

### CNN Architecture

The CNN architecture is detailed in the paper and used to create the **Federated Averaging MNIST** baseline.

| Layer | Details|
| ----- | ------ |
| 1 | Conv2D(1, 32, 5, 1, 1) <br/> ReLU, MaxPool2D(2, 2, 1) |
| 2 | Conv2D(32, 64, 5, 1, 1) <br/> ReLU, MaxPool2D(2, 2, 1) |
| 3 | FC(64 * 7 * 7, 512) <br/> ReLU |
| 5 | FC(512, 10) |

### Training Paramaters

| Description | Value |
| ----------- | ----- |
| loss | cross entropy loss |
| optimizer | SGD |
| learning rate | 0.1 (by default) |
| local epochs | 5 (by default) |
| local batch size | 10 (by default) |

## Running experiments

The `config.yaml` file containing all the tunable hyperparameters and the necessary variables can be found under the `conf` folder.
[Hydra](https://hydra.cc/docs/tutorials/) is used to manage the different parameters experiments can be ran with.

To run using the default parameters, just enter `python main.py`, if some parameters need to be overwritten, you can do it like in the following example:

```sh
python main.py num_epochs=5 num_rounds=1000 iid=True
```

Results will be stored as timestamped folders inside either `outputs` or `multiruns`, depending on whether you perform single- or multi-runs.

### Example output

To help visualize results, the script also plots evaluation curves. Here is an example:

<p align="center">
<img src="docs/centralized_metrics.png" alt="Centralized evaluation results" width="400">
</p>

You will also find the saved history in the `docs/results/` folder,
here `C` is referring to the number of clients, `B` the batch size,
`E` the number of local epochs, `R` the number of rounds, and `stag`
the proportion of clients that are unreachable at each round.


Empty file.
122 changes: 122 additions & 0 deletions baselines/flwr_baselines/publications/fedavg_mnist/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# pylint: disable=too-many-arguments
"""Defines the MNIST Flower Client and a function to instantiate it."""


from collections import OrderedDict
from typing import Callable, List, Tuple

import flwr as fl
import numpy as np
import torch
from torch.utils.data import DataLoader

from flwr_baselines.publications.fedavg_mnist import model
from flwr_baselines.publications.fedavg_mnist.dataset import load_datasets


class FlowerClient(fl.client.NumPyClient):
"""Standard Flower client for CNN training."""

def __init__(
self,
net: torch.nn.Module,
trainloader: DataLoader,
valloader: DataLoader,
device: torch.device,
num_epochs: int,
learning_rate: float,
):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
self.device = device
self.num_epochs = num_epochs
self.learning_rate = learning_rate

def get_parameters(self, config) -> List[np.ndarray]:
"""Returns the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def set_parameters(self, parameters: List[np.ndarray]) -> None:
"""Changes the parameters of the model using the given ones."""
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)

def fit(
self, parameters: List[np.ndarray], config
) -> Tuple[List[np.ndarray], int, dict]:
"""Implements distributed fit function for a given client."""
self.set_parameters(parameters)
model.train(
self.net,
self.trainloader,
self.device,
epochs=self.num_epochs,
learning_rate=self.learning_rate,
)
return self.get_parameters(self.net), len(self.trainloader), {}

def evaluate(self, parameters: List[np.ndarray], config):
"""Implements distributed evaluation for a given client."""
self.set_parameters(parameters)
loss, accuracy = model.test(self.net, self.valloader, self.device)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def gen_client_fn(
device: torch.device,
iid: bool,
num_clients: int,
num_epochs: int,
batch_size: int,
learning_rate: float,
) -> Tuple[Callable[[str], FlowerClient], DataLoader]:
"""Generates the client function that creates the Flower Clients.
Parameters
----------
device : torch.device
The device on which the the client will train on and test on.
iid : bool
The way to partition the data for each client, i.e. whether the data
should be independent and identically distributed between the clients
or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario)
num_clients : int
The number of clients present in the setup
num_epochs : int
The number of local epochs each client should run the training for before
sending it to the server.
batch_size : int
The size of the local batches each client trains on.
learning_rate : float
The learning rate for the SGD optimizer of clients.
Returns
-------
Tuple[Callable[[str], FlowerClient], DataLoader]
A tuple containing the client function that creates Flower Clients and
the DataLoader that will be used for testing
"""
trainloaders, valloaders, testloader = load_datasets(
iid=iid, num_clients=num_clients, batch_size=batch_size
)

def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""

# Load model
net = model.Net().to(device)

# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]

# Create a single Flower client representing a single organization
return FlowerClient(
net, trainloader, valloader, device, num_epochs, learning_rate
)

return client_fn, testloader
121 changes: 121 additions & 0 deletions baselines/flwr_baselines/publications/fedavg_mnist/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""MNIST dataset utilities for federated learning."""


from typing import List, Optional, Tuple

import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset, random_split
from torchvision.datasets import MNIST


def load_datasets(
num_clients: int = 10,
iid: Optional[bool] = True,
val_ratio: float = 0.1,
batch_size: Optional[int] = 32,
seed: Optional[int] = 42,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""Creates the dataloaders to be fed into the model.
Parameters
----------
num_clients : int, optional
The number of clients that hold a part of the data, by default 10
iid : bool, optional
Whether the data should be independent and identically distributed between the
clients or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario), by default True
val_ratio : float, optional
The ratio of training data that will be used for validation (between 0 and 1),
by default 0.1
batch_size : int, optional
The size of the batches to be fed into the model, by default 32
seed : int, optional
Used to set a fix seed to replicate experiments, by default 42
Returns
-------
Tuple[DataLoader, DataLoader, DataLoader]
The DataLoader for training, the DataLoader for validation, the DataLoader for testing.
"""
datasets, testset = _partition_data(num_clients, iid, seed)
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for dataset in datasets:
len_val = int(len(dataset) / (1 / val_ratio))
len_train = len(dataset) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(
dataset, lengths, torch.Generator().manual_seed(seed)
)
trainloaders.append(DataLoader(ds_train, batch_size=batch_size, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=batch_size))
return trainloaders, valloaders, DataLoader(testset, batch_size=batch_size)


def _download_data() -> Tuple[Dataset, Dataset]:
"""Downloads (if necessary) and returns the MNIST dataset.
Returns
-------
Tuple[MNIST, MNIST]
The dataset for training and the dataset for testing MNIST.
"""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
trainset = MNIST("./dataset", train=True, download=True, transform=transform)
testset = MNIST("./dataset", train=False, download=True, transform=transform)
return trainset, testset


def _partition_data(
num_clients: int = 10,
iid: Optional[bool] = True,
seed: Optional[int] = 42,
) -> Tuple[List[Dataset], Dataset]:
"""Split training set into iid or non iid partitions to simulate the
federated setting.
Parameters
----------
num_clients : int, optional
The number of clients that hold a part of the data, by default 10
iid : bool, optional
Whether the data should be independent and identically distributed between
the clients or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario), by default True
seed : int, optional
Used to set a fix seed to replicate experiments, by default 42
Returns
-------
Tuple[List[Dataset], Dataset]
A list of dataset for each client and a single dataset to be use for testing the model.
"""
trainset, testset = _download_data()
partition_size = int(len(trainset) / num_clients)
lengths = [partition_size] * num_clients
if iid:
datasets = random_split(trainset, lengths, torch.Generator().manual_seed(seed))
else:
shard_size = int(partition_size / 2)
idxs = trainset.targets.argsort()
sorted_data = Subset(trainset, idxs)
tmp = []
for idx in range(num_clients * 2):
tmp.append(
Subset(sorted_data, np.arange(shard_size * idx, shard_size * (idx + 1)))
)
idxs_list = torch.randperm(
num_clients * 2, generator=torch.Generator().manual_seed(seed)
)
datasets = [
ConcatDataset((tmp[idxs_list[2 * i]], tmp[idxs_list[2 * i + 1]]))
for i in range(num_clients)
]

return datasets, testset
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
num_clients: 10
num_rounds: 10
num_epochs: 5
batch_size: 10
iid: False
client_fraction: 1.0
expected_maximum: 0.9924
learning_rate: 0.1
save_path: "docs/results"
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit da1c8f7

Please sign in to comment.