Skip to content

Commit

Permalink
- Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yehias21 authored and = committed Jan 21, 2024
1 parent 3f5f54c commit 7ea0c2e
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 170 deletions.
2 changes: 2 additions & 0 deletions baselines/fedpara/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
outputs/
multirun/
client_states/
data/
75 changes: 49 additions & 26 deletions baselines/fedpara/fedpara/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from collections import OrderedDict
from typing import Callable, Dict, List, Tuple, Optional
import copy
import copy,os
import flwr as fl
import torch
from flwr.common import NDArrays, Scalar
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from fedpara.models import train
from fedpara.models import train,test


class FlowerClient(fl.client.NumPyClient):
Expand All @@ -34,7 +33,7 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def _set_parameters(self, parameters: NDArrays) -> None:
def set_parameters(self, parameters: NDArrays) -> None:
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)
Expand All @@ -43,7 +42,7 @@ def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Train the network on the training set."""
self._set_parameters(parameters)
self.set_parameters(parameters)

train(
self.net,
Expand All @@ -67,7 +66,7 @@ def __init__(
cid: int,
net: torch.nn.Module,
train_loader: DataLoader,
test_dataset: List[DataLoader],
test_loader: DataLoader,
device: str,
num_epochs: int,
state_path: str,
Expand All @@ -76,25 +75,38 @@ def __init__(
self.cid = cid
self.net = net
self.train_loader = train_loader
self.test_dataset = test_dataset
self.test_loader = test_loader
self.device = torch.device(device)
self.num_epochs = num_epochs
self.state_path = state_path

def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def _set_parameters(self, parameters: NDArrays) -> None:
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)

return [val.cpu().detach().numpy() for _, val in self.net.get_per_param().items()]


def _set_parameters(self, parameters: NDArrays, first_round = False) -> None:
if first_round:
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)
else:
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.set_per_param(state_dict)
def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Train the network on the training set."""
self._set_parameters(parameters)
if not os.path.isfile(self.state_path):
self._set_parameters(parameters,first_round=True)
else:
try:
self.net.load_state_dict(torch.load(self.state_path),strict=False)
except:
print(f"loading {self.state_path} state dict error")
self._set_parameters(parameters)

print(f"Client {self.cid} Training...")

train(
Expand All @@ -103,8 +115,11 @@ def fit(
self.device,
epochs=self.num_epochs,
hyperparams=config,
round=config["curr_round"],
epoch=config["curr_round"],
)
if self.state_path is not None:
with open(self.state_path, 'wb') as f:
torch.save(self.net.get_per_param(), f)

return (
self.get_parameters({}),
Expand All @@ -113,14 +128,21 @@ def fit(
)
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[int, float, Dict]:
"""Evaluate the network on the test set."""
self._set_parameters(parameters)
print(f"Client {self.cid} Evaluating...")

return (
len(self.test_dataset[self.cid]),
train.test(self.net, self.test_dataset[self.cid], self.device),
{},
)
if not os.path.isfile(self.state_path):
self._set_parameters(parameters,first_round=True)
else:
try:
self.net.load_state_dict(torch.load(self.state_path),strict=False)
except:
print(f"loading {self.state_path} state dict error")
self._set_parameters(parameters)

print(f"Client {self.cid} Evaluating...")
self.net.to(self.device)
loss, accuracy = test(self.net, self.test_loader, device=self.device)
return loss, len(self.test_loader), {"accuracy": accuracy}


def gen_client_fn(
train_loaders: List[DataLoader],
Expand All @@ -135,15 +157,16 @@ def gen_client_fn(
def client_fn(cid: str) -> fl.client.NumPyClient:
"""Create a new FlowerClient for a given cid."""
cid = int(cid)
if args['algorithm'] == "pfedpara" or args['algorithm'] == "fedper":
if args['algorithm'].lower() == "pfedpara" or args['algorithm'] == "fedper":
cl_path = f"{state_path}/client_{cid}.pth"
return PFedParaClient(
cid=cid,
net=instantiate(model).to(args["device"]),
train_loader=train_loaders[cid],
test_dataset=copy.deepcopy(test_loader),
test_loader=copy.deepcopy(test_loader),
device=args["device"],
num_epochs=num_epochs,
state_path=state_path,
state_path=cl_path,
)
else:
return FlowerClient(
Expand Down
3 changes: 1 addition & 2 deletions baselines/fedpara/fedpara/conf/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ model:
hyperparams:
eta_l: 0.1
learning_decay: 0.992
momentum: 0.0
weight_decay: 0


strategy:
_target_: fedpara.strategy.FedPara
Expand Down
2 changes: 0 additions & 2 deletions baselines/fedpara/fedpara/conf/cifar100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ model:
hyperparams:
eta_l: 0.1
learning_decay: 0.992
momentum: 0.0
weight_decay: 0

strategy:
_target_: fedpara.strategy.FedPara
Expand Down
14 changes: 9 additions & 5 deletions baselines/fedpara/fedpara/conf/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ num_rounds: 100
clients_per_round: 10
num_epochs: 5
batch_size: 10
state_pah: ./state/
server_device: cuda
state_path: ./client_states/
client_device: cuda
algorithm: pFedPara


client_resources:
num_cpus: 2
Expand All @@ -17,21 +19,23 @@ dataset_config:
name: MNIST
num_classes: 10
shard_size: 300

data_seed: ${seed}

model:
_target_: fedpara.models.FC
num_classes: ${dataset_config.num_classes}
weights: lowrank # lowrank or standard
param_type: lowrank # lowrank or standard
activation: relu # relu or leaky_relu
ratio: 0.5 # lowrank ratio
algorithm: ${algorithm}

hyperparams:
eta_l: 0.01
learning_decay: 0.999

strategy:
_target_: fedpara.strategy.pFedPara
algorithm: pFedPara
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0.0
min_evaluate_clients: ${clients_per_round}
Expand Down
28 changes: 22 additions & 6 deletions baselines/fedpara/fedpara/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def load_datasets(
) -> Tuple[List[DataLoader], DataLoader]:
"""Load the dataset and return the dataloaders for the clients and the server."""
print("Loading data...")
match config.name:
match config['name']:
case "CIFAR10":
Dataset = datasets.CIFAR10
case "CIFAR100":
Expand All @@ -23,8 +23,8 @@ def load_datasets(
Dataset = datasets.MNIST
case _:
raise NotImplementedError
data_directory = f"./data/{config.name.lower()}/"
match config.name:
data_directory = f"./data/{config['name'].lower()}/"
match config['name']:
case "CIFAR10" | "CIFAR100":
ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl"
transform_train = transforms.Compose(
Expand All @@ -44,6 +44,12 @@ def load_datasets(
try:
with open(ds_path, "rb") as file:
train_datasets = pickle.load(file)
dataset_train = Dataset(
data_directory, train=True, download=False, transform=transform_train
)
dataset_test = Dataset(
data_directory, train=False, download=False, transform=transform_test
)
except FileNotFoundError:
dataset_train = Dataset(
data_directory, train=True, download=True, transform=transform_train
Expand All @@ -54,6 +60,9 @@ def load_datasets(
train_datasets, _ = noniid(dataset_train, num_clients, config.alpha)
pickle.dump(train_datasets, open(ds_path, "wb"))
train_datasets = train_datasets.values()
dataset_test = Dataset(
data_directory, train=False, download=True, transform=transform_test
)

case "MNIST":
ds_path = f"{data_directory}train_{num_clients}.pkl"
Expand All @@ -68,16 +77,23 @@ def load_datasets(
)
try:
train_datasets = pickle.load(open(ds_path, "rb"))
dataset_train = Dataset(
data_directory, train=True, download=False, transform=transform_train
)
dataset_test = Dataset(
data_directory, train=False, download=False, transform=transform_test
)
except FileNotFoundError:
dataset_train = Dataset(
data_directory, train=True, download=True, transform=transform_train
)
train_datasets = mnist_niid(dataset_train, num_clients, config.shard_size, config.data_seed)
pickle.dump(train_datasets, open(ds_path, "wb"))
dataset_test = Dataset(
data_directory, train=False, download=True, transform=transform_test
)


dataset_test = Dataset(
data_directory, train=False, download=True, transform=transform_test
)
test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=2)
train_loaders = [
DataLoader(
Expand Down
51 changes: 23 additions & 28 deletions baselines/fedpara/fedpara/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from fedpara import client, server, utils
from fedpara.dataset import load_datasets
from fedpara.utils import get_parameters, seed_everything
from fedpara.utils import get_parameters, save_results_as_pickle, seed_everything, set_client_state_save_path


@hydra.main(config_path="conf", config_name="cifar10", version_base=None)
@hydra.main(config_path="conf", config_name="mnist", version_base=None)
def main(cfg: DictConfig) -> None:
"""Run the baseline.
Expand All @@ -24,6 +23,9 @@ def main(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
seed_everything(cfg.seed)
OmegaConf.to_container(cfg, resolve=True)
if 'state_path' in cfg: state_path=set_client_state_save_path(cfg.state_path)
else: state_path = None

# 2. Prepare dataset
train_loaders, test_loader = load_datasets(
config=cfg.dataset_config,
Expand All @@ -33,31 +35,15 @@ def main(cfg: DictConfig) -> None:

# 3. Define clients
# In this scheme the responsability of choosing the client is on the client manager
if cfg.strategy.min_evaluate_clients:
client_fn = client.gen_client_fn(
train_loaders=train_loaders,
test_loader=test_loader,
model=cfg.model,
num_epochs=cfg.num_epochs,
args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm},
state_path=cfg.state_path,
)
else :
client_fn = client.gen_client_fn(
train_loaders=train_loaders,
model=cfg.model,
num_epochs=cfg.num_epochs,
args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm},
)

if not cfg.strategy.min_evaluate_clients :
evaluate_fn = server.gen_evaluate_fn(
num_clients=cfg.num_clients,
test_loader=test_loader,
model=cfg.model,
device=cfg.server_device,
state_path=cfg.state_path,
)
client_fn = client.gen_client_fn(
train_loaders=train_loaders,
test_loader=test_loader,
model=cfg.model,
num_epochs=cfg.num_epochs,
args={"device": cfg.client_device, "algorithm": cfg.algorithm},
state_path=state_path,
)

def get_on_fit_config():
def fit_config_fn(server_round: int):
Expand All @@ -76,7 +62,15 @@ def fit_config_fn(server_round: int):
on_fit_config_fn=get_on_fit_config(),
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)),
)

else :
evaluate_fn = server.gen_evaluate_fn(
num_clients=cfg.num_clients,
test_loader=test_loader,
model=cfg.model,
device=cfg.server_device,
state_path=cfg.state_path,
)
strategy = instantiate(
cfg.strategy,
evaluate_fn=evaluate_fn,
Expand All @@ -98,7 +92,8 @@ def fit_config_fn(server_round: int):
"_memory": 30 * 1024 * 1024 * 1024,
},
)

save_results_as_pickle(history, )

# 6. Save results
save_path = HydraConfig.get().runtime.output_dir
file_suffix = "_".join(
Expand Down
Loading

0 comments on commit 7ea0c2e

Please sign in to comment.