Skip to content

Latest commit

 

History

History
27 lines (15 loc) · 1.27 KB

README.md

File metadata and controls

27 lines (15 loc) · 1.27 KB

Gradients without Backpropagation - JAX Implementation

This repository contains a JAX implementation of the methods described in the paper Gradients without Backpropagation

Sometimes, all we want is to get rid of backpropagation of errors and estimate unbiased gradient of loss function during single inference pass :)

Overview

The code demonstrates how to train a simple MLP on MNIST, using either forward gradients (described as $(\nabla f(\boldsymbol{\theta}) \cdot \boldsymbol{v}) \boldsymbol{v}$) calculated by JVP (Jacobian-vector product, forward AD) or traditional VJP (vector-Jacobian product, aka reverse AD) methods. To investigate how stable and scalable the forward gradients method is (as the variance of the estimate is proportional to the number of parameters), you can increase --num_layers parameters.

Note: It seems like this doesn't efficiently scale beyond 10 layers because variance of the gradient estimation depends on number of parameters of the network.

Comparison

Requirements

  • JAX <3
  • optax (for learning rate scheduling)
  • wandb (optional, for logging)

Usage

To run the code and replicate MLP training with forward gradients on MNIST, simply execute the train.py :

python train.py