This repository contains the official PyTorch code for the paper: Federated Learning via Meta-Variational Dropout published in NeurIPS 2023.
- Python >= 3.7.4
- CUDA >= 10.0 supported GPU
- Anaconda
Setup Environment
pip install -r environment.yml
conda activate metavd
python main.py --model <model-name> --dataset <dataset-name> <other-options>
EX) Run Cifar10 Experiment with MetaVD
python main.py --model nvdpgaus --dataset cifar10
EX) Run Cifar100 Experiment with MetaVD and Heterogeneity level of
python main.py --model nvdpgaus --dataset cifar100 --alpha 5.0
We currently support following models and datasets options.
Model Name | Flag | Description |
---|---|---|
FedAvg | fedavg |
Federated Averaging |
FedAvg + Finetuning | fedavgper |
Personalized Federated Learning |
FedAvg + MetaVD | fedavgnvdpgausq |
Federated Averaging with MetaVD (proposed in this work) |
FedAvg + SNIP | fedavgsnip |
Federated Averaging with SNIP |
FedProx | fedprox |
Federated Proximal Optimization |
FedBE | fedbe |
Federated Learning with Bayesian Ensemble |
Reptile | reptile |
Federated Learning with Reptile |
Reptile + VD | vdgausq |
Reptile with VD |
Reptile + EnsembleVD | vdgausemq |
Reptile with EnsembleVD |
Reptile + MetaVD | nvdpgausq |
Reptile with MetaVD (proposed in this work) |
Reptile + SNIP | reptilesnip |
Reptile with SNIP |
MAML | maml |
Federated Learning with Model-Agnostic Meta-Learning |
MAML + MetaVD | mamlgausq |
MAML with MetaVD (proposed in this work) |
MAML + SNIP | mamlsnip |
MAML with SNIP |
PerFedAvg | perfedavg |
HF-MAML with SNIP |
PerFedAvg + MetaVD | perfedavgnvdpgausq |
HF-MAML with MetaVD (proposed in this work) |
PerFedAvg + SNIP | perfedavgsnip |
HF-MAML with SNIP |
Dataset Name | Flag | Description |
---|---|---|
Femnist | femnist |
Federated EMNIST dataset |
Celeba | celeba |
CelebA dataset |
MNIST | mnist |
MNIST dataset |
Cifar10 | cifar10 |
CIFAR10 dataset |
Cifar100 | cifar100 |
CIFAR100 dataset |
EMNIST | emnist |
Extended MNIST dataset |
FMNIST | fmnist |
Fashion MNIST dataset |
Please see the arg parser in main.py file to enable other options.
For all datasets, we set the number of rounds (num_rounds
) to 1000 to ensure sufficient convergence following conventions. The batch size (local_bs
) was set to 64, and local steps (local_epochs
) was set to 5. Personalization was executed with a batch size (adaptation_bs
) of 64 and a 1-step update.
For all methods, we investigated the server learning rate and local SGD learning rate within identical
ranges. The server learning rate η (server_lr
) was explored within the range of [0.6, 0.7, 0.8, 0.9, 1.0]. The local
SGD learning rate (inner_lr
) was investigated within the range of [0.005, 0.01, 0.015, 0.02, 0.025, 0.03].For MetaVD, an additional KL divergence weight parameter
β (beta
) is needed, and we set its optimal value to 10.
-
Tensorboard Setup
cd runs tensorboard --logdir=./ --port=7770 --samples_per_plugin image=100 --reload_multifile=True --reload_interval 30 --host=0.0.0.0
Access visualizations at localhost:7770.
If you find this work useful, please cite our paper:
@article{jeon2024federated,
title={Federated Learning via Meta-Variational Dropout},
author={Jeon, Insu and Hong, Minui and Yun, Junhyeog and Kim, Gunhee},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}
Thank you, my colleagues, for your valuable contributions.