Skip to content

radarFudan/mamba-minimal-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mamba-minimal-jax

Simple, minimal implementation of the Mamba SSM in one file of JAX.

Plan:

  1. First finish the model.py, done.
  2. Convert the pytorch weights into the JAX weights, done.
  3. Check the results of greedy generation is the same as pytorch, done.
  4. Implement the associative scan so that the state update is faster, done in the speedup branch. See discussion in srush/annotated-mamba#1.
  5. Pay attention to the weights initialization so that we can train the model from scratch.
  6. Implement the step function for mamba inference.

From mamba-minimal

Featuring:

  • Equivalent numerical output as official implementation for both forward and backward pass
  • Simplified, readable, annotated code

Does NOT include:

  • Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability.
  • Proper parameter initialization (though this could be added without sacrificing readability)

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

The official implementation is here: https://github.com/state-spaces/mamba

The minimal implementation in torch is here: https://github.com/johnma2006/mamba-minimal

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published