Accepted by ICLR 2023!
(update: released RepOpt-GhostNet model and reproducible code)
This is the official repository of Re-parameterizing Your Optimizers rather than Architectures.
If you find the paper or this repository helpful, please consider citing
@article{ding2022re,
title={Re-parameterizing Your Optimizers rather than Architectures},
author={Ding, Xiaohan and Chen, Honghao and Zhang, Xiangyu and Huang, Kaiqi and Han, Jungong and Ding, Guiguang},
journal={arXiv preprint arXiv:2205.15242},
year={2022}
}
RepOptimizer and RepOpt-VGG have been used in YOLOv6 (paper, code) and deployed in business. The methodology of Structural Re-parameterization also plays a critical role in YOLOv7 (paper, code).
- Code
- PyTorch pretrained models
- PyTorch training code
Tired of reading proof? We provide a script demonstrating the equivalency of GR = CSLA in both SGD and AdamW cases.
You may run the following script to verify the equivalency (GR = CSLA) on an experimental model. You can run without a GPU.
python check_equivalency.py
We provide a more complicated example (RepGhostNet) to verify the equivalency. Just run
python check_equivalency_repghost.py
RepOptimizers currently support two update rules (SGD with momentum and AdamW) and two models (RepOpt-VGG and RepOpt-GhostNet). While re-designing the code of RepOptimizer, I decided to separate the update-rule-related behaviors and model-specific behaviors.
The key components of the new implementation (please see repoptimizer/
) include
Model: repoptvgg_model.py
and repoptghostnet_model.py
define the model architecutres, including the target and search structures.
Model-specific Handler: a RepOptimizerHandler
defines the model-specific behavior of RepOptimizer given the searched scales, which include 1) re-initializing the model (i.e., Rule of Initialization) and 2) generating the Grad Mults (i.e., Rule of Iteration).
For example, RepOptVGGHandler
(see repoptvgg_impl.py
) implements the formulas presented in the paper.
Update rule: repoptimizer_sgd.py
and repoptimizer_adamw.py
define the behavior of RepOptimizers based on different update rules. The differences between a RepOptimizer and its regular counterpart (torch.optim.SGD
or torch.optim.AdamW
) include
-
RepOptimizers take one more argument,
grad_mult_map
, which is the output from RepOptimizerHandler and will be stored in memory. It is a dict where the key is the parameter (torch.nn.Parameter
) and the value is the corresponding Grad Mult (torch.Tensor
). -
In the
step
function, RepOptimizers will use the Grad Mults properly. For SGD, please see here. For AdamW, please see here and here.
We have released the models pre-trained with this codebase.
name | ImageNet-1K acc | #params | download |
---|---|---|---|
RepOpt-VGG-B1 | 78.62 | 51.8M | Google Drive, Baidu Cloud |
RepOpt-VGG-B2 | 79.68 | 80.3M | Google Drive, Baidu Cloud |
RepOpt-VGG-L1 | 79.82 | 76.0M | Google Drive, Baidu Cloud |
RepOpt-VGG-L2 | 80.47 | 118.1M | Google Drive, Baidu Cloud |
The following cases use RepOpt-VGG-B1 as an example. You may replace RepOpt-VGG-B1
by RepOpt-VGG-B2
, RepOpt-VGG-L1
, or RepOpt-VGG-L2
as you wish.
You may test our released models by
python -m torch.distributed.launch --nproc_per_node {your_num_gpus} --master_port 12349 main_repopt.py --arch RepOpt-VGG-B1-target --tag test --eval --resume RepOpt-VGG-B1-acc78.62.pth --data-path /path/to/imagenet --batch-size 32 --opts DATA.DATASET imagenet
To reproduce RepOpt-VGG-B1, you may build a RepOptimizer with our released constants RepOpt-VGG-B1-scales.pth
and train a RepOpt-VGG-B1 with it.
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch RepOpt-VGG-B1-target --batch-size 32 --tag experiment --scales-path RepOpt-VGG-B1-scales.pth --opts TRAIN.EPOCHS 120 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET raug15 DATA.DATASET imagenet
The log and weights will be saved to output/RepOpt-VGG-B1-target/experiment/
Besides using our released scales, you may Hyper-Search by
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/search/dataset --arch RepOpt-VGG-B1-hs --batch-size 32 --tag search --opts TRAIN.EPOCHS 240 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 10 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET cf100 TRAIN.CLIP_GRAD 5.0
(Note that since the model seems too big for such a small dataset, we use grad clipping to stablize the training. But do not use grad clipping while training with RepOptimizer! That would break the equivalency.)
The weights of the search model will be saved to output/RepOpt-VGG-B1-hs/search/latest.pth
Then you may train with it by
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch RepOpt-VGG-B1-target --batch-size 32 --tag experiment --scales-path output/RepOpt-VGG-B1-hs/search/latest.pth --opts TRAIN.EPOCHS 120 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET raug15 DATA.DATASET imagenet
Given the searched scales (saved in a .pth
file), you may conveniently build a RepOptimizer and a RepOpt-VGG model and use them just like you use the common optimizers and models.
Please see build_RepOptVGG_and_SGD_optimizer_from_pth
here.
RepGhostNet is a recently proposed lightweight model built with Structural Re-parameterization. The training-time forward function of a block can be formulated as output=batch_norm(depthwise_convolution(x)) + batch_norm(x)
. With RepOptimizer, the parallel batch norm (referred to as "fusion layer" in the RepGhostNet paper) can be removed even during training. Similar to RepVGG and RepOpt-VGG, we design the CSLA model by replacing the batch norm layers with constant or trainable scaling layers and the Grad Mults of RepOptimizer accordingly.
name | ImageNet-1K acc | download |
---|---|---|
RepGhostNet-0.5x (our implementation) | 66.51 | Google Drive, Baidu Cloud |
RepOpt-GhostNet-0.5x | 66.50 | Google Drive, Baidu Cloud |
We trained the original RepGhostNet-0.5x with this codebase and got a top-1 accuracy of 66.51%.
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch ghost-rep --batch-size 128 --tag reproduce --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET imagenet TRAIN.OPTIMIZER.NAME sgd TRAIN.WARMUP_LR 1e-4
The log and weights will be saved to output/ghost-rep/reproduce/
You may reproduce RepOpt-GhostNet with our released scales
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch ghost-target --batch-size 128 --tag reproduce --scales-path RepOptGhostNet_0_5x_scales.pth --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET imagenet TRAIN.OPTIMIZER.NAME sgd TRAIN.WARMUP_LR 1e-4
Or first Hyper-Search and then use the searched scales
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/cifar100 --arch ghost-hs --batch-size 128 --tag reproduce --opts TRAIN.EPOCHS 600 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 10 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET cf100 TRAIN.CLIP_GRAD 5.0
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch ghost-target --batch-size 128 --tag reproduce --scales-path output/ghost-hs/reproduce/latest.pth --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET imagenet TRAIN.OPTIMIZER.NAME sgd TRAIN.WARMUP_LR 1e-4
This project is released under the MIT license. Please see the LICENSE file for more information.
xiaohding@gmail.com (The original Tsinghua mailbox dxh17@mails.tsinghua.edu.cn will expire in several months)
Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en
Homepage: https://dingxiaohan.xyz/
My open-sourced papers and repos:
The Structural Re-parameterization Universe:
-
RepLKNet (CVPR 2022) Powerful efficient architecture with very large kernels (31x31) and guidelines for using large kernels in model CNNs
Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
code. -
RepOptimizer (ICLR 2023) uses Gradient Re-parameterization to train powerful models efficiently. The training-time RepOpt-VGG is as simple as the inference-time. It also addresses the problem of quantization.
Re-parameterizing Your Optimizers rather than Architectures
code. -
RepVGG (CVPR 2021) A super simple and powerful VGG-style ConvNet architecture. Up to 84.16% ImageNet top-1 accuracy!
RepVGG: Making VGG-style ConvNets Great Again
code. -
RepMLP (CVPR 2022) MLP-style building block and Architecture
RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality
code. -
ResRep (ICCV 2021) State-of-the-art channel pruning (Res50, 55% FLOPs reduction, 76.15% acc)
ResRep: Lossless CNN Pruning via Decoupling Remembering and Forgetting
code. -
ACB (ICCV 2019) is a CNN component without any inference-time costs. The first work of our Structural Re-parameterization Universe.
ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks.
code. -
DBB (CVPR 2021) is a CNN component with higher performance than ACB and still no inference-time costs. Sometimes I call it ACNet v2 because "DBB" is 2 bits larger than "ACB" in ASCII (lol).
Diverse Branch Block: Building a Convolution as an Inception-like Unit
code.
Model compression and acceleration:
-
(CVPR 2019) Channel pruning: Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure
code -
(ICML 2019) Channel pruning: Approximated Oracle Filter Pruning for Destructive CNN Width Optimization
code -
(NeurIPS 2019) Unstructured pruning: Global Sparse Momentum SGD for Pruning Very Deep Neural Networks
code