diff --git a/examples/flower-client-authentication/README.md b/examples/flower-client-authentication/README.md new file mode 100644 index 000000000000..7c724fc26f64 --- /dev/null +++ b/examples/flower-client-authentication/README.md @@ -0,0 +1,105 @@ +# Flower Client Authentication with PyTorch ๐Ÿงช + +> ๐Ÿงช = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to start a long-running Flower server (SuperLink) and a long-running Flower client (SuperNode) with client authentication enabled. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/flower-client-authentication . && rm -rf _tmp && cd flower-client-authentication +``` + +This will create a new directory called `flower-client-authentication` with the following project structure: + +```bash +$ tree . +. +โ”œโ”€โ”€ certificate.conf # <-- configuration for OpenSSL +โ”œโ”€โ”€ generate.sh # <-- generate certificates and keys +โ”œโ”€โ”€ pyproject.toml # <-- project dependencies +โ”œโ”€โ”€ client.py # <-- contains `ClientApp` +โ”œโ”€โ”€ server.py # <-- contains `ServerApp` +โ””โ”€โ”€ task.py # <-- task-specific code (model, data) +``` + +## Install dependencies + +Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml`. You can install the dependencies by invoking `pip`: + +```shell +# From a new python environment, run: +pip install . +``` + +Then, to verify that everything works correctly you can run the following command: + +```shell +python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +## Generate public and private keys + +```bash +./generate.sh +``` + +`generate.sh` is a script that (by default) generates certificates for creating a secure TLS connection +and three private and public key pairs for one server and two clients. +You can generate more keys by specifying the number of client credentials that you wish to generate. +The script also generates a CSV file that includes each of the generated (client) public keys. + +โš ๏ธ Note that this script should only be used for development purposes and not for creating production key pairs. + +```bash +./generate.sh {your_number_of_clients} +``` + +## Start the long-running Flower server (SuperLink) + +To start a long-running Flower server and enable client authentication is very easy; all you need to do is type +`--require-client-authentication` followed by the path to the known `client_public_keys.csv`, server's private key +`server_credentials`, and server's public key `server_credentials.pub`. Notice that you can only enable client +authentication with a secure TLS connection. + +```bash +flower-superlink \ + --certificates certificates/ca.crt certificates/server.pem certificates/server.key \ + --require-client-authentication keys/client_public_keys.csv keys/server_credentials keys/server_credentials.pub +``` + +## Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client: + +```bash +flower-client-app client:app \ + --root-certificates certificates/ca.crt \ + --server 127.0.0.1:9092 \ + --authentication-keys keys/client_credentials_1 keys/client_credentials_1.pub +``` + +In yet another new terminal window, start the second long-running Flower client: + +```bash +flower-client-app client:app \ + --root-certificates certificates/ca.crt \ + --server 127.0.0.1:9092 \ + --authentication-keys keys/client_credentials_2 keys/client_credentials_2.pub +``` + +If you generated more than 2 client credentials, you can add more clients by opening new terminal windows and running the command +above. Don't forget to specify the correct client private and public keys for each client instance you created. + +## Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower ServerApp: + +```bash +flower-server-app server:app --root-certificates certificates/ca.crt --dir ./ --server 127.0.0.1:9091 +``` diff --git a/examples/flower-client-authentication/certificate.conf b/examples/flower-client-authentication/certificate.conf new file mode 100644 index 000000000000..ea97fcbb700d --- /dev/null +++ b/examples/flower-client-authentication/certificate.conf @@ -0,0 +1,20 @@ +[req] +default_bits = 4096 +prompt = no +default_md = sha256 +req_extensions = req_ext +distinguished_name = dn + +[dn] +C = DE +ST = HH +O = Flower +CN = localhost + +[req_ext] +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = ::1 +IP.2 = 127.0.0.1 diff --git a/examples/flower-client-authentication/client.py b/examples/flower-client-authentication/client.py new file mode 100644 index 000000000000..3c99d5a410c9 --- /dev/null +++ b/examples/flower-client-authentication/client.py @@ -0,0 +1,43 @@ +from typing import Dict +from flwr.common import NDArrays, Scalar +from flwr.client import ClientApp, NumPyClient + +from task import ( + Net, + DEVICE, + load_data, + get_parameters, + set_parameters, + train, + test, +) + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client and client_fn +class FlowerClient(NumPyClient): + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + return get_parameters(net) + + def fit(self, parameters, config): + set_parameters(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_parameters(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_parameters(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +app = ClientApp( + client_fn=client_fn, +) diff --git a/examples/flower-client-authentication/generate.sh b/examples/flower-client-authentication/generate.sh new file mode 100644 index 000000000000..ebfdc17b80b5 --- /dev/null +++ b/examples/flower-client-authentication/generate.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# This script will generate all certificates if ca.crt does not exist + +set -e +# Change directory to the script's directory +cd "$(dirname "${BASH_SOURCE[0]}")" + +CERT_DIR=certificates + +# Generate directories if not exists +mkdir -p $CERT_DIR + +# Clearing any existing files in the certificates directory +rm -f $CERT_DIR/* + +# Generate the root certificate authority key and certificate based on key +openssl genrsa -out $CERT_DIR/ca.key 4096 +openssl req \ + -new \ + -x509 \ + -key $CERT_DIR/ca.key \ + -sha256 \ + -subj "/C=DE/ST=HH/O=CA, Inc." \ + -days 365 -out $CERT_DIR/ca.crt + +# Generate a new private key for the server +openssl genrsa -out $CERT_DIR/server.key 4096 + +# Create a signing CSR +openssl req \ + -new \ + -key $CERT_DIR/server.key \ + -out $CERT_DIR/server.csr \ + -config certificate.conf + +# Generate a certificate for the server +openssl x509 \ + -req \ + -in $CERT_DIR/server.csr \ + -CA $CERT_DIR/ca.crt \ + -CAkey $CERT_DIR/ca.key \ + -CAcreateserial \ + -out $CERT_DIR/server.pem \ + -days 365 \ + -sha256 \ + -extfile certificate.conf \ + -extensions req_ext + +KEY_DIR=keys + +mkdir -p $KEY_DIR + +rm -f $KEY_DIR/* + +ssh-keygen -t ecdsa -b 384 -N "" -f "${KEY_DIR}/server_credentials" -C "" + +generate_client_credentials() { + local num_clients=${1:-2} + for ((i=1; i<=num_clients; i++)) + do + ssh-keygen -t ecdsa -b 384 -N "" -f "${KEY_DIR}/client_credentials_$i" -C "" + done +} + +generate_client_credentials "$1" + +printf "%s" "$(cat "${KEY_DIR}/client_credentials_1.pub" | sed 's/.$//')" > $KEY_DIR/client_public_keys.csv +for ((i=2; i<=${1:-2}; i++)) +do + printf ",%s" "$(sed 's/.$//' < "${KEY_DIR}/client_credentials_$i.pub")" >> $KEY_DIR/client_public_keys.csv +done +printf "\n" >> $KEY_DIR/client_public_keys.csv diff --git a/examples/flower-client-authentication/pyproject.toml b/examples/flower-client-authentication/pyproject.toml new file mode 100644 index 000000000000..c3e606dd8585 --- /dev/null +++ b/examples/flower-client-authentication/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "flower-client-authentication" +version = "0.1.0" +description = "Multi-Tenant Federated Learning with Flower and PyTorch" +authors = [ + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +dependencies = [ + "flwr-nightly[rest,simulation]", + "torch==1.13.1", + "torchvision==0.14.1", + "tqdm==4.65.0" +] + +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/examples/flower-client-authentication/server.py b/examples/flower-client-authentication/server.py new file mode 100644 index 000000000000..d88dc1d1a641 --- /dev/null +++ b/examples/flower-client-authentication/server.py @@ -0,0 +1,42 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics +from flwr.server.strategy.fedavg import FedAvg +from flwr.server import ServerApp + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, +) + + +app = ServerApp( + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/flower-client-authentication/task.py b/examples/flower-client-authentication/task.py new file mode 100644 index 000000000000..276aace885df --- /dev/null +++ b/examples/flower-client-authentication/task.py @@ -0,0 +1,95 @@ +import warnings +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor +from tqdm import tqdm + + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + print("Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in tqdm(testloader): + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def get_parameters(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_parameters(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True)