-
Notifications
You must be signed in to change notification settings - Fork 922
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add flower-client-authentication code example (#2999)
Co-authored-by: Daniel J. Beutel <daniel@flower.ai> Co-authored-by: Chong Shen Ng <ngchongshen@gmail.com>
- Loading branch information
1 parent
aab7dfa
commit 6242c95
Showing
7 changed files
with
397 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Flower Client Authentication with PyTorch 🧪 | ||
|
||
> 🧪 = This example covers experimental features that might change in future versions of Flower | ||
> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. | ||
The following steps describe how to start a long-running Flower server (SuperLink) and a long-running Flower client (SuperNode) with client authentication enabled. | ||
|
||
## Project Setup | ||
|
||
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: | ||
|
||
```shell | ||
git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/flower-client-authentication . && rm -rf _tmp && cd flower-client-authentication | ||
``` | ||
|
||
This will create a new directory called `flower-client-authentication` with the following project structure: | ||
|
||
```bash | ||
$ tree . | ||
. | ||
├── certificate.conf # <-- configuration for OpenSSL | ||
├── generate.sh # <-- generate certificates and keys | ||
├── pyproject.toml # <-- project dependencies | ||
├── client.py # <-- contains `ClientApp` | ||
├── server.py # <-- contains `ServerApp` | ||
└── task.py # <-- task-specific code (model, data) | ||
``` | ||
|
||
## Install dependencies | ||
|
||
Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml`. You can install the dependencies by invoking `pip`: | ||
|
||
```shell | ||
# From a new python environment, run: | ||
pip install . | ||
``` | ||
|
||
Then, to verify that everything works correctly you can run the following command: | ||
|
||
```shell | ||
python3 -c "import flwr" | ||
``` | ||
|
||
If you don't see any errors you're good to go! | ||
|
||
## Generate public and private keys | ||
|
||
```bash | ||
./generate.sh | ||
``` | ||
|
||
`generate.sh` is a script that (by default) generates certificates for creating a secure TLS connection | ||
and three private and public key pairs for one server and two clients. | ||
You can generate more keys by specifying the number of client credentials that you wish to generate. | ||
The script also generates a CSV file that includes each of the generated (client) public keys. | ||
|
||
⚠️ Note that this script should only be used for development purposes and not for creating production key pairs. | ||
|
||
```bash | ||
./generate.sh {your_number_of_clients} | ||
``` | ||
|
||
## Start the long-running Flower server (SuperLink) | ||
|
||
To start a long-running Flower server and enable client authentication is very easy; all you need to do is type | ||
`--require-client-authentication` followed by the path to the known `client_public_keys.csv`, server's private key | ||
`server_credentials`, and server's public key `server_credentials.pub`. Notice that you can only enable client | ||
authentication with a secure TLS connection. | ||
|
||
```bash | ||
flower-superlink \ | ||
--certificates certificates/ca.crt certificates/server.pem certificates/server.key \ | ||
--require-client-authentication keys/client_public_keys.csv keys/server_credentials keys/server_credentials.pub | ||
``` | ||
|
||
## Start the long-running Flower client (SuperNode) | ||
|
||
In a new terminal window, start the first long-running Flower client: | ||
|
||
```bash | ||
flower-client-app client:app \ | ||
--root-certificates certificates/ca.crt \ | ||
--server 127.0.0.1:9092 \ | ||
--authentication-keys keys/client_credentials_1 keys/client_credentials_1.pub | ||
``` | ||
|
||
In yet another new terminal window, start the second long-running Flower client: | ||
|
||
```bash | ||
flower-client-app client:app \ | ||
--root-certificates certificates/ca.crt \ | ||
--server 127.0.0.1:9092 \ | ||
--authentication-keys keys/client_credentials_2 keys/client_credentials_2.pub | ||
``` | ||
|
||
If you generated more than 2 client credentials, you can add more clients by opening new terminal windows and running the command | ||
above. Don't forget to specify the correct client private and public keys for each client instance you created. | ||
|
||
## Run the Flower App | ||
|
||
With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower ServerApp: | ||
|
||
```bash | ||
flower-server-app server:app --root-certificates certificates/ca.crt --dir ./ --server 127.0.0.1:9091 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
[req] | ||
default_bits = 4096 | ||
prompt = no | ||
default_md = sha256 | ||
req_extensions = req_ext | ||
distinguished_name = dn | ||
|
||
[dn] | ||
C = DE | ||
ST = HH | ||
O = Flower | ||
CN = localhost | ||
|
||
[req_ext] | ||
subjectAltName = @alt_names | ||
|
||
[alt_names] | ||
DNS.1 = localhost | ||
IP.1 = ::1 | ||
IP.2 = 127.0.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Dict | ||
from flwr.common import NDArrays, Scalar | ||
from flwr.client import ClientApp, NumPyClient | ||
|
||
from task import ( | ||
Net, | ||
DEVICE, | ||
load_data, | ||
get_parameters, | ||
set_parameters, | ||
train, | ||
test, | ||
) | ||
|
||
|
||
# Load model and data (simple CNN, CIFAR-10) | ||
net = Net().to(DEVICE) | ||
trainloader, testloader = load_data() | ||
|
||
|
||
# Define Flower client and client_fn | ||
class FlowerClient(NumPyClient): | ||
def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: | ||
return get_parameters(net) | ||
|
||
def fit(self, parameters, config): | ||
set_parameters(net, parameters) | ||
results = train(net, trainloader, testloader, epochs=1, device=DEVICE) | ||
return get_parameters(net), len(trainloader.dataset), results | ||
|
||
def evaluate(self, parameters, config): | ||
set_parameters(net, parameters) | ||
loss, accuracy = test(net, testloader) | ||
return loss, len(testloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
def client_fn(cid: str): | ||
return FlowerClient().to_client() | ||
|
||
|
||
app = ClientApp( | ||
client_fn=client_fn, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#!/bin/bash | ||
# This script will generate all certificates if ca.crt does not exist | ||
|
||
set -e | ||
# Change directory to the script's directory | ||
cd "$(dirname "${BASH_SOURCE[0]}")" | ||
|
||
CERT_DIR=certificates | ||
|
||
# Generate directories if not exists | ||
mkdir -p $CERT_DIR | ||
|
||
# Clearing any existing files in the certificates directory | ||
rm -f $CERT_DIR/* | ||
|
||
# Generate the root certificate authority key and certificate based on key | ||
openssl genrsa -out $CERT_DIR/ca.key 4096 | ||
openssl req \ | ||
-new \ | ||
-x509 \ | ||
-key $CERT_DIR/ca.key \ | ||
-sha256 \ | ||
-subj "/C=DE/ST=HH/O=CA, Inc." \ | ||
-days 365 -out $CERT_DIR/ca.crt | ||
|
||
# Generate a new private key for the server | ||
openssl genrsa -out $CERT_DIR/server.key 4096 | ||
|
||
# Create a signing CSR | ||
openssl req \ | ||
-new \ | ||
-key $CERT_DIR/server.key \ | ||
-out $CERT_DIR/server.csr \ | ||
-config certificate.conf | ||
|
||
# Generate a certificate for the server | ||
openssl x509 \ | ||
-req \ | ||
-in $CERT_DIR/server.csr \ | ||
-CA $CERT_DIR/ca.crt \ | ||
-CAkey $CERT_DIR/ca.key \ | ||
-CAcreateserial \ | ||
-out $CERT_DIR/server.pem \ | ||
-days 365 \ | ||
-sha256 \ | ||
-extfile certificate.conf \ | ||
-extensions req_ext | ||
|
||
KEY_DIR=keys | ||
|
||
mkdir -p $KEY_DIR | ||
|
||
rm -f $KEY_DIR/* | ||
|
||
ssh-keygen -t ecdsa -b 384 -N "" -f "${KEY_DIR}/server_credentials" -C "" | ||
|
||
generate_client_credentials() { | ||
local num_clients=${1:-2} | ||
for ((i=1; i<=num_clients; i++)) | ||
do | ||
ssh-keygen -t ecdsa -b 384 -N "" -f "${KEY_DIR}/client_credentials_$i" -C "" | ||
done | ||
} | ||
|
||
generate_client_credentials "$1" | ||
|
||
printf "%s" "$(cat "${KEY_DIR}/client_credentials_1.pub" | sed 's/.$//')" > $KEY_DIR/client_public_keys.csv | ||
for ((i=2; i<=${1:-2}; i++)) | ||
do | ||
printf ",%s" "$(sed 's/.$//' < "${KEY_DIR}/client_credentials_$i.pub")" >> $KEY_DIR/client_public_keys.csv | ||
done | ||
printf "\n" >> $KEY_DIR/client_public_keys.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "flower-client-authentication" | ||
version = "0.1.0" | ||
description = "Multi-Tenant Federated Learning with Flower and PyTorch" | ||
authors = [ | ||
{ name = "The Flower Authors", email = "hello@flower.ai" }, | ||
] | ||
dependencies = [ | ||
"flwr-nightly[rest,simulation]", | ||
"torch==1.13.1", | ||
"torchvision==0.14.1", | ||
"tqdm==4.65.0" | ||
] | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["."] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import List, Tuple | ||
|
||
import flwr as fl | ||
from flwr.common import Metrics | ||
from flwr.server.strategy.fedavg import FedAvg | ||
from flwr.server import ServerApp | ||
|
||
|
||
# Define metric aggregation function | ||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Multiply accuracy of each client by number of examples used | ||
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] | ||
train_accuracies = [ | ||
num_examples * m["train_accuracy"] for num_examples, m in metrics | ||
] | ||
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] | ||
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return { | ||
"train_loss": sum(train_losses) / sum(examples), | ||
"train_accuracy": sum(train_accuracies) / sum(examples), | ||
"val_loss": sum(val_losses) / sum(examples), | ||
"val_accuracy": sum(val_accuracies) / sum(examples), | ||
} | ||
|
||
|
||
# Define strategy | ||
strategy = FedAvg( | ||
fraction_fit=1.0, # Select all available clients | ||
fraction_evaluate=0.0, # Disable evaluation | ||
min_available_clients=2, | ||
fit_metrics_aggregation_fn=weighted_average, | ||
) | ||
|
||
|
||
app = ServerApp( | ||
config=fl.server.ServerConfig(num_rounds=3), | ||
strategy=strategy, | ||
) |
Oops, something went wrong.