Skip to content

Commit

Permalink
pfedpara: pre-release
Browse files Browse the repository at this point in the history
  • Loading branch information
yehias21 committed Jan 7, 2024
1 parent 4455932 commit 3f5f54c
Show file tree
Hide file tree
Showing 7 changed files with 438 additions and 83 deletions.
102 changes: 88 additions & 14 deletions baselines/fedpara/fedpara/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Client for FedPara."""

from collections import OrderedDict
from typing import Callable, Dict, List, Tuple

from typing import Callable, Dict, List, Tuple, Optional
import copy
import flwr as fl
import torch
from flwr.common import NDArrays, Scalar
Expand Down Expand Up @@ -60,24 +60,98 @@ def fit(
{},
)

class PFedParaClient(fl.client.NumPyClient):
"""personalized FedPara Client"""
def __init__(
self,
cid: int,
net: torch.nn.Module,
train_loader: DataLoader,
test_dataset: List[DataLoader],
device: str,
num_epochs: int,
state_path: str,
):

self.cid = cid
self.net = net
self.train_loader = train_loader
self.test_dataset = test_dataset
self.device = torch.device(device)
self.num_epochs = num_epochs
self.state_path = state_path

def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def _set_parameters(self, parameters: NDArrays) -> None:
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)

def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Train the network on the training set."""
self._set_parameters(parameters)
print(f"Client {self.cid} Training...")

train(
self.net,
self.train_loader,
self.device,
epochs=self.num_epochs,
hyperparams=config,
round=config["curr_round"],
)

return (
self.get_parameters({}),
len(self.train_loader),
{},
)
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[int, float, Dict]:
"""Evaluate the network on the test set."""
self._set_parameters(parameters)
print(f"Client {self.cid} Evaluating...")

return (
len(self.test_dataset[self.cid]),
train.test(self.net, self.test_dataset[self.cid], self.device),
{},
)

def gen_client_fn(
train_loaders: List[DataLoader],
model: DictConfig,
num_epochs: int,
) -> Callable[[str], FlowerClient]:
args: Dict,
test_loader: Optional[List[DataLoader]]=None,
state_path: Optional[str]=None,
) -> Callable[[str], fl.client.NumPyClient]:
"""Return a function which creates a new FlowerClient for a given cid."""

def client_fn(cid: str) -> FlowerClient:
def client_fn(cid: str) -> fl.client.NumPyClient:
"""Create a new FlowerClient for a given cid."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

return FlowerClient(
cid=int(cid),
net=instantiate(model).to(device),
train_loader=train_loaders[int(cid)],
device=device,
num_epochs=num_epochs,
)

cid = int(cid)
if args['algorithm'] == "pfedpara" or args['algorithm'] == "fedper":
return PFedParaClient(
cid=cid,
net=instantiate(model).to(args["device"]),
train_loader=train_loaders[cid],
test_dataset=copy.deepcopy(test_loader),
device=args["device"],
num_epochs=num_epochs,
state_path=state_path,
)
else:
return FlowerClient(
cid=cid,
net=instantiate(model).to(args["device"]),
train_loader=train_loaders[cid],
device=args["device"],
num_epochs=num_epochs,
)

return client_fn
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,34 @@ num_rounds: 100
clients_per_round: 10
num_epochs: 5
batch_size: 10

state_pah: ./state/
server_device: cuda

client_resources:
num_cpus: 2
num_gpus: 0.0625

dataset_config:
name: FEMNIST
partition: non-iid #redundent
num_classes: 62
alpha: 0 # redundant

name: MNIST
num_classes: 10
shard_size: 300
data_seed: ${seed}
model:
_target_: fedpara.models.VGG
_target_: fedpara.models.FC
num_classes: ${dataset_config.num_classes}
conv_type: lowrank # lowrank or standard
weights: lowrank # lowrank or standard
activation: relu # relu or leaky_relu
ratio: 0.1 # lowrank ratio
ratio: 0.5 # lowrank ratio

hyperparams:
eta_l: 0.1
eta_l: 0.01
learning_decay: 0.999
momentum: 0.0
weight_decay: 0

strategy:
_target_: fedpara.strategy.pFedPara
algorithm: pFedPara
fraction_fit: 0.00001
fraction_evaluate: 0.0
min_evaluate_clients: 0
min_evaluate_clients: ${clients_per_round}
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
accept_failures: false

min_available_clients: ${clients_per_round}
97 changes: 63 additions & 34 deletions baselines/fedpara/fedpara/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,75 @@
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from fedpara.dataset_preparation import DatasetSplit, iid, noniid
from fedpara.dataset_preparation import DatasetSplit, iid, noniid, mnist_niid


def load_datasets(
config, num_clients, batch_size
) -> Tuple[List[DataLoader], DataLoader]:
"""Load the dataset and return the dataloaders for the clients and the server."""
print("Loading data...")
if config.name == "CIFAR10":
Dataset = datasets.CIFAR10
elif config.name == "CIFAR100":
Dataset = datasets.CIFAR100
else:
raise NotImplementedError
match config.name:
case "CIFAR10":
Dataset = datasets.CIFAR10
case "CIFAR100":
Dataset = datasets.CIFAR100
case "MNIST":
Dataset = datasets.MNIST
case _:
raise NotImplementedError
data_directory = f"./data/{config.name.lower()}/"
ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl"
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
try:
with open(ds_path, "rb") as file:
train_datasets = pickle.load(file)
except FileNotFoundError:
dataset_train = Dataset(
data_directory, train=True, download=True, transform=transform_train
)
if config.partition == "iid":
train_datasets = iid(dataset_train, num_clients)
else:
train_datasets, _ = noniid(dataset_train, num_clients, config.alpha)
match config.name:
case "CIFAR10" | "CIFAR100":
ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl"
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
try:
with open(ds_path, "rb") as file:
train_datasets = pickle.load(file)
except FileNotFoundError:
dataset_train = Dataset(
data_directory, train=True, download=True, transform=transform_train
)
if config.partition == "iid":
train_datasets = iid(dataset_train, num_clients)
else:
train_datasets, _ = noniid(dataset_train, num_clients, config.alpha)
pickle.dump(train_datasets, open(ds_path, "wb"))
train_datasets = train_datasets.values()

case "MNIST":
ds_path = f"{data_directory}train_{num_clients}.pkl"
transform_train = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
try:
train_datasets = pickle.load(open(ds_path, "rb"))
except FileNotFoundError:
dataset_train = Dataset(
data_directory, train=True, download=True, transform=transform_train
)
train_datasets = mnist_niid(dataset_train, num_clients, config.shard_size, config.data_seed)
pickle.dump(train_datasets, open(ds_path, "wb"))

dataset_test = Dataset(
data_directory, train=False, download=True, transform=transform_test
)
Expand All @@ -58,7 +86,8 @@ def load_datasets(
shuffle=True,
num_workers=2,
)
for ids in train_datasets.values()
for ids in train_datasets
]

return train_loaders, test_loader

15 changes: 14 additions & 1 deletion baselines/fedpara/fedpara/dataset_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import numpy as np
from torch.utils.data import Dataset

import logging
from collections import Counter

class DatasetSplit(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class."""
Expand Down Expand Up @@ -97,3 +98,15 @@ def noniid(dataset, no_participants, alpha=0.5):
for j in range(no_classes):
clas_weight[i, j] = float(datasize[i, j]) / float((train_img_size[i]))
return per_participant_list, clas_weight

