Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for Custom Metrics calculation during Federated Learning #1958

Merged
merged 34 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
92e598c
Update sim.ipynb
gubertoli Mar 6, 2022
7d58b6f
Merge branch 'main' into main
danieljanes Mar 14, 2022
2311465
Merge branch 'main' into main
danieljanes Mar 14, 2022
b559f88
Merge branch 'adap:main' into main
gubertoli Mar 15, 2022
6372cc7
Merge branch 'adap:main' into main
gubertoli Jun 28, 2022
64c7124
Merge branch 'adap:main' into main
gubertoli Jul 12, 2022
0770547
Change to comply with .fit() tuple requirements
gubertoli Jul 12, 2022
6a9076f
Merge branch 'main' into main
danieljanes Jul 13, 2022
99645ef
Merge branch 'adap:main' into main
gubertoli Jul 13, 2022
6774045
Merge branch 'adap:main' into main
gubertoli Jun 22, 2023
1eefe76
custom metrics example
gubertoli Jun 22, 2023
0edb274
Merge branch 'main' into extra_metrics
gubertoli Oct 18, 2023
e441f7a
Merge branch 'main' into extra_metrics
gubertoli Jan 16, 2024
63ee989
Format and test ok
gubertoli Jan 16, 2024
5e25c3c
README
gubertoli Jan 16, 2024
947897a
Merge branch 'main' into extra_metrics
danieljanes Jan 17, 2024
949b4ee
Update examples/custom-metrics/requirements.txt
gubertoli Jan 17, 2024
8298aec
Update examples/custom-metrics/client.py
gubertoli Jan 17, 2024
c5be003
Update to FlowerClient class and added e-mail
gubertoli Jan 17, 2024
41043b7
Using flwr-datasets and tested with pip and poetry
gubertoli Jan 17, 2024
5762f59
Merge branch 'main' into extra_metrics
gubertoli Jan 17, 2024
76e0711
Merge branch 'main' into extra_metrics
gubertoli Jan 17, 2024
8474a75
Merge branch 'main' into extra_metrics
gubertoli Jan 17, 2024
9d7002f
Merge branch 'main' into extra_metrics
danieljanes Jan 18, 2024
56e0b36
Uppercase comment
gubertoli Jan 18, 2024
d3eebe7
Uppercase comment
gubertoli Jan 18, 2024
290ac5a
Add comment
gubertoli Jan 18, 2024
8d1c380
Add comment about waiting for server.py
gubertoli Jan 18, 2024
4b0978d
Add comment about strategy definition
gubertoli Jan 18, 2024
5d9319b
Fix typos
gubertoli Jan 18, 2024
8e29ba8
Fix typo
gubertoli Jan 18, 2024
01e3ba2
Add missing reference to run.sh
gubertoli Jan 18, 2024
e5ecd6b
Improving docstring about mean average and about weighted average
gubertoli Jan 18, 2024
73f2cac
Merge branch 'main' into extra_metrics
yan-gao-GY Jan 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions examples/custom-metrics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Flower Example using Custom Metrics

This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score.

Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated.

The main takeaways of this implementation are:

- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py`
- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py`

This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10.

Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics
```

This will create a new directory called `custom-metrics` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- server.py
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
-- run.sh
-- README.md
```

### Installing Dependencies

Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

## Run Federated Learning with Custom Metrics

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
python server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each:

```shell
python client.py
```

Alternatively you can run all of it in one shell as follows:

```shell
python server.py &
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
# Wait for a few seconds to give the server enough time to start, then:
python client.py &
python client.py
```

or

```shell
chmod +x run.sh
./run.sh
```

You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development).

Running `run.sh` will result in the following output (after 3 rounds):

```shell
INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed {
'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)],
'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)],
'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)]
}
```
71 changes: 71 additions & 0 deletions examples/custom-metrics/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import flwr as fl
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from flwr_datasets import FederatedDataset


# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


# Load model (MobileNetV2)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Load data with Flower Datasets (CIFAR-10)
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
train = fds.load_full("train")
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
test = fds.load_full("test")

# Using Numpy format
train_np = train.with_format("numpy")
test_np = test.with_format("numpy")
x_train, y_train = train_np["img"], train_np["label"]
x_test, y_test = test_np["img"], test_np["label"]


# Method for extra learning metrics calculation
def eval_learning(y_test, y_pred):
acc = accuracy_score(y_test, y_pred)
rec = recall_score(
y_test, y_pred, average="micro"
) # average argument required for multi-class
prec = precision_score(y_test, y_pred, average="micro")
f1 = f1_score(y_test, y_pred, average="micro")
return acc, rec, prec, f1


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model.get_weights()

def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32)
return model.get_weights(), len(x_train), {}

def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1).reshape(
-1, 1
) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable

acc, rec, prec, f1 = eval_learning(y_test, y_pred)
output_dict = {
"accuracy": accuracy, # accuracy from tensorflow model.evaluate
"acc": acc,
"rec": rec,
"prec": prec,
"f1": f1,
}
return loss, len(x_test), output_dict


# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient())
19 changes: 19 additions & 0 deletions examples/custom-metrics/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "custom-metrics"
version = "0.1.0"
description = "Federated Learning with Flower and Custom Metrics"
authors = [
"The Flower Authors <hello@flower.dev>",
"Gustavo Bertoli <gubertoli -at- gmail.com>"
]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = { version = "*", extras = ["vision"] }
scikit-learn = "^1.2.2"
tensorflow = "==2.12.0"
4 changes: 4 additions & 0 deletions examples/custom-metrics/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
flwr>=1.0,<2.0
flwr-datasets[vision]
scikit-learn>=1.2.2
tensorflow==2.12.0
15 changes: 15 additions & 0 deletions examples/custom-metrics/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
58 changes: 58 additions & 0 deletions examples/custom-metrics/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import flwr as fl
import numpy as np


gubertoli marked this conversation as resolved.
Show resolved Hide resolved
# Define metrics aggregation function
def average_metrics(metrics):
gubertoli marked this conversation as resolved.
Show resolved Hide resolved
"""Aggregate metrics from multiple clients by calculating mean averages.

Parameters:
- metrics (list): A list containing tuples, where each tuple represents metrics for a client.
Each tuple is structured as (num_examples, metric), where:
- num_examples (int): The number of examples used to compute the metrics.
- metric (dict): A dictionary containing custom metrics provided as `output_dict`
in the `evaluate` method from `client.py`.

Returns:
A dictionary with the aggregated metrics, calculating mean averages. The keys of the
dictionary represent different metrics, including:
- 'accuracy': Mean accuracy calculated by TensorFlow.
- 'acc': Mean accuracy from scikit-learn.
- 'rec': Mean recall from scikit-learn.
- 'prec': Mean precision from scikit-learn.
- 'f1': Mean F1 score from scikit-learn.

Note: If a weighted average is required, the `num_examples` parameter can be leveraged.

Example:
Example `metrics` list for two clients after the last round:
[(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}),
(10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})]
"""

# Here num_examples are not taken into account by using _
accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics])
accuracies = np.mean([metric["acc"] for _, metric in metrics])
recalls = np.mean([metric["rec"] for _, metric in metrics])
precisions = np.mean([metric["prec"] for _, metric in metrics])
f1s = np.mean([metric["f1"] for _, metric in metrics])

return {
"accuracy": accuracies_tf,
"acc": accuracies,
"rec": recalls,
"prec": precisions,
"f1": f1s,
}


gubertoli marked this conversation as resolved.
Show resolved Hide resolved
# Define strategy and the custom aggregation function for the evaluation metrics
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics)


# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)