Skip to content

Latest commit

 

History

History
49 lines (43 loc) · 2.08 KB

README.md

File metadata and controls

49 lines (43 loc) · 2.08 KB

RMSGD: Augmented SGD Optimizer

Official PyTorch implementation of the RMSGD optimizer from:

Exploiting Explainable Metrics for Augmented SGD
Mahdi S. Hosseini, Mathieu Tuli, Konstantinos N. Plataniotis
Accepted in IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR2022)


We propose new explainability metrics that measure the redundant information in a network's layers and exploit this information to augment the Stochastic Gradient Descent (SGD) optimizer by adaptively adjusting the learning rate in each layer. We call this new optimizer RMSGD. RMSGD is fast, performs better than existing sota, and generalizes well across experimental configurations.

Contents

This repository + branch contains the standalone optimizer, which is pip installable. Equally, you could copy the contents of src/rmsgd into your local repository and use the optimizer as is.

For all code relating to our paper and to replicate those experiments, see the paper branch

Installation

You can install rmsgd using pip install rmsgd, or equally:

git clone https://github.com/mahdihosseini/RMSGD.git
cd RMSGD
pip install .

Usage

RMSGD can be used like any other optimizer, with one additional step:

from rmsgd import RMSGD
...
optimizer = RMSGD(...)
...
for input in data_loader:
    optimizer.zero_grad()
    output = network(input)
    optimizer.step()
optimizer.epoch_step()

Simply, you must call .epoch_step() at the end of each epoch to update the analysis of the network layers.

Citation

@Article{hosseini2022rmsgd,
  author  = {Hosseini, Mahdi S. and Tuli, Mathieu and Plataniotis, Konstantinos N.},
  title   = {Exploiting Explainable Metrics for Augmented SGD},
  journal = {Accepted in IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year    = {2022},
}

License

This project is released under the MIT license. Please see the LICENSE file for more information.