Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format baselines #2960

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions baselines/fedbn/docs/multirun_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
"source": [
"def fuse_by_dataset(losses):\n",
" \"\"\"Transform per-round history (list of dicts) into\n",
" \n",
"\n",
" a single dict, with values as lists.\"\"\"\n",
" fussed_losses = {}\n",
"\n",
" for _, loss_dict in losses:\n",
" for k,v in loss_dict.items():\n",
" for k, v in loss_dict.items():\n",
" if k in fussed_losses:\n",
" fussed_losses[k].append(v)\n",
" else:\n",
Expand All @@ -82,9 +82,14 @@
" res_list = []\n",
" for results in list(Path(path_multirun).glob(\"**/history.pkl\")):\n",
" data, config = read_pickle_and_config(results)\n",
" pre_train_loss = data['history'].metrics_distributed_fit['pre_train_losses']\n",
" pre_train_loss = data[\"history\"].metrics_distributed_fit[\"pre_train_losses\"]\n",
" fussed_losses = fuse_by_dataset(pre_train_loss)\n",
" res_list.append({'strategy': config['client']['client_label'], 'train_losses': fussed_losses})\n",
" res_list.append(\n",
" {\n",
" \"strategy\": config[\"client\"][\"client_label\"],\n",
" \"train_losses\": fussed_losses,\n",
" }\n",
" )\n",
" return res_list"
]
},
Expand All @@ -95,7 +100,7 @@
"outputs": [],
"source": [
"# Here replace with the path of the multi run you just comleted\n",
"all_losses = process_multirun_data('../multirun/2023-11-15/22-28-28')"
"all_losses = process_multirun_data(\"../multirun/2023-11-15/22-28-28\")"
]
},
{
Expand All @@ -104,30 +109,29 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"def average_by_client_type(all_fused_lossed):\n",
" \"\"\"If there are multliple runs for the same strategy add them up,\n",
" \n",
"\n",
" average them later. This is useful if you run the `--multirun` running\n",
" more than one time the same configuration.\"\"\"\n",
" \n",
"\n",
" # identify how many unique clients were used\n",
" to_plot = {}\n",
" for res in all_fused_lossed:\n",
" strategy = res['strategy']\n",
" strategy = res[\"strategy\"]\n",
" if strategy not in to_plot:\n",
" to_plot[strategy] = {}\n",
" \n",
" for dataset, train_loss in res['train_losses'].items():\n",
"\n",
" for dataset, train_loss in res[\"train_losses\"].items():\n",
" if dataset in to_plot[strategy]:\n",
" to_plot[strategy][dataset]['train_loss'] += np.array(train_loss)\n",
" to_plot[strategy][dataset]['run_count'] += 1\n",
" to_plot[strategy][dataset][\"train_loss\"] += np.array(train_loss)\n",
" to_plot[strategy][dataset][\"run_count\"] += 1\n",
" else:\n",
" to_plot[strategy][dataset] = {'train_loss': np.array(train_loss)}\n",
" to_plot[strategy][dataset]['run_count'] = 1\n",
" to_plot[strategy][dataset] = {\"train_loss\": np.array(train_loss)}\n",
" to_plot[strategy][dataset][\"run_count\"] = 1\n",
"\n",
" # print(to_plot)\n",
" return to_plot\n"
" return to_plot"
]
},
{
Expand Down Expand Up @@ -167,25 +171,25 @@
"print(datasets)\n",
"\n",
"num_datasets = len(datasets)\n",
"fig, axs = plt.subplots(figsize=(14,4), nrows=1, ncols=num_datasets)\n",
"fig, axs = plt.subplots(figsize=(14, 4), nrows=1, ncols=num_datasets)\n",
"\n",
"\n",
"for s_id, (strategy, results) in enumerate(to_plot.items()):\n",
" for i, dataset in enumerate(datasets):\n",
" loss = results[dataset]['train_loss']/results[dataset]['run_count']\n",
" loss = results[dataset][\"train_loss\"] / results[dataset][\"run_count\"]\n",
" axs[i].plot(range(len(loss)), loss, label=strategy)\n",
" axs[i].set_xlabel('Round')\n",
" axs[i].set_xlabel(\"Round\")\n",
" if i == 0:\n",
" axs[i].set_ylabel('Train Loss')\n",
" axs[i].set_ylabel(\"Train Loss\")\n",
"\n",
" axs[i].legend()\n",
"\n",
" if s_id==0:\n",
" if s_id == 0:\n",
" axs[i].grid()\n",
" axs[i].set_title(dataset)\n",
" axs[i].set_xticks(np.arange(0,100+1, 25))\n",
" axs[i].set_xticks(np.arange(0, 100 + 1, 25))\n",
"\n",
"saveFig('train_loss.png', fig)"
"saveFig(\"train_loss.png\", fig)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion baselines/hfedxgboost/hfedxgboost/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""hfedxgboost baseline package."""
"""Hfedxgboost baseline package."""
3 changes: 1 addition & 2 deletions baselines/hfedxgboost/hfedxgboost/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy.aggregate import aggregate
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate


class FedXgbNnAvg(FedAvg):
Expand Down