From b88569a66ffc1965d64227c45cb51bfc292d8934 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 4 Sep 2023 21:14:36 +0100 Subject: [PATCH 01/21] update pyproject, format, typos --- baselines/fedprox/fedprox/__init__.py | 1 + baselines/fedprox/fedprox/client.py | 28 ++++----- baselines/fedprox/fedprox/conf/config.yaml | 4 +- baselines/fedprox/fedprox/conf/fedavg.yaml | 4 +- baselines/fedprox/fedprox/dataset.py | 5 +- .../fedprox/fedprox/dataset_preparation.py | 33 +++++----- baselines/fedprox/fedprox/main.py | 11 +++- baselines/fedprox/fedprox/server.py | 7 ++- baselines/fedprox/fedprox/strategy.py | 11 ++-- baselines/fedprox/fedprox/utils.py | 25 ++++---- baselines/fedprox/pyproject.toml | 62 +++++++++++++++++-- 11 files changed, 127 insertions(+), 64 deletions(-) diff --git a/baselines/fedprox/fedprox/__init__.py b/baselines/fedprox/fedprox/__init__.py index e69de29bb2d..c150b5cc2ef 100644 --- a/baselines/fedprox/fedprox/__init__.py +++ b/baselines/fedprox/fedprox/__init__.py @@ -0,0 +1 @@ +"""Fedprox package.""" diff --git a/baselines/fedprox/fedprox/client.py b/baselines/fedprox/fedprox/client.py index 6944c81a4ad..67db341015c 100644 --- a/baselines/fedprox/fedprox/client.py +++ b/baselines/fedprox/fedprox/client.py @@ -12,7 +12,6 @@ from omegaconf import DictConfig from torch.utils.data import DataLoader -from fedprox.dataset import load_datasets from fedprox.models import test, train @@ -40,11 +39,11 @@ def __init__( self.straggler_schedule = straggler_schedule def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: - """Returns the parameters of the current net.""" + """Return the parameters of the current net.""" return [val.cpu().numpy() for _, val in self.net.state_dict().items()] def set_parameters(self, parameters: NDArrays) -> None: - """Changes the parameters of the model using the given ones.""" + """Change the parameters of the model using the given ones.""" params_dict = zip(self.net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) self.net.load_state_dict(state_dict, strict=True) @@ -52,7 +51,7 @@ def set_parameters(self, parameters: NDArrays) -> None: def fit( self, parameters: NDArrays, config: Dict[str, Scalar] ) -> Tuple[NDArrays, int, Dict]: - """Implements distributed fit function for a given client.""" + """Implement distributed fit function for a given client.""" self.set_parameters(parameters) # At each round check if the client is a straggler, @@ -88,7 +87,7 @@ def fit( self.device, epochs=num_epochs, learning_rate=self.learning_rate, - proximal_mu=config["proximal_mu"], + proximal_mu=float(config["proximal_mu"]), ) return self.get_parameters({}), len(self.trainloader), {"is_straggler": False} @@ -96,7 +95,7 @@ def fit( def evaluate( self, parameters: NDArrays, config: Dict[str, Scalar] ) -> Tuple[float, int, Dict]: - """Implements distributed evaluation for a given client.""" + """Implement distributed evaluation for a given client.""" self.set_parameters(parameters) loss, accuracy = test(self.net, self.valloader, self.device) return float(loss), len(self.valloader), {"accuracy": float(accuracy)} @@ -111,10 +110,8 @@ def gen_client_fn( learning_rate: float, stragglers: float, model: DictConfig, -) -> Tuple[ - Callable[[str], FlowerClient], DataLoader -]: # pylint: disable=too-many-arguments - """Generates the client function that creates the Flower Clients. +) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments + """Generate the client function that creates the Flower Clients. Parameters ---------- @@ -139,13 +136,11 @@ def gen_client_fn( Returns ------- - Tuple[Callable[[str], FlowerClient], DataLoader] - A tuple containing the client function that creates Flower Clients and - the DataLoader that will be used for testing + Callable[[str], FlowerClient] + A client function that creates Flower Clients. """ - - # Defines a staggling schedule for each clients, i.e at which round will they - # be a straggler. This is done so at each round the proportion of staggling + # Defines a straggling schedule for each clients, i.e at which round will they + # be a straggler. This is done so at each round the proportion of straggling # clients is respected stragglers_mat = np.transpose( np.random.choice( @@ -155,7 +150,6 @@ def gen_client_fn( def client_fn(cid: str) -> FlowerClient: """Create a Flower client representing a single organization.""" - # Load model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = instantiate(model).to(device) diff --git a/baselines/fedprox/fedprox/conf/config.yaml b/baselines/fedprox/fedprox/conf/config.yaml index f62198ae722..05763f89d2d 100644 --- a/baselines/fedprox/fedprox/conf/config.yaml +++ b/baselines/fedprox/fedprox/conf/config.yaml @@ -29,12 +29,12 @@ model: strategy: _target_: flwr.server.strategy.FedProx - fraction_fit: 0.00001 # because we want the number of clients to sample on each roudn to be solely defined by min_fit_clients + fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients fraction_evaluate: 0.0 min_fit_clients: ${clients_per_round} min_evaluate_clients: 0 min_available_clients: ${clients_per_round} evaluate_metrics_aggregation_fn: _target_: fedprox.strategy.weighted_average - _partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actuallly calls the function (in aggregate_evaluate()) + _partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actually calls the function (in aggregate_evaluate()) proximal_mu: ${mu} diff --git a/baselines/fedprox/fedprox/conf/fedavg.yaml b/baselines/fedprox/fedprox/conf/fedavg.yaml index 3fa9d2007e6..5b6c3873ec5 100644 --- a/baselines/fedprox/fedprox/conf/fedavg.yaml +++ b/baselines/fedprox/fedprox/conf/fedavg.yaml @@ -34,11 +34,11 @@ model: strategy: _target_: fedprox.strategy.FedAvgWithStragglerDrop #! this points to FedAvgWithStragglerDrop class in strategy.py, Note that we need the full module path (including `fedprox`) - fraction_fit: 0.00001 # because we want the number of clients to sample on each roudn to be solely defined by min_fit_clients + fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients fraction_evaluate: 0.0 min_fit_clients: ${clients_per_round} min_available_clients: ${clients_per_round} min_evaluate_clients: 0 evaluate_metrics_aggregation_fn: _target_: fedprox.strategy.weighted_average - _partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actuallly calls the function (in aggregate_evaluate()) + _partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actually calls the function (in aggregate_evaluate()) diff --git a/baselines/fedprox/fedprox/dataset.py b/baselines/fedprox/fedprox/dataset.py index 035ee987f06..e740ef0f67a 100644 --- a/baselines/fedprox/fedprox/dataset.py +++ b/baselines/fedprox/fedprox/dataset.py @@ -17,7 +17,7 @@ def load_datasets( # pylint: disable=too-many-arguments batch_size: Optional[int] = 32, seed: Optional[int] = 42, ) -> Tuple[DataLoader, DataLoader, DataLoader]: - """Creates the dataloaders to be fed into the model. + """Create the dataloaders to be fed into the model. Parameters ---------- @@ -36,7 +36,8 @@ def load_datasets( # pylint: disable=too-many-arguments Returns ------- Tuple[DataLoader, DataLoader, DataLoader] - The DataLoader for training, the DataLoader for validation, the DataLoader for testing. + The DataLoader for training, the DataLoader for validation, the DataLoader + for testing. """ print(f"Dataset partitioning config: {config}") datasets, testset = _partition_data( diff --git a/baselines/fedprox/fedprox/dataset_preparation.py b/baselines/fedprox/fedprox/dataset_preparation.py index 24fc64f1238..068f4fc4c77 100644 --- a/baselines/fedprox/fedprox/dataset_preparation.py +++ b/baselines/fedprox/fedprox/dataset_preparation.py @@ -1,3 +1,4 @@ +"""Functions for dataset download and processing.""" from typing import List, Optional, Tuple import numpy as np @@ -8,7 +9,7 @@ def _download_data() -> Tuple[Dataset, Dataset]: - """Downloads (if necessary) and returns the MNIST dataset. + """Download (if necessary) and returns the MNIST dataset. Returns ------- @@ -30,8 +31,9 @@ def _partition_data( balance: Optional[bool] = False, seed: Optional[int] = 42, ) -> Tuple[List[Dataset], Dataset]: - """Split training set into iid or non iid partitions to simulate the - federated setting. + """Split training set into iid or non iid partitions to simulate the federated. + + setting. Parameters ---------- @@ -39,8 +41,9 @@ def _partition_data( The number of clients that hold a part of the data iid : bool, optional Whether the data should be independent and identically distributed between - the clients or if the data should first be sorted by labels and distributed by chunks - to each client (used to test the convergence in a worst case scenario), by default False + the clients or if the data should first be sorted by labels and distributed + by chunks to each client (used to test the convergence in a worst case scenario) + , by default False power_law: bool, optional Whether to follow a power-law distribution when assigning number of samples for each client, defaults to True @@ -53,13 +56,14 @@ def _partition_data( Returns ------- Tuple[List[Dataset], Dataset] - A list of dataset for each client and a single dataset to be use for testing the model. + A list of dataset for each client and a single dataset to be use for testing + the model. """ trainset, testset = _download_data() if balance: trainset = _balance_classes(trainset, seed) - + partition_size = int(len(trainset) / num_clients) lengths = [partition_size] * num_clients @@ -179,9 +183,10 @@ def _power_law_split( mean: float = 0.0, sigma: float = 2.0, ) -> Dataset: - """Partitions the dataset following a power-law distribution. It follows - the implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with - default values set accordingly. + """Partition the dataset following a power-law distribution. It follows the. + + implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with default + values set accordingly. Parameters ---------- @@ -205,15 +210,14 @@ def _power_law_split( Dataset The partitioned training dataset. """ - targets = sorted_trainset.targets - full_idx = range(len(targets)) + full_idx = list(range(len(targets))) class_counts = np.bincount(sorted_trainset.targets) labels_cs = np.cumsum(class_counts) labels_cs = [0] + labels_cs[:-1].tolist() - partitions_idx = [] + partitions_idx: List[List[int]] = [] num_classes = len(np.bincount(targets)) hist = np.zeros(num_classes, dtype=np.int32) @@ -243,7 +247,8 @@ def _power_law_split( (num_classes, int(num_partitions / num_classes), num_labels_per_partition), ) remaining_per_class = class_counts - hist - # obtain how many samples each partition should be assigned for each of the labels it contains + # obtain how many samples each partition should be assigned for each of the + # labels it contains probs = ( remaining_per_class.reshape(-1, 1, 1) * probs diff --git a/baselines/fedprox/fedprox/main.py b/baselines/fedprox/fedprox/main.py index af4971af5df..682c4dd271c 100644 --- a/baselines/fedprox/fedprox/main.py +++ b/baselines/fedprox/fedprox/main.py @@ -1,5 +1,7 @@ """Runs CNN federated learning for MNIST dataset.""" +from typing import Dict, Union + import flwr as fl import hydra from hydra.core.hydra_config import HydraConfig @@ -10,17 +12,18 @@ from fedprox.dataset import load_datasets from fedprox.utils import save_results_as_pickle +FitConfig = Dict[str, Union[bool, float]] + @hydra.main(config_path="conf", config_name="config", version_base=None) def main(cfg: DictConfig) -> None: - """Main function to run CNN federated learning on MNIST. + """Rrun CNN federated learning on MNIST. Parameters ---------- cfg : DictConfig An omegaconf object that stores the hydra config. """ - # print config structured as YAML print(OmegaConf.to_yaml(cfg)) @@ -53,7 +56,9 @@ def main(cfg: DictConfig) -> None: def get_on_fit_config(): def fit_config_fn(server_round: int): # resolve and convert to python dict - fit_config = OmegaConf.to_container(cfg.fit_config, resolve=True) + fit_config: FitConfig = OmegaConf.to_container( # type: ignore + cfg.fit_config, resolve=True + ) fit_config["curr_round"] = server_round # add round info return fit_config diff --git a/baselines/fedprox/fedprox/server.py b/baselines/fedprox/fedprox/server.py index 8bd063ad289..d7557123c48 100644 --- a/baselines/fedprox/fedprox/server.py +++ b/baselines/fedprox/fedprox/server.py @@ -1,3 +1,4 @@ +"""Flower Server.""" from collections import OrderedDict from typing import Callable, Dict, Optional, Tuple @@ -17,7 +18,7 @@ def gen_evaluate_fn( ) -> Callable[ [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]] ]: - """Generates the function for centralized evaluation. + """Generate the function for centralized evaluation. Parameters ---------- @@ -28,7 +29,8 @@ def gen_evaluate_fn( Returns ------- - Callable[ [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]] ] + Callable[ [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]] ] The centralized evaluation function. """ @@ -37,7 +39,6 @@ def evaluate( ) -> Optional[Tuple[float, Dict[str, Scalar]]]: # pylint: disable=unused-argument """Use the entire CIFAR-10 test set for evaluation.""" - net = instantiate(model) params_dict = zip(net.state_dict().keys(), parameters_ndarrays) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) diff --git a/baselines/fedprox/fedprox/strategy.py b/baselines/fedprox/fedprox/strategy.py index a53f2d1f91e..1769bf08be7 100644 --- a/baselines/fedprox/fedprox/strategy.py +++ b/baselines/fedprox/fedprox/strategy.py @@ -1,3 +1,4 @@ +"""Flower strategy.""" from typing import List, Tuple, Union from flwr.common import Metrics @@ -7,7 +8,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - """Aggregation function for weighted average during evaluation. + """Aggregate with weighted average during evaluation. Parameters ---------- @@ -24,7 +25,6 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] # Aggregate and return custom metric (weighted average) - print("here and nothing is breaking!!!") return {"accuracy": int(sum(accuracies)) / int(sum(examples))} @@ -37,13 +37,14 @@ def aggregate_fit( results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ): - """Here we discard all the models sent by the clients that were - stragglers in this round.""" + """Discard all the models sent by the clients that were stragglers. + in this round. + """ # Record which client was a straggler in this round stragglers_mask = [res.metrics["is_straggler"] for _, res in results] - print(f"Num stragglers in round: {sum(stragglers_mask)}") + # print(f"Num stragglers in round: {sum(stragglers_mask)}") # keep those results that are not from stragglers results = [res for i, res in enumerate(results) if not (stragglers_mask[i])] diff --git a/baselines/fedprox/fedprox/utils.py b/baselines/fedprox/fedprox/utils.py index 2fa3b64966b..5bc8d584ffe 100644 --- a/baselines/fedprox/fedprox/utils.py +++ b/baselines/fedprox/fedprox/utils.py @@ -12,16 +12,16 @@ def plot_metric_from_history( hist: History, - save_plot_path: Path, + save_plot_path: str, suffix: Optional[str] = "", ) -> None: - """Function to plot from Flower server History. + """Plot from Flower server History. Parameters ---------- hist : History Object containing evaluation for all rounds. - save_plot_path : Path + save_plot_path : str Folder to save the plot to. suffix: Optional[str] Optional string to add at the end of the filename for the plot. @@ -55,10 +55,10 @@ def plot_metric_from_history( def save_results_as_pickle( history: History, file_path: Union[str, Path], - extra_results: Optional[Dict] = {}, - default_filename: Optional[str] = "results.pkl", + extra_results: Optional[Dict] = None, + default_filename: str = "results.pkl", ) -> None: - """Saves results from simulation to pickle. + """Save results from simulation to pickle. Parameters ---------- @@ -77,22 +77,23 @@ def save_results_as_pickle( File used by default if file_path points to a directory instead to a file. Default: "results.pkl" """ - path = Path(file_path) # ensure path exists path.mkdir(exist_ok=True, parents=True) def _add_random_suffix(path_: Path): - """Adds a randomly generated suffix to the file name (so it doesn't - overwrite the file).""" + """Add a randomly generated suffix to the file name (so it doesn't. + + overwrite the file). + """ print(f"File `{path_}` exists! ") suffix = token_hex(4) print(f"New results to be saved with suffix: {suffix}") return path_.parent / (path_.stem + "_" + suffix + ".pkl") def _complete_path_with_default_name(path_: Path): - """Appends the default file name to the path.""" + """Append the default file name to the path.""" print("Using default filename") return path_ / default_filename @@ -105,7 +106,9 @@ def _complete_path_with_default_name(path_: Path): print(f"Results will be saved into: {path}") - data = {"history": history, **extra_results} + data = {"history": history} + if extra_results is not None: + data = {**data, **extra_results} # save results to pickle with open(str(path), "wb") as handle: diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index d1ba004e18c..ea774580d68 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -42,6 +42,8 @@ flwr = "1.3.0" ray = "1.11.1" hydra-core = "1.3.2" matplotlib = "3.7.1" +jupyter = "^1.0.0" +pandas = "^2.0.3" [tool.poetry.dev-dependencies] isort = "==5.11.5" @@ -52,6 +54,7 @@ pylint = "==2.8.2" flake8 = "==3.9.2" pytest = "==6.2.4" pytest-watch = "==4.2.0" +ruff = "==0.0.272" types-requests = "==2.27.7" [tool.isort] @@ -66,10 +69,6 @@ use_parentheses = true line-length = 88 target-version = ["py38", "py39", "py310", "py311"] -[tool.pylint."MESSAGES CONTROL"] -disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" -good-names = "i,j,k,_,x,y,X,Y" - [tool.pytest.ini_options] minversion = "6.2" addopts = "-qq" @@ -81,6 +80,59 @@ testpaths = [ ignore_missing_imports = true strict = false plugins = "numpy.typing.mypy_plugin" -[tool.pylint.messages_control] + +[tool.pylint."MESSAGES CONTROL"] disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" signature-mutators="hydra.main.main" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" From 33dfe8d2155b9783b37db399ab63436cac7972d8 Mon Sep 17 00:00:00 2001 From: javier Date: Wed, 6 Sep 2023 12:33:14 +0000 Subject: [PATCH 02/21] readme images go in _static --- .../fedprox/{docs => _static}/FedProx_mnist.png | Bin 1 file changed, 0 insertions(+), 0 deletions(-) rename baselines/fedprox/{docs => _static}/FedProx_mnist.png (100%) diff --git a/baselines/fedprox/docs/FedProx_mnist.png b/baselines/fedprox/_static/FedProx_mnist.png similarity index 100% rename from baselines/fedprox/docs/FedProx_mnist.png rename to baselines/fedprox/_static/FedProx_mnist.png From f4208177e3d688da1b16bfc4f84800eda08ceb7d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 6 Sep 2023 14:11:48 +0100 Subject: [PATCH 03/21] more formatting --- baselines/fedprox/fedprox/client.py | 1 + baselines/fedprox/fedprox/dataset_preparation.py | 3 +++ baselines/fedprox/fedprox/models.py | 4 ++-- baselines/fedprox/fedprox/strategy.py | 4 ++-- baselines/fedprox/fedprox/utils.py | 4 ++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/baselines/fedprox/fedprox/client.py b/baselines/fedprox/fedprox/client.py index 67db341015c..e168a9ee24a 100644 --- a/baselines/fedprox/fedprox/client.py +++ b/baselines/fedprox/fedprox/client.py @@ -15,6 +15,7 @@ from fedprox.models import test, train +# pylint: disable=too-many-arguments class FlowerClient( fl.client.NumPyClient ): # pylint: disable=too-many-instance-attributes diff --git a/baselines/fedprox/fedprox/dataset_preparation.py b/baselines/fedprox/fedprox/dataset_preparation.py index 068f4fc4c77..6c44c7411c3 100644 --- a/baselines/fedprox/fedprox/dataset_preparation.py +++ b/baselines/fedprox/fedprox/dataset_preparation.py @@ -24,6 +24,7 @@ def _download_data() -> Tuple[Dataset, Dataset]: return trainset, testset +# pylint: disable=too-many-locals def _partition_data( num_clients, iid: Optional[bool] = False, @@ -175,6 +176,7 @@ def _sort_by_class( return sorted_dataset +# pylint: disable=too-many-locals, too-many-arguments def _power_law_split( sorted_trainset: Dataset, num_partitions: int, @@ -249,6 +251,7 @@ def _power_law_split( remaining_per_class = class_counts - hist # obtain how many samples each partition should be assigned for each of the # labels it contains + # pylint: disable=too-many-function-args probs = ( remaining_per_class.reshape(-1, 1, 1) * probs diff --git a/baselines/fedprox/fedprox/models.py b/baselines/fedprox/fedprox/models.py index d44980a8609..d6f20de1acb 100644 --- a/baselines/fedprox/fedprox/models.py +++ b/baselines/fedprox/fedprox/models.py @@ -61,7 +61,7 @@ class LogisticRegression(nn.Module): def __init__(self, num_classes: int) -> None: super().__init__() - self.fc = nn.Linear(28 * 28, num_classes) + self.linear = nn.Linear(28 * 28, num_classes) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -76,7 +76,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: torch.Tensor The resulting Tensor after it has passed through the network """ - output_tensor = self.fc(torch.flatten(input_tensor, 1)) + output_tensor = self.linear(torch.flatten(input_tensor, 1)) return output_tensor diff --git a/baselines/fedprox/fedprox/strategy.py b/baselines/fedprox/fedprox/strategy.py index 1769bf08be7..89bdc76032d 100644 --- a/baselines/fedprox/fedprox/strategy.py +++ b/baselines/fedprox/fedprox/strategy.py @@ -21,7 +21,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: The weighted average metric. """ # Multiply accuracy of each client by number of examples used - accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] # Aggregate and return custom metric (weighted average) @@ -47,7 +47,7 @@ def aggregate_fit( # print(f"Num stragglers in round: {sum(stragglers_mask)}") # keep those results that are not from stragglers - results = [res for i, res in enumerate(results) if not (stragglers_mask[i])] + results = [res for i, res in enumerate(results) if not stragglers_mask[i]] # call the parent `aggregate_fit()` (i.e. that in standard FedAvg) return super().aggregate_fit(server_round, results, failures) diff --git a/baselines/fedprox/fedprox/utils.py b/baselines/fedprox/fedprox/utils.py index 5bc8d584ffe..9200778bfb0 100644 --- a/baselines/fedprox/fedprox/utils.py +++ b/baselines/fedprox/fedprox/utils.py @@ -32,12 +32,12 @@ def plot_metric_from_history( if metric_type == "centralized" else hist.metrics_distributed ) - rounds, values = zip(*metric_dict["accuracy"]) + _, values = zip(*metric_dict["accuracy"]) # let's extract centralised loss (main metric reported in FedProx paper) rounds_loss, values_loss = zip(*hist.losses_centralized) - fig, axs = plt.subplots(nrows=2, ncols=1, sharex="row") + _, axs = plt.subplots(nrows=2, ncols=1, sharex="row") axs[0].plot(np.asarray(rounds_loss), np.asarray(values_loss)) axs[1].plot(np.asarray(rounds_loss), np.asarray(values)) From 3d8d2d2231a917fe0facf5e8796c3a3c4c787c17 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 6 Sep 2023 14:46:40 +0100 Subject: [PATCH 04/21] more --- baselines/fedprox/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/baselines/fedprox/README.md b/baselines/fedprox/README.md index 1834a0ccc62..14a7b2be5fc 100644 --- a/baselines/fedprox/README.md +++ b/baselines/fedprox/README.md @@ -1,15 +1,15 @@ --- title: Federated Optimization in Heterogeneous Networks url: https://arxiv.org/abs/1812.06127 -labels: [image classification, cross-device, stragglers] # please add between 4 and 10 single-word (maybe two-words) labels (e.g. "system heterogeneity", "image classification", "asynchronous", "weight sharing", "cross-silo") -dataset: [mnist] # list of datasets you include in your baseline +labels: [image classification, cross-device, stragglers] +dataset: [mnist] --- # FedProx: Federated Optimization in Heterogeneous Networks > Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper. -**Paper:** https://arxiv.org/abs/1812.06127 +**Paper:** [arxiv.org/abs/1812.06127](https://arxiv.org/abs/1812.06127) **Authors:** Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar and Virginia Smith. @@ -34,7 +34,7 @@ dataset: [mnist] # list of datasets you include in your baseline * A logistic regression model used in the FedProx paper for MNIST (see `models/LogisticRegression`). This is the model used by default. * A two-layer CNN network as used in the FedAvg paper (see `models/Net`) -**Dataset:** This baseline only includes the MNIST dataset. By default it will be partitioned into 1000 clients following a pathological split where each client has examples of two (out of ten) class labels. The number of examples in each client is derived by sampling from a powerlaw distribution. The settings are as follow: +**Dataset:** This baseline only includes the MNIST dataset. By default, it will be partitioned into 1000 clients following a pathological split where each client has examples of two (out of ten) class labels. The number of examples in each client is derived by sampling from a powerlaw distribution. The settings are as follows: | Dataset | #classes | #partitions | partitioning method | partition settings | | :------ | :---: | :---: | :---: | :---: | @@ -96,7 +96,7 @@ python -m fedprox.main --config-name fedavg ## Expected results -With the following command we run both FedProx and FedAvg configurations while iterating through different values of `mu` and `stragglers_fraction`. We ran each experiment five times (this is achieved by artificially adding an extra element to the config but that it doesn't have an impact on the FL setting `'+repeat_num=range(5)'`) +With the following command, we run both FedProx and FedAvg configurations while iterating through different values of `mu` and `stragglers_fraction`. We ran each experiment five times (this is achieved by artificially adding an extra element to the config but it doesn't have an impact on the FL setting `'+repeat_num=range(5)'`) ```bash python -m fedprox.main --multirun mu=0.0,2.0 stragglers_fraction=0.0,0.5,0.9 '+repeat_num=range(5)' From 83cd88cbe98d22f905edb599ad019214b350a936 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 6 Sep 2023 20:41:42 +0100 Subject: [PATCH 05/21] update fedprox pyproject; added changelog entry --- baselines/fedprox/pyproject.toml | 5 ++--- doc/source/ref-changelog.md | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index ea774580d68..23009001b8a 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -38,8 +38,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.8.15, <3.12.0" -flwr = "1.3.0" -ray = "1.11.1" +flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" matplotlib = "3.7.1" jupyter = "^1.0.0" @@ -49,7 +48,7 @@ pandas = "^2.0.3" isort = "==5.11.5" black = "==23.1.0" docformatter = "==1.5.1" -mypy = "==0.961" +mypy = "==1.4.1" pylint = "==2.8.2" flake8 = "==3.9.2" pytest = "==6.2.4" diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 215dc5de8b1..b6dc3608c5a 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -6,6 +6,8 @@ - **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301).[#2305](https://github.com/adap/flower/pull/2305)) +- **New or updates baselines** + * FedProx ([#2286](https://github.com/adap/flower/pull/2286)) ### Incompatible changes - **Remove support for Python 3.7** ([#2280](https://github.com/adap/flower/pull/2280), [#2299](https://github.com/adap/flower/pull/2299)) From 12af3744478e0bdca95e6d84b5658f58fb674d98 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 6 Sep 2023 20:52:42 +0100 Subject: [PATCH 06/21] format changelog --- doc/source/ref-changelog.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index b6dc3608c5a..4fcb28439d6 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -7,7 +7,9 @@ - **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301).[#2305](https://github.com/adap/flower/pull/2305)) - **New or updates baselines** - * FedProx ([#2286](https://github.com/adap/flower/pull/2286)) + + - FedProx ([#2286](https://github.com/adap/flower/pull/2286)) + ### Incompatible changes - **Remove support for Python 3.7** ([#2280](https://github.com/adap/flower/pull/2280), [#2299](https://github.com/adap/flower/pull/2299)) From b52571ee2d8678afb824ffe5b2ec76c7e80df706 Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 7 Sep 2023 08:11:21 +0100 Subject: [PATCH 07/21] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- baselines/fedprox/fedprox/__init__.py | 2 +- baselines/fedprox/fedprox/main.py | 2 +- baselines/fedprox/fedprox/server.py | 2 ++ baselines/fedprox/fedprox/strategy.py | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/baselines/fedprox/fedprox/__init__.py b/baselines/fedprox/fedprox/__init__.py index c150b5cc2ef..1c264dd8e64 100644 --- a/baselines/fedprox/fedprox/__init__.py +++ b/baselines/fedprox/fedprox/__init__.py @@ -1 +1 @@ -"""Fedprox package.""" +"""FedProx package.""" diff --git a/baselines/fedprox/fedprox/main.py b/baselines/fedprox/fedprox/main.py index 682c4dd271c..16fb6376dcb 100644 --- a/baselines/fedprox/fedprox/main.py +++ b/baselines/fedprox/fedprox/main.py @@ -17,7 +17,7 @@ @hydra.main(config_path="conf", config_name="config", version_base=None) def main(cfg: DictConfig) -> None: - """Rrun CNN federated learning on MNIST. + """Run CNN federated learning on MNIST. Parameters ---------- diff --git a/baselines/fedprox/fedprox/server.py b/baselines/fedprox/fedprox/server.py index d7557123c48..fffe20eb600 100644 --- a/baselines/fedprox/fedprox/server.py +++ b/baselines/fedprox/fedprox/server.py @@ -1,4 +1,6 @@ """Flower Server.""" + + from collections import OrderedDict from typing import Callable, Dict, Optional, Tuple diff --git a/baselines/fedprox/fedprox/strategy.py b/baselines/fedprox/fedprox/strategy.py index 89bdc76032d..23fd6695d85 100644 --- a/baselines/fedprox/fedprox/strategy.py +++ b/baselines/fedprox/fedprox/strategy.py @@ -1,4 +1,6 @@ """Flower strategy.""" + + from typing import List, Tuple, Union from flwr.common import Metrics From 1f52b4e163d1330630a1ace07e46e7bc8812b6ad Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 7 Sep 2023 09:12:00 +0100 Subject: [PATCH 08/21] updated pylint typechecks --- baselines/fedprox/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index 23009001b8a..0d4978102b4 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -85,6 +85,9 @@ disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import good-names = "i,j,k,_,x,y,X,Y" signature-mutators="hydra.main.main" +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + [[tool.mypy.overrides]] module = [ "importlib.metadata.*", From 526c5e2d044dbc13dbf79373585b7dd8a94927ea Mon Sep 17 00:00:00 2001 From: javier Date: Thu, 7 Sep 2023 19:01:47 +0000 Subject: [PATCH 09/21] py10, torch2, small edits --- baselines/fedprox/README.md | 9 +-------- baselines/fedprox/fedprox/server.py | 5 +++++ baselines/fedprox/pyproject.toml | 4 +++- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/baselines/fedprox/README.md b/baselines/fedprox/README.md index 14a7b2be5fc..57f14d0a30a 100644 --- a/baselines/fedprox/README.md +++ b/baselines/fedprox/README.md @@ -56,17 +56,10 @@ The following table shows the main hyperparameters for this baseline with their ## Environment Setup -To construct the Python environment follow these steps: +To construct the Python environment simply run: ```bash -# install the base Poetry environment poetry install - -# activate the environment -poetry shell - -# install PyTorch with GPU support. Please note this baseline is very lightweight so it can run fine on a CPU. -pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 ``` ## Running the Experiments diff --git a/baselines/fedprox/fedprox/server.py b/baselines/fedprox/fedprox/server.py index fffe20eb600..3f0f00bea63 100644 --- a/baselines/fedprox/fedprox/server.py +++ b/baselines/fedprox/fedprox/server.py @@ -47,6 +47,11 @@ def evaluate( net.load_state_dict(state_dict, strict=True) net.to(device) + # We could compile the model here but we are not going to do it because + # running test() is so lightweight that the overhead of compiling the model + # negate any potential speedup. Please note this is specific to the model and + # dataset used in this baseline. In general, compiling the model is worth it + loss, accuracy = test(net, testloader, device=device) # return statistics return loss, {"accuracy": accuracy} diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index 0d4978102b4..ee127ac19fe 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -37,12 +37,14 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.8.15, <3.12.0" +python = ">=3.10.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" matplotlib = "3.7.1" jupyter = "^1.0.0" pandas = "^2.0.3" +torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl"} [tool.poetry.dev-dependencies] isort = "==5.11.5" From 720e59475854adcc85309a9a4d783805fa90ad7f Mon Sep 17 00:00:00 2001 From: javier Date: Thu, 7 Sep 2023 19:10:56 +0000 Subject: [PATCH 10/21] format --- doc/source/ref-changelog.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index a1ea9f18fe7..11afa002287 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -6,17 +6,14 @@ - **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301).[#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307)) - - **New or updates baselines** - FedProx ([#2286](https://github.com/adap/flower/pull/2286)) - - **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317)) Flower received many improvements under the hood, too many to list here. - ### Incompatible changes - **Remove support for Python 3.7** ([#2280](https://github.com/adap/flower/pull/2280), [#2299](https://github.com/adap/flower/pull/2299), [2304](https://github.com/adap/flower/pull/2304), [#2306](https://github.com/adap/flower/pull/2306)) From e15dc10277310268d1551311268b6b80d291cf57 Mon Sep 17 00:00:00 2001 From: javier Date: Thu, 7 Sep 2023 20:14:15 +0000 Subject: [PATCH 11/21] . --- baselines/fedprox/pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index ee127ac19fe..fff3fbc0369 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -41,8 +41,6 @@ python = ">=3.10.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" matplotlib = "3.7.1" -jupyter = "^1.0.0" -pandas = "^2.0.3" torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl"} torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl"} From a6cbafbbbce57c4c7bc53628326e596f4ad2c05d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 7 Sep 2023 22:11:54 +0100 Subject: [PATCH 12/21] attempt --- baselines/fedprox/pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index fff3fbc0369..726141cafca 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -37,12 +37,12 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.10.0, <3.11.0" +python = ">=3.9.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" matplotlib = "3.7.1" -torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl"} -torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl"} +torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp39-cp39-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp39-cp39-linux_x86_64.whl"} [tool.poetry.dev-dependencies] isort = "==5.11.5" From b304df093d4eb42b53bb9d1d0a0822117b981f26 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 7 Sep 2023 22:23:08 +0100 Subject: [PATCH 13/21] revert and try workflow change --- .github/workflows/baselines.yml | 1 + baselines/fedprox/pyproject.toml | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index 3ed321d54df..f118774a292 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -116,6 +116,7 @@ jobs: changed_dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" cd "${changed_dir}" python -m poetry install + python -m poetry config virtualenvs.create true - name: Test if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' run: | diff --git a/baselines/fedprox/pyproject.toml b/baselines/fedprox/pyproject.toml index 726141cafca..fff3fbc0369 100644 --- a/baselines/fedprox/pyproject.toml +++ b/baselines/fedprox/pyproject.toml @@ -37,12 +37,12 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.9.0, <3.11.0" +python = ">=3.10.0, <3.11.0" flwr = { extras = ["simulation"], version = "1.5.0" } hydra-core = "1.3.2" matplotlib = "3.7.1" -torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp39-cp39-linux_x86_64.whl"} -torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp39-cp39-linux_x86_64.whl"} +torch = { url = "https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl"} [tool.poetry.dev-dependencies] isort = "==5.11.5" From 8db7e8d6029030058e44d2253b62ac0c5f9f54cb Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 10 Sep 2023 11:10:12 +0100 Subject: [PATCH 14/21] works? --- .github/workflows/baselines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index f118774a292..bf878fbe9c8 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -109,7 +109,7 @@ jobs: if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' run: | From 22ad8dc35e5efb76a196681a0f61f1be4de5f68a Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 10 Sep 2023 11:17:35 +0100 Subject: [PATCH 15/21] . --- .github/workflows/baselines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index bf878fbe9c8..ef9e016e690 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -109,7 +109,7 @@ jobs: if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: - python-version: 3.9 + python-version: 3.10 - name: Install dependencies if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' run: | From 64f9f941aabb1fa94bb32d18f950a8a92aa2cda5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 10 Sep 2023 11:21:37 +0100 Subject: [PATCH 16/21] . --- .github/workflows/baselines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index ef9e016e690..75d2e7af6f0 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -109,7 +109,7 @@ jobs: if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: - python-version: 3.10 + python-version: '3.10' - name: Install dependencies if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' run: | From 0852a05b2ee0ab513b8dc4af0cb178e61b5d4278 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 10 Sep 2023 11:30:44 +0100 Subject: [PATCH 17/21] reverting workflow tests --- .github/workflows/baselines.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml index 75d2e7af6f0..3ed321d54df 100644 --- a/.github/workflows/baselines.yml +++ b/.github/workflows/baselines.yml @@ -109,14 +109,13 @@ jobs: if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' uses: ./.github/actions/bootstrap with: - python-version: '3.10' + python-version: 3.8 - name: Install dependencies if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' run: | changed_dir="${{ steps.validate_changed_baselines_dirs.outputs.changed_dir }}" cd "${changed_dir}" python -m poetry install - python -m poetry config virtualenvs.create true - name: Test if: steps.validate_changed_baselines_dirs.outputs.baseline_changed == 'true' run: | From 929cc32140ff1860fa16caca0e54d20112387b2c Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 11 Sep 2023 09:40:18 +0100 Subject: [PATCH 18/21] minor updates --- baselines/fedprox/fedprox/models.py | 2 +- baselines/fedprox/fedprox/strategy.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/baselines/fedprox/fedprox/models.py b/baselines/fedprox/fedprox/models.py index d6f20de1acb..b9e1d567e79 100644 --- a/baselines/fedprox/fedprox/models.py +++ b/baselines/fedprox/fedprox/models.py @@ -1,4 +1,4 @@ -"""CNN model architecutre, training, and testing functions for MNIST.""" +"""CNN model architecture, training, and testing functions for MNIST.""" from typing import List, Tuple diff --git a/baselines/fedprox/fedprox/strategy.py b/baselines/fedprox/fedprox/strategy.py index 23fd6695d85..3f77b2e2aa0 100644 --- a/baselines/fedprox/fedprox/strategy.py +++ b/baselines/fedprox/fedprox/strategy.py @@ -46,8 +46,6 @@ def aggregate_fit( # Record which client was a straggler in this round stragglers_mask = [res.metrics["is_straggler"] for _, res in results] - # print(f"Num stragglers in round: {sum(stragglers_mask)}") - # keep those results that are not from stragglers results = [res for i, res in enumerate(results) if not stragglers_mask[i]] From 9bec94f5809cdae87dfade5690a5382f11ecfae4 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 11 Sep 2023 13:11:24 +0100 Subject: [PATCH 19/21] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- baselines/fedprox/README.md | 2 +- doc/source/ref-changelog.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/baselines/fedprox/README.md b/baselines/fedprox/README.md index 57f14d0a30a..da527fbcb82 100644 --- a/baselines/fedprox/README.md +++ b/baselines/fedprox/README.md @@ -56,7 +56,7 @@ The following table shows the main hyperparameters for this baseline with their ## Environment Setup -To construct the Python environment simply run: +To construct the Python environment, simply run: ```bash poetry install diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index bf4d9e137e9..12b3973ac23 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -8,7 +8,7 @@ - **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301).[#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327)) -- **New or updates baselines** +- **Update Flower Baselines** - FedProx ([#2286](https://github.com/adap/flower/pull/2286)) From d8156679e13a314d51f049261f19db62fc73f150 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 11 Sep 2023 13:15:12 +0100 Subject: [PATCH 20/21] fix --- baselines/fedprox/fedprox/strategy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/baselines/fedprox/fedprox/strategy.py b/baselines/fedprox/fedprox/strategy.py index 3f77b2e2aa0..f32e633fad6 100644 --- a/baselines/fedprox/fedprox/strategy.py +++ b/baselines/fedprox/fedprox/strategy.py @@ -40,8 +40,6 @@ def aggregate_fit( failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ): """Discard all the models sent by the clients that were stragglers. - - in this round. """ # Record which client was a straggler in this round stragglers_mask = [res.metrics["is_straggler"] for _, res in results] From 08bc367f6a98e20e09384a97c8cd37dae067ee7a Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 11 Sep 2023 13:43:00 +0100 Subject: [PATCH 21/21] . --- baselines/fedprox/fedprox/strategy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/baselines/fedprox/fedprox/strategy.py b/baselines/fedprox/fedprox/strategy.py index f32e633fad6..5086d229da2 100644 --- a/baselines/fedprox/fedprox/strategy.py +++ b/baselines/fedprox/fedprox/strategy.py @@ -39,8 +39,7 @@ def aggregate_fit( results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ): - """Discard all the models sent by the clients that were stragglers. - """ + """Discard all the models sent by the clients that were stragglers.""" # Record which client was a straggler in this round stragglers_mask = [res.metrics["is_straggler"] for _, res in results]