This is an official PyTorch implementation of Adan. See the paper here. If you find our adan helpful or heuristic to your projects, please cite this paper and also star this repository. Thanks!
@article{xie2024adan,
title={Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models},
author={Xie, Xingyu and Zhou, Pan and Li, Huan and Lin, Zhouchen and Yan, Shuicheng},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2024},
publisher={IEEE}
}
- Adan is supported in the framework NeMo from NVIDIA.
- Adan is the default optimizer for the High-Fidelity Text-to-3D Generation Project. See more details Consistent3D.
- Adan is the default optimizer for Masked Diffusion Transformer V2. See more details MDT V2.
- Adan is supported in the Project D-Adaptation from Meta AI.
- Adan will be supported in the Project Paddle from Baidu(百度飞浆).
- Adan is supported in Timm from Huggingface
- Adan is the default optimizer in the text-to-3D DreamFusion Project. See more results here.
- Adan is supported in the MMClassification of the OpenMMLab project. The log and example of using Adan to train ViT-B is here.
- TF's implementation (third party) refers to DenisVorotyntsev/Adan.
- JAX's version (third party) is implemented and also supported in Deepmind/optax.
- 🔥🔥🔥 Results on large language models, like MoE and GPT2, are released.
- FusedAdan with less memory footprint is released.
python3 -m pip install git+https://github.com/sail-sg/Adan.git
FusedAdan is installed by default. If you want to use the original Adan, please install it by:
git clone https://github.com/sail-sg/Adan.git
cd Adan
python3 setup.py install --unfused
For your convenience to use Adan, we briefly provide some intuitive instructions below, then provide some general experimental tips, and finally provide more details (e.g., specific commands and hyper-parameters) for each experiment in the paper.
Step 1. Add Adan-dependent hyper-parameters by adding the following hyper-parameters to the config:
parser.add_argument('--max-grad-norm', type=float, default=0.0, help='if the l2 norm is large than this hyper-parameter, then we clip the gradient (default: 0.0, no gradient clip)')
parser.add_argument('--weight-decay', type=float, default=0.02, help='weight decay, similar one used in AdamW (default: 0.02)')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', help='optimizer epsilon to avoid the bad case where second-order moment is zero (default: None, use opt default 1e-8 in adan)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='optimizer betas in Adan (default: None, use opt default [0.98, 0.92, 0.99] in Adan)')
parser.add_argument('--no-prox', action='store_true', default=False, help='whether perform weight decay like AdamW (default=False)')
opt-betas
: To keep consistent with our usage habits, the
foreach (bool)
: If True
, Adan will use the torch._foreach
implementation. It is faster but uses slightly more memory.
no-prox
: It determines the update rule of parameters with weight decay. By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
But one can also update the parameter like Adamw:
Step 2. Create the Adan optimizer as follows. In this step, we can directly replace the vanilla optimizer by using the following command:
from adan import Adan
optimizer = Adan(param, lr=args.lr, weight_decay=args.weight_decay, betas=args.opt_betas, eps = args.opt_eps, max_grad_norm=args.max_grad_norm, no_prox=args.no_prox)
- To make Adan simple, in all experiments except Table 12 in the paper, we do not use the restart strategy in Adan. But Table 12 shows that the restart strategy can further slightly improve the performance of Adan.
- Adan often allows one to use a large peak learning rate which often fails other optimizers, e.g., Adam and AdamW. For example, in all experiments except for the MAE pre-training and LSTM, the learning rate used by Adan is 5-10 times larger than that in Adam/AdamW.
- Adan is relatively robust to
beta1
,beta2,
andbeta3
, especially forbeta2
. If you want better performance, you can first tunebeta3
and thenbeta1
. - Adan has a slightly higher GPU memory cost than Adam/AdamW on a single node. However, this problem can be solved using the ZeroRedundancyOptimizer, which shares optimizer states across distributed data-parallel processes to reduce per-process memory footprint. Specifically, when using the
ZeroRedundancyOptimizer
on more than two GPUs, Adan and Adam consume almost the same amount of memory.
Please refer to the following links for detailed steps. In these detailed steps, we even include the docker images for reproducibility.
- Instruction for ViTs, ResNets, and ConvNext.
- Instruction for MAE.
- Instruction for BERT.
- Instruction for Transformer-XL.
- Instruction for GPT2
- Resutls for text-to-3D DreamFusion.
To investigate the efficacy of the Adan optimizer for LLMs, we conducted pre-training experiments using MoE models. The experiments utilized the RedPajama-v2 dataset with three configurations, each consisting of 8 experts: 8x0.1B (totaling 0.5B trainable parameters), 8x0.3B (2B trainable parameters), and 8x0.6B (4B trainable parameters). These models were trained with sampled data comprising 10B, 30B, 100B, and 300B tokens, respectively.
Model Size | 8x0.1B | 8x0.1B | 8x0.1B | 8x0.3B | 8x0.3B | 8x0.3B | 8x0.6B |
---|---|---|---|---|---|---|---|
Token Size | 10B | 30B | 100B | 30B | 100B | 300B | 300B |
AdamW | 2.722 | 2.550 | 2.427 | 2.362 | 2.218 | 2.070 | 2.023 |
Adan | 2.697 | 2.513 | 2.404 | 2.349 | 2.206 | 2.045 | 2.010 |
We provide the config and log for GPT2-345m pre-trained on the dataset that comes from BigCode and evaluated on the HumanEval dataset by zero-shot learning. HumanEval is used to measure functional correctness for synthesizing programs from docstrings. It consists of 164 original programming problems, assessing language comprehension, algorithms, and simple mathematics, with some comparable to simple software interview questions. We set Temperature = 0.8
during evaluation.
Steps | pass@1 | pass@10 | pass@100 | Download | |
---|---|---|---|---|---|
GPT2-345m (Adam) | 300k | 0.0840 | 0.209 | 0.360 | log&config |
GPT2-345m (Adan) | 150k | 0.0843 | 0.221 | 0.377 | log&config |
Adan obtains comparable results with only half cost.
For your convenience to use Adan, we provide the configs and log files for the experiments on ImageNet-1k.
Model | Epoch | Training Setting | Acc. (%) | Config | Batch Size | Download |
---|---|---|---|---|---|---|
ViT-S | 150 | I | 80.1 | config | 2048 | log/model |
ViT-S | 150 | II | 79.6 | config | 2048 | log/model |
ViT-S | 300 | I | 81.1 | config | 2048 | log/model |
ViT-S | 300 | II | 80.7 | config | 2048 | log/model |
ViT-B | 150 | II | 81.7 | config | 2048 | log/model |
ViT-B | 300 | II | 82.6 | config | 2048 | log/model |
ResNet-50 | 100 | I | 78.1 | config | 2048 | log/model |
ResNet-50 | 200 | I | 79.7 | config | 2048 | log/model |
ResNet-50 | 300 | I | 80.2 | config | 2048 | log/model |
ResNet-101 | 100 | I | 80.0 | config | 2048 | log/model |
ResNet-101 | 200 | I | 81.6 | config | 2048 | log/model |
ResNet-101 | 300 | I | 81.9 | config | 2048 | log/model |
ConvNext-tiny | 150 | II | 81.7 | config | 2048 | log//model |
ConvNext-tiny | 300 | II | 82.4 | config | 2048 | log/model |
MAE-small | 800+100 | --- | 83.8 | config | 4096/2048 | log-pretrain/log-finetune/model |
MAE-Large | 800+50 | --- | 85.9 | config | 4096/2048 | log-pretrain/log-finetune/model |
We give the configs and log files of the BERT-base model pre-trained on the Bookcorpus and Wikipedia datasets and fine-tuned on GLUE tasks. Note that we provide the config, log file, and detailed instructions for BERT-base in the folder ./NLP/BERT
.
Pretraining | Config | Batch Size | Log | Model |
---|---|---|---|---|
Adan | config | 256 | log | model |
Fine-tuning on GLUE-Task | Metric | Result | Config |
---|---|---|---|
CoLA | Matthew's corr. | 64.6 | config |
SST-2 | Accuracy | 93.2 | config |
STS-B | Person corr. | 89.3 | config |
QQP | Accuracy | 91.2 | config |
MNLI | Matched acc./Mismatched acc. | 85.7/85.6 | config |
QNLI | Accuracy | 91.3 | config |
RTE | Accuracy | 73.3 | config |
For fine-tuning on GLUE-Task, see the total batch size in their corresponding configure files.
We provide the config and log for Transformer-XL-base trained on the WikiText-103 dataset. The total batch size for this experiment is 60*4
.
Steps | Test PPL | Download | |
---|---|---|---|
Baseline (Adam) | 200k | 24.2 | log&config |
Transformer-XL-base | 50k | 26.2 | log&config |
Transformer-XL-base | 100k | 24.2 | log&config |
Transformer-XL-base | 200k | 23.5 | log&config |
We provide the config and log for GPT2-345m pre-trained on the dataset that comes from BigCode and evaluated on the HumanEval dataset by zero-shot learning. HumanEval is used to measure functional correctness for synthesizing programs from docstrings. It consists of 164 original programming problems, assessing language comprehension, algorithms, and simple mathematics, with some comparable to simple software interview questions. We set Temperature = 0.8
during evaluation.
Steps | pass@1 | pass@10 | pass@100 | Download | |
---|---|---|---|---|---|
GPT2-345m (Adam) | 300k | 0.0840 | 0.209 | 0.360 | log&config |
GPT2-345m (Adan) | 150k | 0.0843 | 0.221 | 0.377 | log&config |
Adan obtains comparable results with only half cost.
We show the results of the text-to-3D task supported by the DreamFusion Project. More visualization results could be founded here.
Examples generated from text prompt Sydney opera house, aerial view
with Adam and Adan:
opera-adan.mp4
opera-adam.mp4
A brief comparison of peak memory and wall duration for the optimizer is as follows. The duration time is the total time of 200 optimizer.step()
. We further compare Adam and FusedAdan in great detail on GPT-2. See more results here.
Model | Model Size (MB) | Adam Peak (MB) | Adan Peak (MB) | FusedAdan Peak (MB) | Adam Time (ms) | Adan Time (ms) | FusedAdan Time (ms) |
---|---|---|---|---|---|---|---|
ResNet-50 | 25 | 7142 | 7195 | 7176 | 9.0 | 4.2 | 1.9 |
ResNet-101 | 44 | 10055 | 10215 | 10160 | 17.5 | 7.0 | 3.4 |
ViT-B | 86 | 9755 | 9758 | 9758 | 8.9 | 12.3 | 4.3 |
Swin-B | 87 | 16118 | 16202 | 16173 | 17.9 | 12.8 | 4.9 |
ConvNext-B | 88 | 17353 | 17389 | 17377 | 19.1 | 15.6 | 5.0 |
Swin-L | 196 | 24299 | 24316 | 24310 | 17.5 | 28.1 | 10.1 |
ConvNext-L | 197 | 26025 | 26055 | 26044 | 18.6 | 31.1 | 10.2 |
ViT-L | 304 | 25652 | 25658 | 25656 | 18.0 | 43.2 | 15.1 |
GPT-2 | 758 | 25096 | 25406 | 25100 | 49.9 | 107.7 | 37.4 |
GPT-2 | 1313 | 34357 | 38595 | 34363 | 81.8 | 186.0 | 64.4 |