This is the official PyTorch implementation of RepLKNet, from the following CVPR-2022 paper:
Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs.
The paper is now released on arXiv: https://arxiv.org/abs/2203.06717.
Update: all the pretrained models, ImageNet-1K models, and Cityscapes/ADE20K/COCO models have been released.
Update: released a script to visualize the Effective Receptive Field (ERF). To get the ERF of your own model, you only need to add a few lines of code!
Update: released the training commands and more examples.
If you find the paper or this repository helpful, please consider citing
@article{replknet,
title={Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs},
author={Ding, Xiaohan and Zhang, Xiangyu and Zhou, Yizhuang and Han, Jungong and Ding, Guiguang and Sun, Jian},
journal={arXiv preprint arXiv:2203.06717},
year={2022}
}
framework | link |
---|---|
MegEngine (official) | https://github.com/megvii-research/RepLKNet |
PyTorch (official) | https://github.com/DingXiaoH/RepLKNet-pytorch |
Tensorflow | https://github.com/shkarupa-alex/tfreplknet |
PaddlePaddle | https://github.com/BR-IDL/PaddleViT/tree/develop/image_classification/RepLKNet |
... |
More re-implementations are welcomed.
We have released an example for PyTorch. Please check setup.py
and depthwise_conv2d_implicit_gemm.py
(a replacement of torch.nn.Conv2d) in https://github.com/MegEngine/cutlass/tree/master/examples/19_large_depthwise_conv2d_torch_extension.
unzip cutlass.zip
, enter the directory. This version of cutlass works fine with our large-kernel implementation and multiple python versions. You may alternatively use the cutlass branch maintained by the MegEngine team (clone https://github.com/MegEngine/cutlass), but you may need to be more careful with your python version (see this issue).cd examples/19_large_depthwise_conv2d_torch_extension
./setup.py install --user
. If you get errors, check yourCUDA_HOME
.- A quick check:
python depthwise_conv2d_implicit_gemm.py
- Add
WHERE_YOU_CLONED_CUTLASS/examples/19_large_depthwise_conv2d_torch_extension
into yourPYTHONPATH
so that you canfrom depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
anywhere. Then you may useDepthWiseConv2dImplicitGEMM
as a replacement ofnn.Conv2d
. export LARGE_KERNEL_CONV_IMPL=WHERE_YOU_CLONED_CUTLASS/examples/19_large_depthwise_conv2d_torch_extension
so that RepLKNet will use the efficient implementation. Or you may simply modify the related code (get_conv2d
) inreplknet.py
.
It should work with a wide range of GPUs and PyTorch/CUDA versions. We suggest you try first and check the environments only if you get any errors. Our latest testes used both
- Ubuntu 18.04 + CUDA 11.3 + nvcc 11.3 + cudnn 8.2.0 + python 3.8.12 + pytorch 1.10 + gcc 7.3.0 + nccl 2.10.3 + NVIDIA driver 450.102.04 + V100 and A100 GPUs
- Ubuntu 18.04 + CUDA 10.2 + nvcc 10.0 + cudnn 7.6.5 + python 3.6.9 + pytorch 1.9 + gcc 7.5.0 + nccl 2.7.8 + NVIDIA driver 460.32.03 + 2080Ti and V100 GPUs
It is reported (see here) that a python version mismatch may result in an error (forward_fp32.cu(212): error: more than one instance of constructor "cutlass::Tensor4DCoord::Tensor4DCoord" ...
or cutlass/include/cutlass/fast_math.h(741): error: no suitable conversion function from "__half" to "float" exists
). Please upgrade or downgrade your python. We sincerely thank @sleeplessai and @ewrfcas for sharing their experience.
Our implementation mentioned in the paper has been integrated into MegEngine. The engine will automatically use it. If you would like to use it in other frameworks like Tensorflow, you may need to compile our released cuda sources (the *.cu
files in the above example should work with other frameworks) and use some tools to load them, just like cutlass
and torch.utils.cpp_extension
in the PyTorch example. Would be appreciated if you could share with us your experience.
You may refer to the MegEngine source code: https://github.com/MegEngine/MegEngine/tree/8a2e92bd6c5ac02807b27d174dce090ee391000b/dnn/src/cuda/conv_bias/chanwise.
Pull requests (e.g., better or other implementations or implementations on other frameworks) are welcomed.
- Model code
- PyTorch pretrained models
- PyTorch large-kernel conv impl
- PyTorch training code
- PyTorch downstream models
- PyTorch downstream code
- A script to visualize the ERF
- How to obtain the shape bias
name | resolution | ImageNet-1K acc | #params | FLOPs | ImageNet-1K pretrained model |
---|---|---|---|---|---|
RepLKNet-31B | 224x224 | 83.5 | 79M | 15.3G | Google Drive, Baidu |
RepLKNet-31B | 384x384 | 84.8 | 79M | 45.1G | Google Drive, Baidu |
name | resolution | ImageNet-1K acc | #params | FLOPs | 22K pretrained model | 1K finetuned model |
---|---|---|---|---|---|---|
RepLKNet-31B | 224x224 | 85.2 | 79M | 15.3G | Google Drive, Baidu | Google Drive, Baidu |
RepLKNet-31B | 384x384 | 86.0 | 79M | 45.1G | - | Google Drive, Baidu |
RepLKNet-31L | 384x384 | 86.6 | 172M | 96.0G | Google Drive, Baidu | Google Drive, Baidu |
name | resolution | ImageNet-1K acc | #params | FLOPs | MegData-73M pretrained model | 1K finetuned model |
---|---|---|---|---|---|---|
RepLKNet-XL | 320x320 | 87.8 | 335M | 128.7G | Google Drive, Baidu | Google Drive, Baidu |
For RepLKNet-31B/L with 224x224 or 384x384, we use the "IMAGENET_DEFAULT_MEAN/STD" for preprocessing (see here). For examples,
python -m torch.distributed.launch --nproc_per_node=8 main.py --model RepLKNet-31B --batch_size 32 --eval True --resume RepLKNet-31B_ImageNet-1K_224.pth --input_size 224
or
python -m torch.distributed.launch --nproc_per_node=8 main.py --model RepLKNet-31L --batch_size 32 --eval True --resume RepLKNet-31L_ImageNet-22K-to-1K_384.pth --input_size 384
For RepLKNet-XL, please note that we used mean=[0.5,0.5,0.5]
and std=[0.5,0.5,0.5]
for preprocessing on MegData73M dataset as well as finetuning on ImageNet-1K. This mean/std setting is also referred to as "IMAGENET_INCEPTION_MEAN/STD" in timm, see here. Add --imagenet_default_mean_and_std false
to use this mean/std setting (see here). As noted in the paper, we did not use small kernels for re-parameterization.
python -m torch.distributed.launch --nproc_per_node=8 main.py --model RepLKNet-XL --batch_size 32 --eval true --resume RepLKNet-XL_MegData73M_ImageNet1K.pth --imagenet_default_mean_and_std false --input_size 320
To verify the equivalency of Structural Re-parameterization (i.e., the outputs before and after structural_reparam
), add --with_small_kernel_merged true
.
You may use multi-node training on a SLURM cluster with submitit. Please install:
pip install submitit
If you have limited GPU memory (e.g., 2080Ti), use --use_checkpoint true
to save GPU memory.
Single machine (note --update_freq 4
):
python -m torch.distributed.launch --nproc_per_node=8 main.py --model RepLKNet-31B --drop_path 0.5 --batch_size 64 --lr 4e-3 --update_freq 4 --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k --warmup_epochs 10 --epochs 300 --output_dir your_training_dir
Four machines (note --update_freq 1
):
python run_with_submitit.py --nodes 4 --ngpus 8 --model RepLKNet-31B --drop_path 0.5 --batch_size 64 --lr 4e-3 --update_freq 1 --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k --warmup_epochs 10 --epochs 300 --job_dir your_training_dir
In the following, we only present multi-machine commands. You may train with a single machine in a similar way.
python run_with_submitit.py --nodes 4 --ngpus 8 --model RepLKNet-31B --drop_path 0.8 --input_size 384 --batch_size 32 --lr 4e-4 --epochs 30 --weight_decay 1e-8 --update_freq 1 --cutmix 0 --mixup 0 --finetune RepLKNet-31B_ImageNet-1K_224.pth --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k --warmup_epochs 1 --job_dir your_training_dir --layer_decay 0.7
python run_with_submitit.py --nodes 16 --ngpus 8 --model RepLKNet-31B --drop_path 0.1 --batch_size 32 --lr 4e-3 --update_freq 1 --warmup_epochs 5 --epochs 90 --data_set image_folder --nb_classes 21841 --disable_eval true --data_path /path/to/imagenet-22k --job_dir /path/to/save_results
python run_with_submitit.py --nodes 2 --ngpus 8 --model RepLKNet-31B --drop_path 0.2 --input_size 224 --batch_size 32 --lr 4e-4 --epochs 30 --weight_decay 1e-8 --update_freq 1 --cutmix 0 --mixup 0 --finetune RepLKNet-31B_ImageNet-22K.pth --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k --warmup_epochs 1 --job_dir your_training_dir --layer_decay 0.7
python run_with_submitit.py --nodes 4 --ngpus 8 --model RepLKNet-31B --drop_path 0.3 --input_size 384 --batch_size 16 --lr 4e-4 --epochs 30 --weight_decay 1e-8 --update_freq 1 --cutmix 0 --mixup 0 --finetune RepLKNet-31B_ImageNet-22K.pth --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k --warmup_epochs 1 --job_dir your_training_dir --layer_decay 0.7 --min_lr 3e-4
python run_with_submitit.py --nodes 16 --ngpus 8 --model RepLKNet-31L --drop_path 0.1 --batch_size 32 --lr 4e-3 --update_freq 1 --warmup_epochs 5 --epochs 90 --data_set image_folder --nb_classes 21841 --disable_eval true --data_path /path/to/imagenet-22k --job_dir /path/to/save_results
python run_with_submitit.py --nodes 4 --ngpus 8 --model RepLKNet-31L --drop_path 0.3 --input_size 384 --batch_size 16 --lr 4e-4 --epochs 30 --weight_decay 1e-8 --update_freq 1 --cutmix 0 --mixup 0 --finetune RepLKNet-31L_ImageNet-22K.pth --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k --warmup_epochs 1 --job_dir your_training_dir --layer_decay 0.7 --min_lr 3e-4
We use MMSegmentation and MMDetection frameworks. Just clone MMSegmentation or MMDetection, and
- Put
segmentation/replknet.py
intommsegmentation/mmseg/models/backbones/
ormmdetection/mmdet/models/backbones/
. The only difference betweensegmentation/replknet.py
andreplknet.py
is the@BACKBONES.register_module
. - Add RepLKNet into
mmsegmentation/mmseg/models/backbones/__init__.py
ormmdetection/mmdet/models/backbones/__init__.py
. That is
...
from .replknet import RepLKNet
__all__ = ['ResNet', ..., 'RepLKNet']
- Put
segmentation/configs/*.py
intommsegmentation/configs/replknet/
ordetection/configs/*.py
intommdetection/configs/replknet/
- Download and use our weights. For examples, to evaluate RepLKNet-31B + UperNet on Cityscapes
python -m torch.distributed.launch --nproc_per_node=8 tools/test.py configs/replknet/RepLKNet-31B_1Kpretrain_upernet_80k_cityscapes_769.py RepLKNet-31B_ImageNet-1K_UperNet_Cityscapes.pth --launcher pytorch --eval mIoU
or RepLKNet-31B + Cascade Mask R-CNN on COCO
python -m torch.distributed.launch --nproc_per_node=8 tools/test.py configs/replknet/RepLKNet-31B_22Kpretrain_cascade_mask_rcnn_3x_coco.py RepLKNet-31B_ImageNet-22K_CascMaskRCNN_COCO.pth --eval bbox --launcher pytorch
- Or you may finetune our released pretrained weights (see the tips below about the batch size and number of iterations)
python -m torch.distributed.launch --nproc_per_node=8 tools/train.py configs/replknet/some_config.py --launcher pytorch --options model.backbone.pretrained=some_pretrained_weights.pth
We have released all the Cityscapes/ADE20K/COCO model weights.
Single-scale (ss) and multi-scale (ms) mIoU tested with UperNet (FLOPs is computed with 2048×512 for the ImageNet-1K pretrained models and 2560×640 for the 22K and MegData73M pretrained models, following Swin):
backbone | pretraining | dataset | train schedule | mIoU (ss) | mIoU (ms) | #params | FLOPs | download |
---|---|---|---|---|---|---|---|---|
RepLKNet-31B | ImageNet-1K | Cityscapes | 80k | 83.1 | 83.5 | 110M | 2315G | Google Drive, Baidu |
RepLKNet-31B | ImageNet-1K | ADE20K | 160k | 49.9 | 50.6 | 112M | 1170G | Google Drive, Baidu |
RepLKNet-31B | ImageNet-22K | ADE20K | 160k | 51.5 | 52.3 | 112M | 1829G | Google Drive, Baidu |
RepLKNet-31L | ImageNet-22K | ADE20K | 160k | 52.4 | 52.7 | 207M | 2404G | Google Drive, Baidu |
RepLKNet-XL | MegData73M | ADE20K | 160k | 55.2 | 56.0 | 374M | 3431G | Google Drive, Baidu |
Cascade Mask R-CNN on COCO (FLOPs is computed with 1280x800):
backbone | pretraining | method | train schedule | AP_box | AP_mask | #params | FLOPs | download |
---|---|---|---|---|---|---|---|---|
RepLKNet-31B | ImageNet-1K | FCOS | 2x | 47.0 | - | 87M | 437G | Google Drive, Baidu |
RepLKNet-31B | ImageNet-1K | Cascade Mask RCNN | 3x | 52.2 | 45.2 | 137M | 965G | Google Drive, Baidu |
RepLKNet-31B | ImageNet-22K | Cascade Mask RCNN | 3x | 53.0 | 46.0 | 137M | 965G | Google Drive, Baidu |
RepLKNet-31L | ImageNet-22K | Cascade Mask RCNN | 3x | 53.9 | 46.5 | 229M | 1321G | Google Drive, Baidu |
RepLKNet-XL | MegData73M | Cascade Mask RCNN | 3x | 55.5 | 48.0 | 392M | 1958G | Google Drive, Baidu |
- The mean/std values on MegData73M are different from ImageNet. So we used
mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
for pretraining RepLKNet-XL on MegData73M and finetuning on ImageNet-1K. Accordingly, we should letimg_norm_cfg = dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
in MMSegmentation and MMDetection. Please check here and here. For other models, we use the default ImageNet mean/std. - For RepLKNet-XL on ADE20K and COCO, we batch-normalize the intermediate feature maps before feeding them into the heads. Just use
RepLKNet(..., norm_intermediate_features=True)
. We did not try such design on the other models, so we are not sure if it is significant. - For RepLKNet-31B/L on Cityscapes and ADE20K, we used 4 or 8 2080Ti nodes each with 8 GPUs, the batch size per GPU was smaller than the default (the default is 4 per GPU, see here), but the global batch size was larger. Accordingly, we reduced the number of iterations to ensure the same total training samples. Please check the comments in the config files. If you wish to train with our config files, please set the batch size and number of iterations according to your own situation.
- Lowering the learning rate for lower-level layers may improve the performance when finetuning on ImageNet-1K or downstream tasks, just like ConvNeXt and BeiT. We are not sure if the improvements would be significant. For ImageNet, our implementation simply follows ConvNeXt and BeiT. For MMSegmentation and MMDetection, please raise an issue if you need a showcase,
- Tips on the drop_path_rate: bigger model, higher drop_path; bigger pretraining data, lower drop_path.
We have released our script to visualize and analyze the Effective Receptive Field (ERF). For example, to automatically download the ResNet-101 from torchvision and obtain the aggregated contribution score matrix,
python erf/visualize_erf.py --model resnet101 --data_path /path/to/imagenet-1k --save_path resnet101_erf_matrix.npy
Then calculate the high-contribution area ratio and visualize the ERF by
python erf/analyze_erf.py --source resnet101_erf_matrix.npy --heatmap_save resnet101_heatmap.png
Note this plotting script works with matplotlib 3.3. If you use a higher version of matplotlib, see the comments here.
To visualize your own model, first define a model that outputs the last feature map rather than the logits (following this example), add the code for building model and loading weights here, then
python erf/visualize_erf.py --model your_model --weights /path/to/your/weights --data_path /path/to/imagenet-1k --save_path your_model_erf_matrix.npy
To reproduced the results in the paper, please download the RepLKNet-13 (Google Drive, Baidu) and RepLKNet-31 (Google Drive, Baidu) models trained in 120 epochs.
- Install https://github.com/bethgelab/model-vs-human
- Add your code for building model and loading weights in this file. For example
@register_model("pytorch")
def replknet(model_name, *args):
model = ...
model.load_state_dict(...)
return model
- Modify examples/evaluate.py (
models = ['replknet']
) and examples/plotting_definition.py (decision_makers.append(DecisionMaker(name_pattern="replknet", ...))
), following its examples.
The released PyTorch training script is based on the code of ConvNeXt, which was built using the timm library, DeiT and BEiT repositories.
This project is released under the MIT license. Please see the LICENSE file for more information.
xiaohding@gmail.com (The original Tsinghua mailbox dxh17@mails.tsinghua.edu.cn will expire in several months)
Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en
Homepage: https://dingxiaohan.xyz/
My open-sourced papers and repos:
The Structural Re-parameterization Universe:
-
RepLKNet (CVPR 2022) Powerful efficient architecture with very large kernels (31x31) and guidelines for using large kernels in model CNNs
Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
code. -
RepOptimizer uses Gradient Re-parameterization to train powerful models efficiently. The training-time model is as simple as the inference-time. It also addresses the problem of quantization.
Re-parameterizing Your Optimizers rather than Architectures
code. -
RepVGG (CVPR 2021) A super simple and powerful VGG-style ConvNet architecture. Up to 84.16% ImageNet top-1 accuracy!
RepVGG: Making VGG-style ConvNets Great Again
code. -
RepMLP (CVPR 2022) MLP-style building block and Architecture
RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality
code. -
ResRep (ICCV 2021) State-of-the-art channel pruning (Res50, 55% FLOPs reduction, 76.15% acc)
ResRep: Lossless CNN Pruning via Decoupling Remembering and Forgetting
code. -
ACB (ICCV 2019) is a CNN component without any inference-time costs. The first work of our Structural Re-parameterization Universe.
ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks.
code. -
DBB (CVPR 2021) is a CNN component with higher performance than ACB and still no inference-time costs. Sometimes I call it ACNet v2 because "DBB" is 2 bits larger than "ACB" in ASCII (lol).
Diverse Branch Block: Building a Convolution as an Inception-like Unit
code.
Model compression and acceleration:
-
(CVPR 2019) Channel pruning: Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure
code -
(ICML 2019) Channel pruning: Approximated Oracle Filter Pruning for Destructive CNN Width Optimization
code -
(NeurIPS 2019) Unstructured pruning: Global Sparse Momentum SGD for Pruning Very Deep Neural Networks
code