def mnist_niid(dataset: Dataset, num_clients: int, shard_size: int, seed: int) -> np.ndarray:
""" Partitioning technique as mentioned in https://arxiv.org/pdf/1602.05629.pdf"""
indices = dataset.targets[np.argsort(dataset.targets)].numpy()
logging.debug(Counter(dataset.targets[indices].numpy()))
silos = np.array_split(indices, len(dataset) // shard_size)# randomly assign silos to clients
np.random.seed(seed+17)
np.random.shuffle(silos)
clients = np.array(np.array_split(silos, num_clients)).reshape(num_clients, -1)
logging.debug(clients.shape)
logging.debug(Counter([len(Counter(dataset.targets[client].numpy())) for client in clients]))
return clients
64 changes: 47 additions & 17 deletions baselines/fedpara/fedpara/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,58 @@ def main(cfg: DictConfig) -> None:
)

# 3. Define clients
client_fn = client.gen_client_fn(
train_loaders=train_loaders,
model=cfg.model,
num_epochs=cfg.num_epochs,
)
# In this scheme the responsability of choosing the client is on the client manager
if cfg.strategy.min_evaluate_clients:
client_fn = client.gen_client_fn(
train_loaders=train_loaders,
test_loader=test_loader,
model=cfg.model,
num_epochs=cfg.num_epochs,
args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm},
state_path=cfg.state_path,
)
else :
client_fn = client.gen_client_fn(
train_loaders=train_loaders,
model=cfg.model,
num_epochs=cfg.num_epochs,
args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm},
)

evaluate_fn = server.gen_evaluate_fn(
num_clients=cfg.num_clients,
test_loader=test_loader,
model=cfg.model,
device=cfg.server_device,
)
if not cfg.strategy.min_evaluate_clients :
evaluate_fn = server.gen_evaluate_fn(
num_clients=cfg.num_clients,
test_loader=test_loader,
model=cfg.model,
device=cfg.server_device,
state_path=cfg.state_path,
)

def get_on_fit_config():
def fit_config_fn(server_round: int):
fit_config = OmegaConf.to_container(cfg.hyperparams, resolve=True)
fit_config["curr_round"] = server_round
return fit_config

return fit_config_fn

net_glob = instantiate(cfg.model)

# 4. Define strategy
strategy = instantiate(
cfg.strategy,
evaluate_fn=evaluate_fn,
on_fit_config_fn=server.get_on_fit_config(dict(cfg.hyperparams)),
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)),
)
if cfg.strategy.min_evaluate_clients:
strategy = instantiate(
cfg.strategy,
on_fit_config_fn=get_on_fit_config(),
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)),
)
else :
strategy = instantiate(
cfg.strategy,
evaluate_fn=evaluate_fn,
on_fit_config_fn=get_on_fit_config(),
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)),
)


# 5. Start Simulation
history = fl.simulation.start_simulation(
Expand Down
Loading

0 comments on commit 3f5f54c

Please sign in to comment.