Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flower-client-authentication code example #2999

Merged
merged 23 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
959ecce
Add flower-client-authentication code example
danielnugraha Feb 22, 2024
dceae03
Small fixes
danielnugraha Feb 22, 2024
2b3401e
Add readme.md
danielnugraha Feb 22, 2024
7b4811a
Format readme
danielnugraha Feb 22, 2024
b6314b9
Merge remote-tracking branch 'origin' into add-code-example
danielnugraha Apr 24, 2024
5b76097
Merge branch 'main' into add-code-example
danieljanes Apr 29, 2024
790c923
Merge branch 'main' into add-code-example
danieljanes May 1, 2024
8e09159
Update pyproject.toml with hatchling build, remove requirements.txt (…
chongshenng May 1, 2024
bb09566
Update examples/flower-client-authentication/README.md
danielnugraha May 3, 2024
637b375
Update examples/flower-client-authentication/client.py
danielnugraha May 3, 2024
9c7e7c4
Update examples/flower-client-authentication/client.py
danielnugraha May 3, 2024
f1a45ad
Update examples/flower-client-authentication/server.py
danielnugraha May 3, 2024
80e4c57
Update examples/flower-client-authentication/README.md
danielnugraha May 3, 2024
fb21ead
Update examples/flower-client-authentication/README.md
danielnugraha May 3, 2024
511bc39
Update examples/flower-client-authentication/README.md
danielnugraha May 3, 2024
0e53cf4
Update examples/flower-client-authentication/README.md
danielnugraha May 3, 2024
1046cf1
Merge remote-tracking branch 'origin' into add-code-example
danielnugraha May 3, 2024
f471ee5
Fix code example
danielnugraha May 3, 2024
a33f72f
Fix readme to use secure connection
danielnugraha May 3, 2024
8ce9297
Format
danielnugraha May 3, 2024
ca37b81
Implement review feedback
danielnugraha May 3, 2024
4fcf085
Format
danielnugraha May 3, 2024
bbaedbe
Merge branch 'main' into add-code-example
danieljanes May 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions examples/flower-client-authentication/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Flower Client Authentication with PyTorch 🧪

> 🧪 = This example covers experimental features that might change in future versions of Flower
> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch.

The following steps describe how to start a long-running Flower server (SuperLink) and a long-running Flower client (SuperNode) with client authentication enabled.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/flower-client-authentication . && rm -rf _tmp && cd flower-client-authentication
```

This will create a new directory called `flower-client-authentication` with the following project structure:

```bash
$ tree .
.
├── certificate.conf # <-- configuration for OpenSSL
├── generate.sh # <-- generate certificates and keys
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved
├── pyproject.toml # <-- project dependencies
├── client.py # <-- contains `ClientApp`
├── server.py # <-- contains `ServerApp`
└── task.py # <-- task-specific code (model, data)
```

## Install dependencies

Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml`. You can install the dependencies by invoking `pip`:

```shell
# From a new python environment, run:
pip install .
```

Then, to verify that everything works correctly you can run the following command:

```shell
python3 -c "import flwr"
```

If you don't see any errors you're good to go!

## Generate public and private keys

```bash
./generate.sh
```

`generate.sh` is a script that (by default) generates certificates for creating a secure TLS connection
and three private and public key pairs for one server and two clients.
You can generate more keys by specifying the number of client credentials that you wish to generate.
The script also generates a CSV file that includes each of the generated (client) public keys.

⚠️ Note that this script should only be used for development purposes and not for creating production key pairs.

```bash
./generate.sh {your_number_of_clients}
```

## Start the long-running Flower server (SuperLink)

To start a long-running Flower server and enable client authentication is very easy; all you need to do is type
`--require-client-authentication` followed by the path to the known `client_public_keys.csv`, server's private key
`server_credentials`, and server's public key `server_credentials.pub`. Notice that you can only enable client
authentication with a secure TLS connection.
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved

```bash
flower-superlink \
--certificates certificates/ca.crt certificates/server.pem certificates/server.key \
--require-client-authentication keys/client_public_keys.csv keys/server_credentials keys/server_credentials.pub
```

## Start the long-running Flower client (SuperNode)

In a new terminal window, start the first long-running Flower client:

```bash
flower-client-app client:app \
--root-certificates certificates/ca.crt \
--server 127.0.0.1:9092 \
--authentication-keys keys/client_credentials_1 keys/client_credentials_1.pub
```

In yet another new terminal window, start the second long-running Flower client:

```bash
flower-client-app client:app \
--root-certificates certificates/ca.crt \
--server 127.0.0.1:9092 \
--authentication-keys keys/client_credentials_2 keys/client_credentials_2.pub
```

If you generated more than 2 client credentials, you can add more clients by opening new terminal windows and running the command
above. Don't forget to specify the correct client private and public keys for each client instance you created.

## Run the Flower App

With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower ServerApp:

```bash
flower-server-app server:app --root-certificates certificates/ca.crt --dir ./ --server 127.0.0.1:9091
```
20 changes: 20 additions & 0 deletions examples/flower-client-authentication/certificate.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[req]
default_bits = 4096
prompt = no
default_md = sha256
req_extensions = req_ext
distinguished_name = dn

[dn]
C = DE
ST = HH
O = Flower
CN = localhost

[req_ext]
subjectAltName = @alt_names

[alt_names]
DNS.1 = localhost
IP.1 = ::1
IP.2 = 127.0.0.1
43 changes: 43 additions & 0 deletions examples/flower-client-authentication/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Dict
from flwr.common import NDArrays, Scalar
from flwr.client import ClientApp, NumPyClient

from task import (
Net,
DEVICE,
load_data,
get_parameters,
set_parameters,
train,
test,
)


# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()


# Define Flower client and client_fn
class FlowerClient(NumPyClient):
def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
return get_parameters(net)

def fit(self, parameters, config):
set_parameters(net, parameters)
results = train(net, trainloader, testloader, epochs=1, device=DEVICE)
return get_parameters(net), len(trainloader.dataset), results

def evaluate(self, parameters, config):
set_parameters(net, parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
return FlowerClient().to_client()


app = ClientApp(
client_fn=client_fn,
)
72 changes: 72 additions & 0 deletions examples/flower-client-authentication/generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash
# This script will generate all certificates if ca.crt does not exist

set -e
# Change directory to the script's directory
cd "$(dirname "${BASH_SOURCE[0]}")"

CERT_DIR=certificates

# Generate directories if not exists
mkdir -p $CERT_DIR

# Clearing any existing files in the certificates directory
rm -f $CERT_DIR/*

# Generate the root certificate authority key and certificate based on key
openssl genrsa -out $CERT_DIR/ca.key 4096
openssl req \
-new \
-x509 \
-key $CERT_DIR/ca.key \
-sha256 \
-subj "/C=DE/ST=HH/O=CA, Inc." \
-days 365 -out $CERT_DIR/ca.crt

# Generate a new private key for the server
openssl genrsa -out $CERT_DIR/server.key 4096

# Create a signing CSR
openssl req \
-new \
-key $CERT_DIR/server.key \
-out $CERT_DIR/server.csr \
-config certificate.conf

# Generate a certificate for the server
openssl x509 \
-req \
-in $CERT_DIR/server.csr \
-CA $CERT_DIR/ca.crt \
-CAkey $CERT_DIR/ca.key \
-CAcreateserial \
-out $CERT_DIR/server.pem \
-days 365 \
-sha256 \
-extfile certificate.conf \
-extensions req_ext

KEY_DIR=keys

mkdir -p $KEY_DIR

rm -f $KEY_DIR/*

ssh-keygen -t ecdsa -b 384 -N "" -f "${KEY_DIR}/server_credentials" -C ""

generate_client_credentials() {
local num_clients=${1:-2}
for ((i=1; i<=num_clients; i++))
do
ssh-keygen -t ecdsa -b 384 -N "" -f "${KEY_DIR}/client_credentials_$i" -C ""
done
}

generate_client_credentials "$1"

printf "%s" "$(cat "${KEY_DIR}/client_credentials_1.pub" | sed 's/.$//')" > $KEY_DIR/client_public_keys.csv
for ((i=2; i<=${1:-2}; i++))
do
printf ",%s" "$(sed 's/.$//' < "${KEY_DIR}/client_credentials_$i.pub")" >> $KEY_DIR/client_public_keys.csv
done
printf "\n" >> $KEY_DIR/client_public_keys.csv
20 changes: 20 additions & 0 deletions examples/flower-client-authentication/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "flower-client-authentication"
version = "0.1.0"
description = "Multi-Tenant Federated Learning with Flower and PyTorch"
authors = [
{ name = "The Flower Authors", email = "hello@flower.ai" },
]
dependencies = [
"flwr-nightly[rest,simulation]",
"torch==1.13.1",
"torchvision==0.14.1",
"tqdm==4.65.0"
]

[tool.hatch.build.targets.wheel]
packages = ["."]
42 changes: 42 additions & 0 deletions examples/flower-client-authentication/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics
from flwr.server.strategy.fedavg import FedAvg
from flwr.server import ServerApp


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [
num_examples * m["train_accuracy"] for num_examples, m in metrics
]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
"val_loss": sum(val_losses) / sum(examples),
"val_accuracy": sum(val_accuracies) / sum(examples),
}


# Define strategy
strategy = FedAvg(
fraction_fit=1.0, # Select all available clients
fraction_evaluate=0.0, # Disable evaluation
min_available_clients=2,
fit_metrics_aggregation_fn=weighted_average,
)


app = ServerApp(
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)
Loading