Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Update huggingface template for flwr new #4169

Merged
merged 10 commits into from
Sep 11, 2024
48 changes: 19 additions & 29 deletions src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
"""$project_name: A Flower / $framework_str app."""

import torch
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from transformers import AutoModelForSequenceClassification

from $import_name.task import (
get_weights,
load_data,
set_weights,
train,
test,
CHECKPOINT,
DEVICE,
)
from $import_name.task import get_weights, load_data, set_weights, test, train


# Flower client
Expand All @@ -22,37 +15,34 @@ class FlowerClient(NumPyClient):
self.trainloader = trainloader
self.testloader = testloader
self.local_epochs = local_epochs

def get_parameters(self, config):
return get_weights(self.net)

def set_parameters(self, parameters):
set_weights(self.net, parameters)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(
self.net,
self.trainloader,
epochs=self.local_epochs,
)
return self.get_parameters(config={}), len(self.trainloader), {}
set_weights(self.net, parameters)
train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
return get_weights(self.net), len(self.trainloader), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(self.net, self.testloader)
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.testloader, self.device)
return float(loss), len(self.testloader), {"accuracy": accuracy}


def client_fn(context: Context):
# Load model and data
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)

# Get this client's dataset partition
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, valloader = load_data(partition_id, num_partitions)
model_name = context.run_config["model-name"]
trainloader, valloader = load_data(partition_id, num_partitions, model_name)

# Load model
num_labels = context.run_config["num-labels"]
net = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)

local_epochs = context.run_config["local-epochs"]

# Return Client instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.server.strategy import FedAvg
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from transformers import AutoModelForSequenceClassification

from $import_name.task import get_weights


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]

# Initialize global model
model_name = context.run_config["model-name"]
num_labels = context.run_config["num-labels"]
net = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)

weights = get_weights(net)
initial_parameters = ndarrays_to_parameters(weights)

# Define strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
initial_parameters=initial_parameters,
)
config = ServerConfig(num_rounds=num_rounds)

Expand Down
29 changes: 16 additions & 13 deletions src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,25 @@ import warnings
from collections import OrderedDict

import torch
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
import transformers
from datasets.utils.logging import disable_progress_bar
from evaluate import load as load_metric
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding

from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cpu")
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
warnings.filterwarnings("ignore", category=FutureWarning)
disable_progress_bar()
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
transformers.logging.set_verbosity_error()


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
def load_data(partition_id: int, num_partitions: int, model_name: str):
"""Load IMDB data (training and eval)"""
# Only initialize `FederatedDataset` once
global fds
Expand All @@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
return tokenizer(
examples["text"], truncation=True, add_special_tokens=True, max_length=512
)

partition_train_test = partition_train_test.map(tokenize_function, batched=True)
partition_train_test = partition_train_test.remove_columns("text")
Expand All @@ -59,25 +62,25 @@ def load_data(partition_id: int, num_partitions: int):
return trainloader, testloader


def train(net, trainloader, epochs):
def train(net, trainloader, epochs, device):
optimizer = AdamW(net.parameters(), lr=5e-5)
net.train()
for _ in range(epochs):
for batch in trainloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
batch = {k: v.to(device) for k, v in batch.items()}
outputs = net(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()


def test(net, testloader):
def test(net, testloader, device):
metric = load_metric("accuracy")
loss = 0
net.eval()
for batch in testloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = net(**batch)
logits = outputs.logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.10.0",
"flwr[simulation]>=1.11.0",
"flwr-datasets>=0.3.0",
"torch==2.2.1",
"transformers>=4.30.0,<5.0",
Expand All @@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
fraction-fit = 0.5
local-epochs = 1
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
num-labels = 2

[tool.flwr.federations]
default = "localhost"

[tool.flwr.federations.localhost]
options.num-supernodes = 10

[tool.flwr.federations.localhost-gpu]
options.num-supernodes = 10
options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU