diff --git a/baselines/fedpara/fedpara/client.py b/baselines/fedpara/fedpara/client.py index 818312a57f2d..92c0b0484458 100644 --- a/baselines/fedpara/fedpara/client.py +++ b/baselines/fedpara/fedpara/client.py @@ -1,8 +1,8 @@ """Client for FedPara.""" from collections import OrderedDict -from typing import Callable, Dict, List, Tuple - +from typing import Callable, Dict, List, Tuple, Optional +import copy import flwr as fl import torch from flwr.common import NDArrays, Scalar @@ -60,24 +60,98 @@ def fit( {}, ) +class PFedParaClient(fl.client.NumPyClient): + """personalized FedPara Client""" + def __init__( + self, + cid: int, + net: torch.nn.Module, + train_loader: DataLoader, + test_dataset: List[DataLoader], + device: str, + num_epochs: int, + state_path: str, + ): + + self.cid = cid + self.net = net + self.train_loader = train_loader + self.test_dataset = test_dataset + self.device = torch.device(device) + self.num_epochs = num_epochs + self.state_path = state_path + + def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + """Return the parameters of the current net.""" + return [val.cpu().numpy() for _, val in self.net.state_dict().items()] + + def _set_parameters(self, parameters: NDArrays) -> None: + params_dict = zip(self.net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + self.net.load_state_dict(state_dict, strict=True) + + def fit( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[NDArrays, int, Dict]: + """Train the network on the training set.""" + self._set_parameters(parameters) + print(f"Client {self.cid} Training...") + + train( + self.net, + self.train_loader, + self.device, + epochs=self.num_epochs, + hyperparams=config, + round=config["curr_round"], + ) + + return ( + self.get_parameters({}), + len(self.train_loader), + {}, + ) + def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[int, float, Dict]: + """Evaluate the network on the test set.""" + self._set_parameters(parameters) + print(f"Client {self.cid} Evaluating...") + + return ( + len(self.test_dataset[self.cid]), + train.test(self.net, self.test_dataset[self.cid], self.device), + {}, + ) def gen_client_fn( train_loaders: List[DataLoader], model: DictConfig, num_epochs: int, -) -> Callable[[str], FlowerClient]: + args: Dict, + test_loader: Optional[List[DataLoader]]=None, + state_path: Optional[str]=None, +) -> Callable[[str], fl.client.NumPyClient]: """Return a function which creates a new FlowerClient for a given cid.""" - def client_fn(cid: str) -> FlowerClient: + def client_fn(cid: str) -> fl.client.NumPyClient: """Create a new FlowerClient for a given cid.""" - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - return FlowerClient( - cid=int(cid), - net=instantiate(model).to(device), - train_loader=train_loaders[int(cid)], - device=device, - num_epochs=num_epochs, - ) - + cid = int(cid) + if args['algorithm'] == "pfedpara" or args['algorithm'] == "fedper": + return PFedParaClient( + cid=cid, + net=instantiate(model).to(args["device"]), + train_loader=train_loaders[cid], + test_dataset=copy.deepcopy(test_loader), + device=args["device"], + num_epochs=num_epochs, + state_path=state_path, + ) + else: + return FlowerClient( + cid=cid, + net=instantiate(model).to(args["device"]), + train_loader=train_loaders[cid], + device=args["device"], + num_epochs=num_epochs, + ) + return client_fn diff --git a/baselines/fedpara/fedpara/conf/femnist.yaml b/baselines/fedpara/fedpara/conf/mnist.yaml similarity index 58% rename from baselines/fedpara/fedpara/conf/femnist.yaml rename to baselines/fedpara/fedpara/conf/mnist.yaml index dfd67c0ab935..8115cc3e11ab 100644 --- a/baselines/fedpara/fedpara/conf/femnist.yaml +++ b/baselines/fedpara/fedpara/conf/mnist.yaml @@ -6,7 +6,7 @@ num_rounds: 100 clients_per_round: 10 num_epochs: 5 batch_size: 10 - +state_pah: ./state/ server_device: cuda client_resources: @@ -14,31 +14,26 @@ client_resources: num_gpus: 0.0625 dataset_config: - name: FEMNIST - partition: non-iid #redundent - num_classes: 62 - alpha: 0 # redundant - + name: MNIST + num_classes: 10 + shard_size: 300 + data_seed: ${seed} model: - _target_: fedpara.models.VGG + _target_: fedpara.models.FC num_classes: ${dataset_config.num_classes} - conv_type: lowrank # lowrank or standard + weights: lowrank # lowrank or standard activation: relu # relu or leaky_relu - ratio: 0.1 # lowrank ratio + ratio: 0.5 # lowrank ratio hyperparams: - eta_l: 0.1 + eta_l: 0.01 learning_decay: 0.999 - momentum: 0.0 - weight_decay: 0 strategy: _target_: fedpara.strategy.pFedPara algorithm: pFedPara fraction_fit: 0.00001 fraction_evaluate: 0.0 - min_evaluate_clients: 0 + min_evaluate_clients: ${clients_per_round} min_fit_clients: ${clients_per_round} - min_available_clients: ${clients_per_round} - accept_failures: false - + min_available_clients: ${clients_per_round} \ No newline at end of file diff --git a/baselines/fedpara/fedpara/dataset.py b/baselines/fedpara/fedpara/dataset.py index 6d207aec9524..8bc41f89608c 100644 --- a/baselines/fedpara/fedpara/dataset.py +++ b/baselines/fedpara/fedpara/dataset.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from torchvision import datasets, transforms -from fedpara.dataset_preparation import DatasetSplit, iid, noniid +from fedpara.dataset_preparation import DatasetSplit, iid, noniid, mnist_niid def load_datasets( @@ -14,39 +14,67 @@ def load_datasets( ) -> Tuple[List[DataLoader], DataLoader]: """Load the dataset and return the dataloaders for the clients and the server.""" print("Loading data...") - if config.name == "CIFAR10": - Dataset = datasets.CIFAR10 - elif config.name == "CIFAR100": - Dataset = datasets.CIFAR100 - else: - raise NotImplementedError + match config.name: + case "CIFAR10": + Dataset = datasets.CIFAR10 + case "CIFAR100": + Dataset = datasets.CIFAR100 + case "MNIST": + Dataset = datasets.MNIST + case _: + raise NotImplementedError data_directory = f"./data/{config.name.lower()}/" - ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl" - transform_train = transforms.Compose( - [ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ] - ) - transform_test = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ] - ) - try: - with open(ds_path, "rb") as file: - train_datasets = pickle.load(file) - except FileNotFoundError: - dataset_train = Dataset( - data_directory, train=True, download=True, transform=transform_train - ) - if config.partition == "iid": - train_datasets = iid(dataset_train, num_clients) - else: - train_datasets, _ = noniid(dataset_train, num_clients, config.alpha) + match config.name: + case "CIFAR10" | "CIFAR100": + ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl" + transform_train = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ) + transform_test = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ) + try: + with open(ds_path, "rb") as file: + train_datasets = pickle.load(file) + except FileNotFoundError: + dataset_train = Dataset( + data_directory, train=True, download=True, transform=transform_train + ) + if config.partition == "iid": + train_datasets = iid(dataset_train, num_clients) + else: + train_datasets, _ = noniid(dataset_train, num_clients, config.alpha) + pickle.dump(train_datasets, open(ds_path, "wb")) + train_datasets = train_datasets.values() + + case "MNIST": + ds_path = f"{data_directory}train_{num_clients}.pkl" + transform_train = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ) + transform_test = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + try: + train_datasets = pickle.load(open(ds_path, "rb")) + except FileNotFoundError: + dataset_train = Dataset( + data_directory, train=True, download=True, transform=transform_train + ) + train_datasets = mnist_niid(dataset_train, num_clients, config.shard_size, config.data_seed) + pickle.dump(train_datasets, open(ds_path, "wb")) + dataset_test = Dataset( data_directory, train=False, download=True, transform=transform_test ) @@ -58,7 +86,8 @@ def load_datasets( shuffle=True, num_workers=2, ) - for ids in train_datasets.values() + for ids in train_datasets ] return train_loaders, test_loader + diff --git a/baselines/fedpara/fedpara/dataset_preparation.py b/baselines/fedpara/fedpara/dataset_preparation.py index ccdbea1edbe8..b76dcf576f22 100644 --- a/baselines/fedpara/fedpara/dataset_preparation.py +++ b/baselines/fedpara/fedpara/dataset_preparation.py @@ -11,7 +11,8 @@ import numpy as np from torch.utils.data import Dataset - +import logging +from collections import Counter class DatasetSplit(Dataset): """An abstract Dataset class wrapped around Pytorch Dataset class.""" @@ -97,3 +98,15 @@ def noniid(dataset, no_participants, alpha=0.5): for j in range(no_classes): clas_weight[i, j] = float(datasize[i, j]) / float((train_img_size[i])) return per_participant_list, clas_weight + +def mnist_niid(dataset: Dataset, num_clients: int, shard_size: int, seed: int) -> np.ndarray: + """ Partitioning technique as mentioned in https://arxiv.org/pdf/1602.05629.pdf""" + indices = dataset.targets[np.argsort(dataset.targets)].numpy() + logging.debug(Counter(dataset.targets[indices].numpy())) + silos = np.array_split(indices, len(dataset) // shard_size)# randomly assign silos to clients + np.random.seed(seed+17) + np.random.shuffle(silos) + clients = np.array(np.array_split(silos, num_clients)).reshape(num_clients, -1) + logging.debug(clients.shape) + logging.debug(Counter([len(Counter(dataset.targets[client].numpy())) for client in clients])) + return clients diff --git a/baselines/fedpara/fedpara/main.py b/baselines/fedpara/fedpara/main.py index 3f87a896eebf..dca22b7ef2da 100644 --- a/baselines/fedpara/fedpara/main.py +++ b/baselines/fedpara/fedpara/main.py @@ -32,28 +32,58 @@ def main(cfg: DictConfig) -> None: ) # 3. Define clients - client_fn = client.gen_client_fn( - train_loaders=train_loaders, - model=cfg.model, - num_epochs=cfg.num_epochs, - ) + # In this scheme the responsability of choosing the client is on the client manager + if cfg.strategy.min_evaluate_clients: + client_fn = client.gen_client_fn( + train_loaders=train_loaders, + test_loader=test_loader, + model=cfg.model, + num_epochs=cfg.num_epochs, + args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm}, + state_path=cfg.state_path, + ) + else : + client_fn = client.gen_client_fn( + train_loaders=train_loaders, + model=cfg.model, + num_epochs=cfg.num_epochs, + args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm}, + ) - evaluate_fn = server.gen_evaluate_fn( - num_clients=cfg.num_clients, - test_loader=test_loader, - model=cfg.model, - device=cfg.server_device, - ) + if not cfg.strategy.min_evaluate_clients : + evaluate_fn = server.gen_evaluate_fn( + num_clients=cfg.num_clients, + test_loader=test_loader, + model=cfg.model, + device=cfg.server_device, + state_path=cfg.state_path, + ) + + def get_on_fit_config(): + def fit_config_fn(server_round: int): + fit_config = OmegaConf.to_container(cfg.hyperparams, resolve=True) + fit_config["curr_round"] = server_round + return fit_config + + return fit_config_fn net_glob = instantiate(cfg.model) # 4. Define strategy - strategy = instantiate( - cfg.strategy, - evaluate_fn=evaluate_fn, - on_fit_config_fn=server.get_on_fit_config(dict(cfg.hyperparams)), - initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)), - ) + if cfg.strategy.min_evaluate_clients: + strategy = instantiate( + cfg.strategy, + on_fit_config_fn=get_on_fit_config(), + initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)), + ) + else : + strategy = instantiate( + cfg.strategy, + evaluate_fn=evaluate_fn, + on_fit_config_fn=get_on_fit_config(), + initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)), + ) + # 5. Start Simulation history = fl.simulation.start_simulation( diff --git a/baselines/fedpara/fedpara/models.py b/baselines/fedpara/fedpara/models.py index 6670e1a6260e..233f76f935ec 100644 --- a/baselines/fedpara/fedpara/models.py +++ b/baselines/fedpara/fedpara/models.py @@ -2,7 +2,6 @@ import math from typing import Dict, Tuple - import numpy as np import torch import torch.nn.functional as F @@ -10,7 +9,137 @@ from torch import nn from torch.nn import init from torch.utils.data import DataLoader +class LowRankNN(nn.Module): + def __init__(self,input, output, rank,activation: str = "relu",) -> None: + super(LowRankNN, self).__init__() + + self.X = nn.Parameter( + torch.empty(size=(input, rank)), + requires_grad=True, + ) + self.Y = nn.Parameter( + torch.empty(size=(output,rank)), requires_grad=True + ) + + if activation == "leakyrelu": + activation = "leaky_relu" + init.kaiming_normal_(self.X, mode="fan_out", nonlinearity=activation) + init.kaiming_normal_(self.Y, mode="fan_out", nonlinearity=activation) + + def forward(self,x): + out = torch.einsum("xr,yr->xy", self.X, self.Y) + return out + +class Linear(nn.Module): + def __init__(self, input, output, ratio, activation: str = "relu",bias= True, pfedpara=True) -> None: + super(Linear, self).__init__() + rank = self._calc_from_ratio(ratio,input, output) + self.w1 = LowRankNN(input, output, rank, activation) + self.w2 = LowRankNN(input, output, rank, activation) + # make the bias for each layer + if bias: + self.bias = nn.Parameter(torch.zeros(output)) + self.pfedpara = pfedpara + + def _calc_from_ratio(self, ratio,input, output): + # Return the low-rank of sub-matrices given the compression ratio + # minimum possible parameter + r1 = int(np.ceil(np.sqrt(output))) + r2 = int(np.ceil(np.sqrt(input))) + r = np.min((r1, r2)) + # maximum possible rank, + """ + To solve it we need to know the roots of quadratic equation: ax^2+bx+c=0 + a = kernel**2 + b = out channel+ in channel + c = - num_target_params/2 + r3 is floored because we cannot take the ceil as it results a bigger number of parameters than the original problem + """ + num_target_params = ( + output * input + ) + a, b, c = input, output,- num_target_params/2 + discriminant = b**2 - 4 * a * c + r3 = math.floor((-b+math.sqrt(discriminant))/(2*a)) + rank=math.ceil((1-ratio)*r+ ratio*r3) + return rank + + def forward(self,x): + # personalized + if self.pfedpara: + w = self.w1() * self.w2() + self.w1() + else: + w = self.w1() * self.w2() + out = F.linear(x, w,self.bias) + return out + +class FC(nn.Module): + def __init__(self, input_size=28**2, hidden_size=256, num_classes=10, ratio=0.1, param_type="standard",activation: str = "relu",): + super(FC, self).__init__() + + if param_type == "standard": + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) + self.softmax = nn.Softmax(dim=1) + elif param_type == "lowrank": + self.fc1 = Linear(input_size, hidden_size, ratio, activation) + self.relu = nn.ReLU() + self.fc2 = Linear(hidden_size, num_classes, ratio, activation) + self.softmax = nn.Softmax(dim=1) + else: + raise ValueError("param_type must be either standard or lowrank") + @property + def per_param(self): + """ + Return the personalized parameters of the model + """ + if self.method == "pfedpara": + params = {"fc1.X":self.fc1.w1.X, "fc1.Y":self.fc1.w1.Y, "fc2.X":self.fc2.w1.X, "fc2.Y":self.fc2.w1.Y} + # return the w1 only of each layer, same format as the state_dict + elif self.method == "fedper": + params = {"fc2.w1":self.fc2.w1, "fc2.w2":self.fc2.w2} + else: + raise ValueError("method must be either pfedpara, fedper") + return params + @property + def load_per_param(self,state_dict): + """ + Load the personalized parameters of the model + """ + if self.method == "pfedpara": + self.fc1.w1.X = state_dict["fc1.X"] + self.fc1.w1.Y = state_dict["fc1.Y"] + self.fc2.w1.X = state_dict["fc2.X"] + self.fc2.w1.Y = state_dict["fc2.Y"] + elif self.method == "fedper": + self.fc2.w1 = state_dict["fc2.w1"] + self.fc2.w2 = state_dict["fc2.w2"] + else: + raise ValueError("method must be either pfedpara, fedper") + @property + def model_size(self): + """ + Return the total number of trainable parameters (in million paramaters) and the size of the model in MB. + """ + total_trainable_params = sum( + p.numel() for p in self.parameters() if p.requires_grad)/1e6 + param_size = 0 + for param in self.parameters(): + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in self.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + size_all_mb = (param_size + buffer_size) / 1024**2 + return total_trainable_params, size_all_mb + + def forward(self,x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + out = self.softmax(out) + return out class LowRank(nn.Module): """Low-rank convolutional layer.""" diff --git a/baselines/fedpara/fedpara/test.ipynb b/baselines/fedpara/fedpara/test.ipynb new file mode 100644 index 000000000000..a25ae10e7d6b --- /dev/null +++ b/baselines/fedpara/fedpara/test.ipynb @@ -0,0 +1,85 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torchvision.datasets\n", + "import torchvision.transforms as transforms\n", + "import numpy as np\n", + "from collections import Counter\n", + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data = torchvision.datasets.MNIST\n", + "transform= transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", + "trainset = data(root='./data', train = True, download = True, transform = transform)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "def mnist_niid(dataset: Dataset, num_clients: int, silo_size: int, seed: int) -> list:\n", + " indices = trainset.targets[np.argsort(trainset.targets)].numpy()\n", + " logging.debug(Counter(trainset.targets[indices].numpy()))\n", + " silos = np.array_split(indices, len(trainset) // 300)# randomly assign silos to clients\n", + " np.random.seed(seed+17)\n", + " np.random.shuffle(silos)\n", + " clients = np.array(np.array_split(silos, 100)).reshape(100, -1)\n", + " logging.debug(clients.shape)\n", + " logging.debug(Counter([len(Counter(trainset.targets[client].numpy())) for client in clients]))\n", + " return clients" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Counter({2: 82, 1: 11, 3: 7})" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flower", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}