Skip to content

Commit

Permalink
format baselines (#2960)
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert-Steiner authored Feb 20, 2024
1 parent 29dfaff commit 8433a07
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
50 changes: 27 additions & 23 deletions baselines/fedbn/docs/multirun_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
"source": [
"def fuse_by_dataset(losses):\n",
" \"\"\"Transform per-round history (list of dicts) into\n",
" \n",
"\n",
" a single dict, with values as lists.\"\"\"\n",
" fussed_losses = {}\n",
"\n",
" for _, loss_dict in losses:\n",
" for k,v in loss_dict.items():\n",
" for k, v in loss_dict.items():\n",
" if k in fussed_losses:\n",
" fussed_losses[k].append(v)\n",
" else:\n",
Expand All @@ -82,9 +82,14 @@
" res_list = []\n",
" for results in list(Path(path_multirun).glob(\"**/history.pkl\")):\n",
" data, config = read_pickle_and_config(results)\n",
" pre_train_loss = data['history'].metrics_distributed_fit['pre_train_losses']\n",
" pre_train_loss = data[\"history\"].metrics_distributed_fit[\"pre_train_losses\"]\n",
" fussed_losses = fuse_by_dataset(pre_train_loss)\n",
" res_list.append({'strategy': config['client']['client_label'], 'train_losses': fussed_losses})\n",
" res_list.append(\n",
" {\n",
" \"strategy\": config[\"client\"][\"client_label\"],\n",
" \"train_losses\": fussed_losses,\n",
" }\n",
" )\n",
" return res_list"
]
},
Expand All @@ -95,7 +100,7 @@
"outputs": [],
"source": [
"# Here replace with the path of the multi run you just comleted\n",
"all_losses = process_multirun_data('../multirun/2023-11-15/22-28-28')"
"all_losses = process_multirun_data(\"../multirun/2023-11-15/22-28-28\")"
]
},
{
Expand All @@ -104,30 +109,29 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"def average_by_client_type(all_fused_lossed):\n",
" \"\"\"If there are multliple runs for the same strategy add them up,\n",
" \n",
"\n",
" average them later. This is useful if you run the `--multirun` running\n",
" more than one time the same configuration.\"\"\"\n",
" \n",
"\n",
" # identify how many unique clients were used\n",
" to_plot = {}\n",
" for res in all_fused_lossed:\n",
" strategy = res['strategy']\n",
" strategy = res[\"strategy\"]\n",
" if strategy not in to_plot:\n",
" to_plot[strategy] = {}\n",
" \n",
" for dataset, train_loss in res['train_losses'].items():\n",
"\n",
" for dataset, train_loss in res[\"train_losses\"].items():\n",
" if dataset in to_plot[strategy]:\n",
" to_plot[strategy][dataset]['train_loss'] += np.array(train_loss)\n",
" to_plot[strategy][dataset]['run_count'] += 1\n",
" to_plot[strategy][dataset][\"train_loss\"] += np.array(train_loss)\n",
" to_plot[strategy][dataset][\"run_count\"] += 1\n",
" else:\n",
" to_plot[strategy][dataset] = {'train_loss': np.array(train_loss)}\n",
" to_plot[strategy][dataset]['run_count'] = 1\n",
" to_plot[strategy][dataset] = {\"train_loss\": np.array(train_loss)}\n",
" to_plot[strategy][dataset][\"run_count\"] = 1\n",
"\n",
" # print(to_plot)\n",
" return to_plot\n"
" return to_plot"
]
},
{
Expand Down Expand Up @@ -167,25 +171,25 @@
"print(datasets)\n",
"\n",
"num_datasets = len(datasets)\n",
"fig, axs = plt.subplots(figsize=(14,4), nrows=1, ncols=num_datasets)\n",
"fig, axs = plt.subplots(figsize=(14, 4), nrows=1, ncols=num_datasets)\n",
"\n",
"\n",
"for s_id, (strategy, results) in enumerate(to_plot.items()):\n",
" for i, dataset in enumerate(datasets):\n",
" loss = results[dataset]['train_loss']/results[dataset]['run_count']\n",
" loss = results[dataset][\"train_loss\"] / results[dataset][\"run_count\"]\n",
" axs[i].plot(range(len(loss)), loss, label=strategy)\n",
" axs[i].set_xlabel('Round')\n",
" axs[i].set_xlabel(\"Round\")\n",
" if i == 0:\n",
" axs[i].set_ylabel('Train Loss')\n",
" axs[i].set_ylabel(\"Train Loss\")\n",
"\n",
" axs[i].legend()\n",
"\n",
" if s_id==0:\n",
" if s_id == 0:\n",
" axs[i].grid()\n",
" axs[i].set_title(dataset)\n",
" axs[i].set_xticks(np.arange(0,100+1, 25))\n",
" axs[i].set_xticks(np.arange(0, 100 + 1, 25))\n",
"\n",
"saveFig('train_loss.png', fig)"
"saveFig(\"train_loss.png\", fig)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion baselines/hfedxgboost/hfedxgboost/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""hfedxgboost baseline package."""
"""Hfedxgboost baseline package."""
3 changes: 1 addition & 2 deletions baselines/hfedxgboost/hfedxgboost/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy.aggregate import aggregate
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate


class FedXgbNnAvg(FedAvg):
Expand Down

0 comments on commit 8433a07

Please sign in to comment.