Skip to content

Commit

Permalink
refactor(framework) Update mlx,tf and pytorch templates (#3933)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
  • Loading branch information
jafermarq and danieljanes authored Aug 5, 2024
1 parent abd833d commit 6a639c6
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 58 deletions.
15 changes: 8 additions & 7 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""$project_name: A Flower / $framework_str app."""

import torch
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from $import_name.task import (
Net,
DEVICE,
load_data,
get_weights,
set_weights,
Expand All @@ -21,27 +21,28 @@ class FlowerClient(NumPyClient):
self.trainloader = trainloader
self.valloader = valloader
self.local_epochs = local_epochs
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)

def fit(self, parameters, config):
set_weights(self.net, parameters)
results = train(
train_loss = train(
self.net,
self.trainloader,
self.valloader,
self.local_epochs,
DEVICE,
self.device,
)
return get_weights(self.net), len(self.trainloader.dataset), results
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}

def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
loss, accuracy = test(self.net, self.valloader, self.device)
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
# Load model and data
net = Net().to(DEVICE)
net = Net()
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, valloader = load_data(partition_id, num_partitions)
Expand Down
11 changes: 4 additions & 7 deletions src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@ from $import_name.task import load_data, load_model
# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(
self, model, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
self, model, data, epochs, batch_size, verbose
):
self.model = model
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test
self.x_train, self.y_train, self.x_test, self.y_test = data
self.epochs = epochs
self.batch_size = batch_size
self.verbose = verbose
Expand Down Expand Up @@ -46,14 +43,14 @@ def client_fn(context: Context):

partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
data = load_data(partition_id, num_partitions)
epochs = context.run_config["local-epochs"]
batch_size = context.run_config["batch-size"]
verbose = context.run_config.get("verbose")

# Return Client instance
return FlowerClient(
net, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
net, data, epochs, batch_size, verbose
).to_client()


Expand Down
11 changes: 6 additions & 5 deletions src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ from flwr.server.strategy import FedAvg
from $import_name.task import Net, get_weights


# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)

def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]

# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)

# Define strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@ from flwr.server.strategy import FedAvg

from $import_name.task import load_model

# Define config
config = ServerConfig(num_rounds=3)

parameters = ndarrays_to_parameters(load_model().get_weights())

def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Get parameters to initialize global model
parameters = ndarrays_to_parameters(load_model().get_weights())

# Define strategy
strategy = strategy = FedAvg(
fraction_fit=1.0,
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def load_data(partition_id: int, num_partitions: int):
fds = FederatedDataset(
dataset="ylecun/mnist",
partitioners={"train": partitioner},
trust_remote_code=True,
)
partition = fds.load_partition(partition_id)
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
Expand Down
32 changes: 13 additions & 19 deletions src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

Expand Down Expand Up @@ -66,44 +63,41 @@ def load_data(partition_id: int, num_partitions: int):
return trainloader, testloader


def train(net, trainloader, valloader, epochs, device):
def train(net, trainloader, epochs, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
net.train()
running_loss = 0.0
for _ in range(epochs):
for batch in trainloader:
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
loss = criterion(net(images.to(device)), labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()

train_loss, train_acc = test(net, trainloader)
val_loss, val_acc = test(net, valloader)

results = {
"train_loss": train_loss,
"train_accuracy": train_acc,
"val_loss": val_loss,
"val_accuracy": val_acc,
}
return results
avg_trainloss = running_loss / len(trainloader)
return avg_trainloss


def test(net, testloader):
def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(DEVICE)
labels = batch["label"].to(DEVICE)
images = batch["img"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy


Expand Down
18 changes: 15 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/task.tensorflow.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import os

import tensorflow as tf
import keras
from keras import layers
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

Expand All @@ -12,8 +13,19 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def load_model():
# Load model and data (MobileNetV2, CIFAR-10)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
# Define a simple CNN for CIFAR-10 and set Adam optimizer
model = keras.Sequential(
[
keras.Input(shape=(32, 32, 3)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(10, activation="softmax"),
]
)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr-datasets>=0.0.2,<1.0.0",
"flwr[simulation]>=1.10.0",
"flwr-datasets>=0.3.0",
"torch==2.2.1",
"transformers>=4.30.0,<5.0",
"evaluate>=0.4.0,<1.0",
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr[simulation]>=1.10.0",
"jax==0.4.13",
"jaxlib==0.4.13",
"scikit-learn==1.3.2",
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr-datasets[vision]>=0.0.2,<1.0.0",
"mlx==0.10.0",
"flwr[simulation]>=1.10.0",
"flwr-datasets[vision]>=0.3.0",
"mlx==0.16.1",
"numpy==1.24.4",
]

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr[simulation]>=1.10.0",
"numpy>=1.21.0",
]

Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr-datasets[vision]>=0.0.2,<1.0.0",
"flwr[simulation]>=1.10.0",
"flwr-datasets[vision]>=0.3.0",
"torch==2.2.1",
"torchvision==0.17.1",
]
Expand All @@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
fraction-fit = 0.5
local-epochs = 1

[tool.flwr.federations]
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr-datasets[vision]>=0.0.2,<1.0.0",
"flwr[simulation]>=1.10.0",
"flwr-datasets[vision]>=0.3.0",
"scikit-learn>=1.1.1",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.9.0,<2.0",
"flwr-datasets[vision]>=0.0.2,<1.0.0",
"flwr[simulation]>=1.10.0",
"flwr-datasets[vision]>=0.3.0",
"tensorflow>=2.11.1",
]

Expand Down

0 comments on commit 6a639c6

Please sign in to comment.