diff --git a/baselines/fedbn/docs/multirun_plot.ipynb b/baselines/fedbn/docs/multirun_plot.ipynb index 6314e0033b9d..2660d1cc4e19 100644 --- a/baselines/fedbn/docs/multirun_plot.ipynb +++ b/baselines/fedbn/docs/multirun_plot.ipynb @@ -58,12 +58,12 @@ "source": [ "def fuse_by_dataset(losses):\n", " \"\"\"Transform per-round history (list of dicts) into\n", - " \n", + "\n", " a single dict, with values as lists.\"\"\"\n", " fussed_losses = {}\n", "\n", " for _, loss_dict in losses:\n", - " for k,v in loss_dict.items():\n", + " for k, v in loss_dict.items():\n", " if k in fussed_losses:\n", " fussed_losses[k].append(v)\n", " else:\n", @@ -82,9 +82,14 @@ " res_list = []\n", " for results in list(Path(path_multirun).glob(\"**/history.pkl\")):\n", " data, config = read_pickle_and_config(results)\n", - " pre_train_loss = data['history'].metrics_distributed_fit['pre_train_losses']\n", + " pre_train_loss = data[\"history\"].metrics_distributed_fit[\"pre_train_losses\"]\n", " fussed_losses = fuse_by_dataset(pre_train_loss)\n", - " res_list.append({'strategy': config['client']['client_label'], 'train_losses': fussed_losses})\n", + " res_list.append(\n", + " {\n", + " \"strategy\": config[\"client\"][\"client_label\"],\n", + " \"train_losses\": fussed_losses,\n", + " }\n", + " )\n", " return res_list" ] }, @@ -95,7 +100,7 @@ "outputs": [], "source": [ "# Here replace with the path of the multi run you just comleted\n", - "all_losses = process_multirun_data('../multirun/2023-11-15/22-28-28')" + "all_losses = process_multirun_data(\"../multirun/2023-11-15/22-28-28\")" ] }, { @@ -104,30 +109,29 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def average_by_client_type(all_fused_lossed):\n", " \"\"\"If there are multliple runs for the same strategy add them up,\n", - " \n", + "\n", " average them later. This is useful if you run the `--multirun` running\n", " more than one time the same configuration.\"\"\"\n", - " \n", + "\n", " # identify how many unique clients were used\n", " to_plot = {}\n", " for res in all_fused_lossed:\n", - " strategy = res['strategy']\n", + " strategy = res[\"strategy\"]\n", " if strategy not in to_plot:\n", " to_plot[strategy] = {}\n", - " \n", - " for dataset, train_loss in res['train_losses'].items():\n", + "\n", + " for dataset, train_loss in res[\"train_losses\"].items():\n", " if dataset in to_plot[strategy]:\n", - " to_plot[strategy][dataset]['train_loss'] += np.array(train_loss)\n", - " to_plot[strategy][dataset]['run_count'] += 1\n", + " to_plot[strategy][dataset][\"train_loss\"] += np.array(train_loss)\n", + " to_plot[strategy][dataset][\"run_count\"] += 1\n", " else:\n", - " to_plot[strategy][dataset] = {'train_loss': np.array(train_loss)}\n", - " to_plot[strategy][dataset]['run_count'] = 1\n", + " to_plot[strategy][dataset] = {\"train_loss\": np.array(train_loss)}\n", + " to_plot[strategy][dataset][\"run_count\"] = 1\n", "\n", " # print(to_plot)\n", - " return to_plot\n" + " return to_plot" ] }, { @@ -167,25 +171,25 @@ "print(datasets)\n", "\n", "num_datasets = len(datasets)\n", - "fig, axs = plt.subplots(figsize=(14,4), nrows=1, ncols=num_datasets)\n", + "fig, axs = plt.subplots(figsize=(14, 4), nrows=1, ncols=num_datasets)\n", "\n", "\n", "for s_id, (strategy, results) in enumerate(to_plot.items()):\n", " for i, dataset in enumerate(datasets):\n", - " loss = results[dataset]['train_loss']/results[dataset]['run_count']\n", + " loss = results[dataset][\"train_loss\"] / results[dataset][\"run_count\"]\n", " axs[i].plot(range(len(loss)), loss, label=strategy)\n", - " axs[i].set_xlabel('Round')\n", + " axs[i].set_xlabel(\"Round\")\n", " if i == 0:\n", - " axs[i].set_ylabel('Train Loss')\n", + " axs[i].set_ylabel(\"Train Loss\")\n", "\n", " axs[i].legend()\n", "\n", - " if s_id==0:\n", + " if s_id == 0:\n", " axs[i].grid()\n", " axs[i].set_title(dataset)\n", - " axs[i].set_xticks(np.arange(0,100+1, 25))\n", + " axs[i].set_xticks(np.arange(0, 100 + 1, 25))\n", "\n", - "saveFig('train_loss.png', fig)" + "saveFig(\"train_loss.png\", fig)" ] }, { diff --git a/baselines/hfedxgboost/hfedxgboost/__init__.py b/baselines/hfedxgboost/hfedxgboost/__init__.py index 543147a05591..ce8299b55c1a 100644 --- a/baselines/hfedxgboost/hfedxgboost/__init__.py +++ b/baselines/hfedxgboost/hfedxgboost/__init__.py @@ -1 +1 @@ -"""hfedxgboost baseline package.""" +"""Hfedxgboost baseline package.""" diff --git a/baselines/hfedxgboost/hfedxgboost/strategy.py b/baselines/hfedxgboost/hfedxgboost/strategy.py index eb067a89e5f0..b8e9c6704603 100644 --- a/baselines/hfedxgboost/hfedxgboost/strategy.py +++ b/baselines/hfedxgboost/hfedxgboost/strategy.py @@ -9,9 +9,8 @@ from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.logger import log from flwr.server.client_proxy import ClientProxy - -from flwr.server.strategy.aggregate import aggregate from flwr.server.strategy import FedAvg +from flwr.server.strategy.aggregate import aggregate class FedXgbNnAvg(FedAvg):