Skip to content

Latest commit

 

History

History
121 lines (91 loc) · 7.07 KB

README.md

File metadata and controls

121 lines (91 loc) · 7.07 KB

Federated Deep Equilibrium Learning

This repository contains reference code for the paper Federated Deep Equilibrium Learning: Harnessing Compact Global Representations to Enhance Personalization (CIKM 2024).

Authors: Long Tan Le, Tuan Dung Nguyen, Tung-Anh Nguyen, Choong Seon Hong, Suranga Seneviratne, Wei Bao, Nguyen H. Tran

Paper Link: https://doi.org/10.1145/3627673.3679752


Network Models

This project contains explicit models such as ResNet, Transformer and implicit DEQ models such as DEQ-Resnet and DEQ-Transformer.

  • ResNet, DEQ-Resnet

Working with datasets: FEMNIST, CIFAR-10, CIFAR-100

  • Transformer, DEQ-Transformer

Working with datasets: Shakespeare

Datasets

  • FEMNIST: 62 different classes (10 digits, 26 lowercase, 26 uppercase), images are 28x28 pixels, 200 clients (nature non-IID).
  • CIFAR-10: Consist of 60000 32x32 color images in 10 classes including 50000 training images and 10000 test images, 100 users (non-IID by labels)
  • CIFAR-100: Consist of 60000 32x32 color images in 100 classes including 50000 training images and 10000 test images, 100 users (non-IID by labels)
  • Shakespeare: Text Dataset of Shakespeare Dialogues, 200 users (nature non-IID).
Task Name Dataset Model Task Summary
cifar10_image CIFAR-10 ResNet34, DEQ-ResNet-M Image Classification
cifar100_image CIFAR-100 ResNet34, DEQ-ResNet-M Image Classification
femnist_image FEMNIST ResNet20, DEQ-ResNet-S Character Recognition
shakespeare_character Shakespeare Transformer-8, DEQ-Transformer Next-Character prediction

Directory structure

  • main.py: the main driver script
  • utils/data_utils.py: helper functions for reading datasets, data loading, logging, etc.
  • utils/jax_utils.py: helper functions for loss, prediction, structs, etc.
  • utils/model_utils.py: model definitions in JAX/Haiku
  • deq/: implementations of DEQ
  • trainers/: implementations of the algorithms
  • runners/: scripts for starting an experiment
  • data/: directory of datasets

Environment

  • OS: Ubuntu 20.04
  • Python == 3.9
  • JAX == 0.3.25
  • Jaxlib == 0.3.25+cuda11.cudnn82
  • jaxopt == 0.5.5
  • optax == 0.1.4
  • chex == 0.1.5
  • dm-haiku == 0.0.9

Please follow the installation guidelines in http://github.com/google/jax to install compatible version Jax and Jaxlib version for your machine. The version of packages related to Jax (jaxopt, optax, chex, dm-haiku) may also need to be adjusted for compatible with Jax and Jaxlib.

To install other dependencies: pip3 install -r requirements.txt

Experiments

The general template commands for running an experiment are:

bash runners/<dataset>/run_<algorithm>.sh [other flags]

Flags

params full params description default value options
--trainer Algorithm to run fedeq_vision
-m --model Local Models deq_resnet_s
-c --num_clients the total number of clients 100
-lpc --labels_per_client the number of labels per clients for CIFAR-10, CIFAR-100 5
-d --dataset Dataset Name femnist
-t --num_rounds the number of global rounds 100
-lr --learning_rate the learning rate of client optimizers 0.01
-cr --client_rate the propotion of clients selected for training each round 0.1
-b --batch_size batch size 10
-le --local_epochs the number of epochs training representation 5
-pe --personalized_epochs the number of epochs training personalized params 3
-fs --fwd_solver Root-finding solver for DEQ models anderson
-bs --bwd_solver Backward solver for DEQ models normal_cg
--rho The value of rho for ADMM consensus optimization 0.01
-fu --frac_unseen the propotion of unseen clients 0.0
-r --repeat Number of times to repeat the experiment 1
-g --gpu The ID of GPU used to run experiments 0

Citation

@inproceedings{10.1145/3627673.3679752,
author = {Le, Long Tan and Nguyen, Tuan Dung and Nguyen, Tung-Anh and Hong, Choong Seon and Seneviratne, Suranga and Bao, Wei and Tran, Nguyen H.},
title = {Federated Deep Equilibrium Learning: Harnessing Compact Global Representations to Enhance Personalization},
year = {2024},
publisher = {Association for Computing Machinery},
pages = {1132–1142},
series = {CIKM '24}
}

References

Motley

Jax

  • James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. http://github.com/google/jax

Haiku

Jaxopt

  • Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares Lopez, Fabian Pedregosa, and Jean-Philippe Vert. Efficient and modular implicit differentiation Advances in Neural Information Processing Systems, 2022. https://github.com/google/jaxopt

FedJax

  • Guanhua Wang, Haibo Yu, Shuang Wu, Wei Dai, Jun Feng, Shuai Li, Han Yu, Tian Li, and Jakub Konecny. Fedjax: Federated learning simulation with jax, 2021. https://fedjax.readthedocs.io/