Skip to content

juvi21/self-compressing-nn-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-Compressing Neural Networks in JAX

This is me learning JAX by porting geohot's tinygrad implementation of this paper. I'm quite curious how they compare speed-wise to the tinygrad and PyTorch versions. As this is my first touch with JAX, the code surely isn't the most optimal, so pull requests are encouraged. I will likely benchmark them and also share the results here in this repo.

After trying out Numba, Triton, and raw CUDA, JAX and raw CUDA felt the most intuitive for me sofar. So it was definitely worth implementing this.

Bla Bla: How to Run the Code

python -m venv self-compressing-nn-jax
source self-compressing-nn-jax/bin/activate
pip install -e .
python3 train_mnist.py

Bench (coming soon)

tinygrad

PyTorch

JAX

Training run

I had some weird drops in this run two times ... need to investigate this.

Model Size vs Training Accurarcy

Releases

No releases published

Packages

No packages published

Languages