This repo contains an implementation of the Muon
optimizer described in this thread.
Muon is the fastest known optimizer across diverse training scenarios including CIFAR-10
and GPT-2 scale language modeling.
pip install git+https://github.com/KellerJordan/Muon
Muon is intended for only the internal ≥2D parameters of a network. Any embedding, classifier head, or {0, 1}D parameter should be optimized using a backup optimizer instead (e.g., AdamW). Muon provides an internal AdamW backup so you don't have to use an extra optimizer.
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)
from muon import Muon
# Find ≥2D parameters in the body of the network -- these will be optimized by Muon
muon_params = [p for p in model.body.parameters() if p.ndim >= 2]
# Find everything else -- these will be optimized by AdamW
adamw_params = [p for p in model.body.parameters() if p.ndim < 2]
adamw_params.extend(model.head.parameters())
adamw_params.extend(model.embed.parameters())
# Create the optimizer
optimizer = Muon(muon_params, lr=0.02, momentum=0.95,
adamw_params=adaw_params, adamw_lr=3e-4, adamw_betas=(0.90, 0.95), adamw_wd=0.01)
You'll have to replace model.body
, model.head
, and model.embed
with whatever subset is appropriate for your model.
(E.g., for a ConvNet, muon_params
should be all the convolutional filters, and adamw_params
should be everything else (the classifier head and any gains or biases in the model).)
If you're replacing an already-tuned AdamW with Muon, the only thing you should need to tune is Muon's learning rate. The AdamW hyperparameters should be set to whatever you were already using.
For a comparison between AdamW, Shampoo, SOAP, and Muon for training a 124M-parameter transformer, see here.
See this thread for more info including the connection to Shampoo.