-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add Autoformer algorithm (#315)
* update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut
- Loading branch information
Yue Sun
authored
Nov 14, 2022
1 parent
9c567e4
commit fb42405
Showing
68 changed files
with
3,598 additions
and
260 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# dataset settings | ||
dataset_type = 'mmcls.ImageNet' | ||
preprocess_cfg = dict( | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
bgr_mean = preprocess_cfg['mean'][::-1] | ||
bgr_std = preprocess_cfg['std'][::-1] | ||
|
||
# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models | ||
rand_increasing_policies = [ | ||
dict(type='mmcls.AutoContrast'), | ||
dict(type='mmcls.Equalize'), | ||
dict(type='mmcls.Invert'), | ||
dict(type='mmcls.Rotate', magnitude_key='angle', magnitude_range=(0, 30)), | ||
dict(type='mmcls.Posterize', magnitude_key='bits', magnitude_range=(4, 0)), | ||
dict(type='mmcls.Solarize', magnitude_key='thr', magnitude_range=(256, 0)), | ||
dict( | ||
type='mmcls.SolarizeAdd', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 110)), | ||
dict( | ||
type='mmcls.ColorTransform', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.9)), | ||
dict( | ||
type='mmcls.Contrast', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.9)), | ||
dict( | ||
type='mmcls.Brightness', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.9)), | ||
dict( | ||
type='mmcls.Sharpness', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.9)), | ||
dict( | ||
type='mmcls.Shear', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.3), | ||
direction='horizontal'), | ||
dict( | ||
type='mmcls.Shear', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.3), | ||
direction='vertical'), | ||
dict( | ||
type='mmcls.Translate', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.45), | ||
direction='horizontal'), | ||
dict( | ||
type='mmcls.Translate', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.45), | ||
direction='vertical') | ||
] | ||
|
||
train_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict( | ||
type='mmcls.RandomResizedCrop', | ||
scale=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), | ||
dict( | ||
type='mmcls.RandAugment', | ||
policies=rand_increasing_policies, | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=9, | ||
magnitude_std=0.5, | ||
hparams=dict( | ||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), | ||
dict( | ||
type='mmcls.RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=bgr_mean, | ||
fill_std=bgr_std), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict( | ||
type='mmcls.ResizeEdge', | ||
scale=248, | ||
edge='short', | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='mmcls.CenterCrop', crop_size=224), | ||
dict(type='mmcls.PackClsInputs') | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=64, | ||
num_workers=6, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='mmcls.RepeatAugSampler'), | ||
persistent_workers=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=256, | ||
num_workers=6, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/val.txt', | ||
data_prefix='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# optimizer | ||
paramwise_cfg = dict( | ||
bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0) | ||
|
||
optim_wrapper = dict( | ||
optimizer=dict( | ||
type='AdamW', | ||
lr=0.002, | ||
weight_decay=0.05, | ||
eps=1e-8, | ||
betas=(0.9, 0.999)), | ||
# specific to vit pretrain | ||
paramwise_cfg=dict(custom_keys={ | ||
'.cls_token': dict(decay_mult=0.0), | ||
'.pos_embed': dict(decay_mult=0.0) | ||
})) | ||
|
||
# leanring policy | ||
param_scheduler = [ | ||
# warm up learning rate scheduler | ||
dict( | ||
type='LinearLR', | ||
start_factor=1e-3, | ||
by_epoch=True, | ||
begin=0, | ||
# about 10000 iterations for ImageNet-1k | ||
end=20, | ||
# update by iter | ||
convert_to_iter_based=True), | ||
# main learning rate scheduler | ||
dict( | ||
type='CosineAnnealingLR', | ||
T_max=500, | ||
eta_min=1e-5, | ||
by_epoch=True, | ||
begin=20, | ||
end=500, | ||
convert_to_iter_based=True), | ||
] | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=500) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
auto_scale_lr = dict(base_batch_size=2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# AutoFormer | ||
|
||
> [Searching Transformers for Visual Recognition](https://arxiv.org/abs/2107.00651) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Recently, pure transformer-based models have shown | ||
great potentials for vision tasks such as image classification and detection. However, the design of transformer networks is challenging. It has been observed that the depth, | ||
embedding dimension, and number of heads can largely affect the performance of vision transformers. Previous models configure these dimensions based upon manual crafting. In this work, we propose a new one-shot architecture | ||
search framework, namely AutoFormer, dedicated to vision | ||
transformer search. AutoFormer entangles the weights of | ||
different blocks in the same layers during supernet training. Benefiting from the strategy, the trained supernet allows thousands of subnets to be very well-trained. Specifically, the performance of these subnets with weights inherited from the supernet is comparable to those retrained | ||
from scratch. Besides, the searched models, which we refer to AutoFormers, surpass the recent state-of-the-arts such | ||
as ViT and DeiT. In particular, AutoFormer-tiny/small/base | ||
achieve 74.7%/81.7%/82.4% top-1 accuracy on ImageNet | ||
with 5.7M/22.9M/53.7M parameters, respectively. Lastly, | ||
we verify the transferability of AutoFormer by providing | ||
the performance on downstream benchmarks and distillation experiments. | ||
|
||
![pipeline](/docs/en/imgs/model_zoo/autoformer/pipeline.png) | ||
|
||
## Introduction | ||
|
||
### Supernet pre-training on ImageNet | ||
|
||
```bash | ||
python ./tools/train.py \ | ||
configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py \ | ||
--work-dir $WORK_DIR | ||
``` | ||
|
||
### Search for subnet on the trained supernet | ||
|
||
```bash | ||
sh tools/train.py \ | ||
configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py \ | ||
$STEP1_CKPT \ | ||
--work-dir $WORK_DIR | ||
``` | ||
|
||
## Results and models | ||
|
||
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | Remarks | | ||
| :------: | :------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------: | | ||
| ImageNet | vit | [mutable](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/spos_shufflenetv2_subnet_8xb128_in1k/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-454627be_mutable_cfg.yaml?versionId=CAEQHxiBgICw5b6I7xciIGY5MjVmNWFhY2U5MjQzN2M4NDViYzI2YWRmYWE1YzQx) | 52.472 | 10.2 | 82.48 | 95.99 | [config](./autoformer_supernet_32xb256_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/x.pth) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/nas/spos/x.log.json) | MMRazor searched | | ||
|
||
**Note**: | ||
|
||
1. There are some small differences in our experiment in order to be consistent with mmrazor repo. For example, we set the max value of embed_channels 624 while the original repo set it 640. However, the original repo only search 528, 576, 624 embed_channels, so set 624 can also get the same result with orifinal paper. | ||
2. The original paper get 82.4 top-1 acc with 53.7M Params while we get 82.48 top-1 acc with 52.47M Params. | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{xu2021autoformer, | ||
title={Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting}, | ||
author={Xu, Jiehui and Wang, Jianmin and Long, Mingsheng and others}, | ||
journal={Advances in Neural Information Processing Systems}, | ||
volume={34}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
Footer |
17 changes: 17 additions & 0 deletions
17
configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
_base_ = ['./autoformer_supernet_32xb256_in1k.py'] | ||
|
||
custom_hooks = None | ||
|
||
train_cfg = dict( | ||
_delete_=True, | ||
type='mmrazor.EvolutionSearchLoop', | ||
dataloader=_base_.val_dataloader, | ||
evaluator=_base_.val_evaluator, | ||
max_epochs=20, | ||
num_candidates=20, | ||
top_k=10, | ||
num_mutation=5, | ||
num_crossover=5, | ||
mutate_prob=0.2, | ||
constraints_range=dict(params=(0, 55)), | ||
score_key='accuracy/top1') |
79 changes: 79 additions & 0 deletions
79
configs/nas/mmcls/autoformer/autoformer_supernet_32xb256_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
_base_ = [ | ||
'mmrazor::_base_/settings/imagenet_bs2048_AdamW.py', | ||
'mmcls::_base_/default_runtime.py', | ||
] | ||
|
||
# data preprocessor | ||
data_preprocessor = dict( | ||
_scope_='mmcls', | ||
type='ClsDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
num_classes=1000, | ||
batch_augments=dict( | ||
augments=[ | ||
dict(type='Mixup', alpha=0.2), | ||
dict(type='CutMix', alpha=1.0) | ||
], | ||
probs=[0.5, 0.5])) | ||
|
||
arch_setting = dict( | ||
mlp_ratios=[3.0, 3.5, 4.0], | ||
num_heads=[8, 9, 10], | ||
depth=[14, 15, 16], | ||
embed_dims=[528, 576, 624]) | ||
|
||
supernet = dict( | ||
_scope_='mmrazor', | ||
type='SearchableImageClassifier', | ||
data_preprocessor=data_preprocessor, | ||
backbone=dict( | ||
_scope_='mmrazor', | ||
type='AutoformerBackbone', | ||
arch_setting=arch_setting), | ||
neck=None, | ||
head=dict( | ||
type='DynamicLinearClsHead', | ||
num_classes=1000, | ||
in_channels=624, | ||
loss=dict( | ||
type='mmcls.LabelSmoothLoss', | ||
mode='original', | ||
num_classes=1000, | ||
label_smooth_val=0.1, | ||
loss_weight=1.0), | ||
topk=(1, 5)), | ||
connect_head=dict(connect_with_backbone='backbone.last_mutable'), | ||
) | ||
|
||
model = dict( | ||
type='mmrazor.Autoformer', | ||
architecture=supernet, | ||
fix_subnet=None, | ||
mutators=dict( | ||
channel_mutator=dict( | ||
type='mmrazor.OneShotChannelMutator', | ||
channel_unit_cfg={ | ||
'type': 'OneShotMutableChannelUnit', | ||
'default_args': { | ||
'unit_predefined': True | ||
} | ||
}, | ||
parse_cfg={'type': 'Predefined'}), | ||
value_mutator=dict(type='mmrazor.DynamicValueMutator'))) | ||
|
||
# runtime setting | ||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')] | ||
|
||
# checkpoint saving | ||
_base_.default_hooks.checkpoint = dict( | ||
type='CheckpointHook', | ||
interval=2, | ||
by_epoch=True, | ||
save_best='accuracy/top1', | ||
max_keep_ckpts=3) | ||
|
||
find_unused_parameters = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.