-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add Dsnas Algorithm (#226)
* [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP
- Loading branch information
Showing
18 changed files
with
1,187 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
norm_cfg = dict(type='BN', eps=0.01) | ||
|
||
_STAGE_MUTABLE = dict( | ||
type='mmrazor.OneHotMutableOP', | ||
fix_threshold=0.3, | ||
candidates=dict( | ||
shuffle_3x3=dict( | ||
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg), | ||
shuffle_5x5=dict( | ||
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg), | ||
shuffle_7x7=dict( | ||
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg), | ||
shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg))) | ||
|
||
arch_setting = [ | ||
# Parameters to build layers. 3 parameters are needed to construct a | ||
# layer, from left to right: channel, num_blocks, mutable_cfg. | ||
[64, 4, _STAGE_MUTABLE], | ||
[160, 4, _STAGE_MUTABLE], | ||
[320, 8, _STAGE_MUTABLE], | ||
[640, 4, _STAGE_MUTABLE] | ||
] | ||
|
||
nas_backbone = dict( | ||
type='mmrazor.SearchableShuffleNetV2', | ||
widen_factor=1.0, | ||
arch_setting=arch_setting, | ||
norm_cfg=norm_cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# dataset settings | ||
dataset_type = 'mmcls.ImageNet' | ||
data_preprocessor = dict( | ||
type='mmcls.ClsDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict(type='mmcls.RandomResizedCrop', scale=224), | ||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'), | ||
dict(type='mmcls.CenterCrop', crop_size=224), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=128, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=128, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/val.txt', | ||
data_prefix='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# optimizer | ||
paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0) | ||
|
||
optim_wrapper = dict( | ||
constructor='mmrazor.SeparateOptimWrapperConstructor', | ||
architecture=dict( | ||
optimizer=dict( | ||
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5), | ||
paramwise_cfg=paramwise_cfg), | ||
mutator=dict( | ||
optimizer=dict( | ||
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5, | ||
0.999)))) | ||
|
||
search_epochs = 85 | ||
# leanring policy | ||
param_scheduler = dict( | ||
architecture=[ | ||
dict( | ||
type='mmcls.LinearLR', | ||
end=5, | ||
start_factor=0.2, | ||
by_epoch=True, | ||
convert_to_iter_based=True), | ||
dict( | ||
type='mmcls.CosineAnnealingLR', | ||
T_max=240, | ||
begin=5, | ||
end=search_epochs, | ||
by_epoch=True, | ||
convert_to_iter_based=True), | ||
dict( | ||
type='mmcls.CosineAnnealingLR', | ||
T_max=160, | ||
begin=search_epochs, | ||
end=240, | ||
eta_min=0.0, | ||
by_epoch=True, | ||
convert_to_iter_based=True) | ||
], | ||
mutator=[]) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=240) | ||
val_cfg = dict() | ||
test_cfg = dict() |
20 changes: 20 additions & 0 deletions
20
configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
backbone.layers.0.0: shuffle_3x3 | ||
backbone.layers.0.1: shuffle_3x3 | ||
backbone.layers.0.2: shuffle_xception | ||
backbone.layers.0.3: shuffle_3x3 | ||
backbone.layers.1.0: shuffle_xception | ||
backbone.layers.1.1: shuffle_7x7 | ||
backbone.layers.1.2: shuffle_3x3 | ||
backbone.layers.1.3: shuffle_3x3 | ||
backbone.layers.2.0: shuffle_xception | ||
backbone.layers.2.1: shuffle_xception | ||
backbone.layers.2.2: shuffle_7x7 | ||
backbone.layers.2.3: shuffle_xception | ||
backbone.layers.2.4: shuffle_xception | ||
backbone.layers.2.5: shuffle_xception | ||
backbone.layers.2.6: shuffle_7x7 | ||
backbone.layers.2.7: shuffle_3x3 | ||
backbone.layers.3.0: shuffle_3x3 | ||
backbone.layers.3.1: shuffle_xception | ||
backbone.layers.3.2: shuffle_xception | ||
backbone.layers.3.3: shuffle_3x3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# DSNAS | ||
|
||
> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Most existing NAS methods require two-stage parameter optimization. | ||
However, performance of the same architecture in the two stages correlates poorly. | ||
Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining. | ||
|
||
![pipeline](/docs/en/imgs/model_zoo/dsnas/pipeline.jpg) | ||
|
||
## Results and models | ||
|
||
### Supernet | ||
|
||
| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks | | ||
| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: | | ||
| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched | | ||
|
||
**Note**: | ||
|
||
1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example, | ||
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.** | ||
2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51. | ||
3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the | ||
origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59. | ||
4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS. | ||
5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo. | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{hu2020dsnas, | ||
title={Dsnas: Direct neural architecture search without parameter retraining}, | ||
author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={12084--12092}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
_base_ = ['./dsnas_supernet_8xb128_in1k.py'] | ||
|
||
# NOTE: Replace this with the mutable_cfg searched by yourself. | ||
fix_subnet = { | ||
'backbone.layers.0.0': 'shuffle_3x3', | ||
'backbone.layers.0.1': 'shuffle_7x7', | ||
'backbone.layers.0.2': 'shuffle_3x3', | ||
'backbone.layers.0.3': 'shuffle_5x5', | ||
'backbone.layers.1.0': 'shuffle_3x3', | ||
'backbone.layers.1.1': 'shuffle_3x3', | ||
'backbone.layers.1.2': 'shuffle_3x3', | ||
'backbone.layers.1.3': 'shuffle_7x7', | ||
'backbone.layers.2.0': 'shuffle_xception', | ||
'backbone.layers.2.1': 'shuffle_3x3', | ||
'backbone.layers.2.2': 'shuffle_3x3', | ||
'backbone.layers.2.3': 'shuffle_5x5', | ||
'backbone.layers.2.4': 'shuffle_3x3', | ||
'backbone.layers.2.5': 'shuffle_5x5', | ||
'backbone.layers.2.6': 'shuffle_7x7', | ||
'backbone.layers.2.7': 'shuffle_7x7', | ||
'backbone.layers.3.0': 'shuffle_xception', | ||
'backbone.layers.3.1': 'shuffle_3x3', | ||
'backbone.layers.3.2': 'shuffle_7x7', | ||
'backbone.layers.3.3': 'shuffle_3x3', | ||
} | ||
|
||
model = dict(fix_subnet=fix_subnet) | ||
|
||
find_unused_parameters = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
_base_ = [ | ||
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py', | ||
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py', | ||
'mmcls::_base_/default_runtime.py', | ||
] | ||
|
||
# model | ||
model = dict( | ||
type='mmrazor.Dsnas', | ||
architecture=dict( | ||
type='ImageClassifier', | ||
data_preprocessor=_base_.data_preprocessor, | ||
backbone=_base_.nas_backbone, | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=1024, | ||
loss=dict( | ||
type='LabelSmoothLoss', | ||
num_classes=1000, | ||
label_smooth_val=0.1, | ||
mode='original', | ||
loss_weight=1.0), | ||
topk=(1, 5))), | ||
mutator=dict(type='mmrazor.DiffModuleMutator'), | ||
pretrain_epochs=15, | ||
finetune_epochs=_base_.search_epochs, | ||
) | ||
|
||
model_wrapper_cfg = dict( | ||
type='mmrazor.DsnasDDP', | ||
broadcast_buffers=False, | ||
find_unused_parameters=True) | ||
|
||
randomness = dict(seed=48, diff_rank_seed=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .autoslim import AutoSlim, AutoSlimDDP | ||
from .darts import Darts, DartsDDP | ||
from .dsnas import Dsnas, DsnasDDP | ||
from .spos import SPOS | ||
|
||
__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP'] | ||
__all__ = [ | ||
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP' | ||
] |
Oops, something went wrong.