MatMamba: A Matryoshka State Space Model
Abhinav Shukla, Sai Vemprala, Aditya Kusupati, Ashish Kapoor
https://arxiv.org/abs/2410.06718
MatMamba is a general sequence processing architecture based on Mamba2. It introduces a nested Matryoshka structure in a Mamba2 block. We jointly train a few chosen granularities to get a single model from which we can flexibly extract a large number of nested submodels for adaptive inference based on the available deployment compute.
For example, one could train a single 7B model (with the same weights) while explicitly optimizing nested submodels that are 3.5B, 1.75B, and 875M. Then, at inference time we could use a 7B model on a large GPU, an 875M model on a phone, and an interpolated 2.3B model with Mix'N'Match on a medium-sized GPU. We explictly train a few (in this case 4) submodels, but are able to get hundreds of nested submodels for free by Matryoshka style learning.
The requirements for MatMamba are almost the same as that of the Mamba2 repository. To install the matmamba
package and set up a fresh conda environment with all necessary dependencies, run the following script:
bash scripts/setup_env.sh
And then:
conda activate matmamba
Like a Transformer and Mamba2, a MatMamba2 block takes in a tensor of shape (batch_size, seq_len, d_model)
and returns a tensor of the same shape. Based on the available compute, we can use a specific number of dimensions (and heads) internally.
import torch
from matmamba import MatMamba2
matmamba_block = MatMamba2(
d_model=512,
d_state=128,
).cuda()
b, l, d = 8, 1024, 512
x = torch.randn((b, l, d)).cuda()
# Without any optional args/config, the block is a regular Mamba2 block
y1 = matmamba_block(x)
assert y1.shape == (b, l, d)
# If we want a number of dims as a fraction of `d_model`, we can use the `mrl_level`
# An `mrl_level` of 2 means that `d_model/2` dims will be used
y2 = matmamba_block(x, mrl_level=2)
# `y2` is also (b, l, d), but only half the dims are used internally
assert y2.shape == (b, l, d)
# We can also manually specify the number of dims for each layer using `mixnmatch_dims`
# For example, if we want to use exactly 64 dims:
matmamba_block.mixnmatch = True
matmamba_block.mixnmatch_dims = 64
y3 = matmamba_block(x)
assert y3.shape == (b, l, d)
# Set mixnmatch to False to revert to the default behavior
matmamba_block.mixnmatch = False
matmamba_block.mixnmatch_dims = matmamba_block.d_model
Note that the first time you run the script, it may be slightly slow due to JIT / Triton auto-tuning per granularity. Subsequent calls will be faster.
See matmamba/mamba2.py
for the implementation of the MatMamba2 block.
We can make a vision model (MatMamba-Vision) and a language model (MatMamba-LM) using the MatMamba block.
This works very similar to a ViT. The transformer blocks are replaced by MatMamba blocks, and the [CLS] token is moved to a suffix due to the causal nature of Mamba2. We can attach a classification head for tasks like image classification on ImageNet, but it can also be used for any vision task.
import torch
from matmamba import MatMamba2Vision, MatMamba2VisionConfig
config = MatMamba2VisionConfig(
d_model=1024,
n_layer=20,
d_intermediate=0,
n_classes=1000,
patch_size=16,
drop_path_rate=0.1,
proj_drop_rate=0.1,
)
model = MatMamba2Vision(config).cuda()
x = torch.randn((8, 3, 224, 224)).cuda() # Dummy image batch
y = model(x)
assert y.shape == (8, 1000)
# mrl_level and mixnmatch_dims can be used here as well
y2 = model(x, mrl_level=2)
assert y2.shape == (8, 1000)
for layer in model.layers:
layer.mixer.mixnmatch = True
layer.mixer.mixnmatch_dims = 256
y3 = model(x)
assert y3.shape == (8, 1000)
You can also directly load a pretrained model using the from_pretrained
method:
model = model.from_pretrained("scaledfoundations/MatMamba-Vision-135M-ImageNet")
See matmamba/matmamba2_vision.py
for the implementation of the vision backbone, and training code in train_imagenet.py
.
[coming soon]
We can also make a Causal Language Model using the MatMamba block.
import torch
from mamba_ssm.models.config_mamba import MambaConfig
from matmamba import MatMambaLMHeadModel
model = MatMambaLMHeadModel(MambaConfig(n_layer=24, d_model=768)).cuda()
vocab_size = 50280
b, l = 8, 1024
# Dummy input batch of token ids, can come from any tokenizer
x = torch.randint(0, vocab_size, (b, l)).cuda()
y = model(x).logits
assert y.shape == (b, l, vocab_size)
# mrl_level and mixnmatch_dims can be used here as well
y2 = model(x, mrl_level=2).logits
assert y2.shape == (b, l, vocab_size)
for layer in model.backbone.layers:
layer.mixer.mixnmatch = True
layer.mixer.mixnmatch_dims = 384
y3 = model(x).logits
assert y3.shape == (b, l, vocab_size)
You can also directly load a pretrained model using the from_pretrained
method:
model = model.from_pretrained("scaledfoundations/MatMamba-LM-1.4B-FineWeb")
See matmamba/mixer_seq_simple.py
for the implementation of the language model backbone, and training code in train_fineweb.py
.
[coming soon]
You can find all pretrained models (MatMamba-Vision and MatMamba-LM) from the paper on Hugging Face in the MatMamba collection.
Model Name | Training Dataset | d_model | Training Granularities | Link to Weights |
---|---|---|---|---|
MatMamba-Vision-35M |
ImageNet | 512 | 512, 256, 128, 64 | weights |
MatMamba-Vision-135M |
ImageNet | 1024 | 1024, 512, 256, 128 | weights |
MatMamba-LM-130M |
FineWeb | 768 | 768, 384, 192, 96 | weights |
MatMamba-LM-370M |
FineWeb | 1024 | 1024, 512, 256, 128 | weights |
MatMamba-LM-790M |
FineWeb | 1536 | 1536, 768, 384, 192 | weights |
MatMamba-LM-1.4B |
FineWeb | 2048 | 2048, 1024, 512, 256 | weights |
If you use this code, or otherwise find our work valuable, please cite:
@article{shukla2024matmamba,
title={MatMamba: A Matryoshka State Space Model},
author={Shukla, Abhinav and Vemprala, Sai, and Kusupati, Aditya, and Kapoor, Ashish},
journal={arXiv preprint arXiv:2410.06718},
year={2024}
}