diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 5d15599f2..84e74de48 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -29,46 +29,8 @@ jobs:
strategy:
matrix:
python-version: [3.7]
- torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0]
+ torch: [1.12.0]
include:
- - torch: 1.6.0
- torch_version: 1.6
- torchvision: 0.7.0
- - torch: 1.7.0
- torch_version: 1.7
- torchvision: 0.8.1
- - torch: 1.7.0
- torch_version: 1.7
- torchvision: 0.8.1
- python-version: 3.8
- - torch: 1.8.0
- torch_version: 1.8
- torchvision: 0.9.0
- - torch: 1.8.0
- torch_version: 1.8
- torchvision: 0.9.0
- python-version: 3.8
- - torch: 1.9.0
- torch_version: 1.9
- torchvision: 0.10.0
- - torch: 1.9.0
- torch_version: 1.9
- torchvision: 0.10.0
- python-version: 3.8
- - torch: 1.10.0
- torch_version: 1.10
- torchvision: 0.11.0
- - torch: 1.10.0
- torch_version: 1.10
- torchvision: 0.11.0
- python-version: 3.8
- - torch: 1.11.0
- torch_version: 1.11
- torchvision: 0.12.0
- - torch: 1.11.0
- torch_version: 1.11
- torchvision: 0.12.0
- python-version: 3.8
- torch: 1.12.0
torch_version: 1.12
torchvision: 0.13.0
diff --git a/README.md b/README.md
index c440fda98..6a96d0372 100644
--- a/README.md
+++ b/README.md
@@ -187,6 +187,7 @@ This project is released under the [Apache 2.0 license](LICENSE).
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
+- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark.
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
diff --git a/README_zh-CN.md b/README_zh-CN.md
index cf5fd0a5d..169181941 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -160,6 +160,7 @@ MMRazor 是一款由来自不同高校和企业的研发人员共同参与贡献
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
+- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO 系列工具箱与测试基准
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具箱
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
diff --git a/configs/distill/mmcls/deit/README.md b/configs/distill/mmcls/deit/README.md
new file mode 100644
index 000000000..1057c81c2
--- /dev/null
+++ b/configs/distill/mmcls/deit/README.md
@@ -0,0 +1,45 @@
+# DeiT
+
+> [](https://arxiv.org/abs/2012.12877)
+> Training data-efficient image transformers & distillation through attention
+
+
+
+## Abstract
+
+Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
+
+
+
+
+
+## Results and models
+
+### Classification
+
+| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
+| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) |
+
+```{warning}
+Before training, please first install `timm`.
+
+pip install timm
+or
+git clone https://github.com/rwightman/pytorch-image-models
+cd pytorch-image-models && pip install -e .
+```
+
+## Citation
+
+```
+@InProceedings{pmlr-v139-touvron21a,
+ title = {Training data-efficient image transformers & distillation through attention},
+ author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
+ booktitle = {International Conference on Machine Learning},
+ pages = {10347--10357},
+ year = {2021},
+ volume = {139},
+ month = {July}
+}
+```
diff --git a/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py b/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
new file mode 100644
index 000000000..c2cfaf56a
--- /dev/null
+++ b/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
@@ -0,0 +1,64 @@
+_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py']
+
+# student settings
+student = _base_.model
+student.backbone.type = 'DistilledVisionTransformer'
+student.head = dict(
+ type='mmrazor.DeiTClsHead',
+ num_classes=1000,
+ in_channels=768,
+ loss=dict(
+ type='mmcls.LabelSmoothLoss',
+ label_smooth_val=0.1,
+ mode='original',
+ loss_weight=0.5))
+
+data_preprocessor = dict(
+ type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg)
+
+# teacher settings
+checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501
+teacher = dict(
+ _scope_='mmcls',
+ type='ImageClassifier',
+ backbone=dict(
+ type='TIMMBackbone', model_name='regnety_160', pretrained=True),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=3024,
+ loss=dict(
+ type='LabelSmoothLoss',
+ label_smooth_val=0.1,
+ mode='original',
+ loss_weight=0.5),
+ topk=(1, 5),
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_path, prefix='head.')))
+
+model = dict(
+ _scope_='mmrazor',
+ _delete_=True,
+ type='SingleTeacherDistill',
+ architecture=student,
+ teacher=teacher,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.layers.head_dist')),
+ teacher_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.fc')),
+ distill_losses=dict(
+ loss_distill=dict(
+ type='CrossEntropyLoss',
+ loss_weight=0.5,
+ )),
+ loss_forward_mappings=dict(
+ loss_distill=dict(
+ preds_S=dict(from_student=True, recorder='fc'),
+ preds_T=dict(from_student=False, recorder='fc')))))
+
+find_unused_parameters = True
+
+val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
diff --git a/configs/distill/mmcls/deit/metafile.yml b/configs/distill/mmcls/deit/metafile.yml
new file mode 100644
index 000000000..d46a91b64
--- /dev/null
+++ b/configs/distill/mmcls/deit/metafile.yml
@@ -0,0 +1,34 @@
+Collections:
+ - Name: DEIT
+ Metadata:
+ Training Data:
+ - ImageNet-1k
+ Paper:
+ URL: https://arxiv.org/abs/2012.12877
+ Title: Training data-efficient image transformers & distillation through attention
+ README: configs/distill/mmcls/deit/README.md
+
+Models:
+ - Name: deit-base_regnety160_pt-16xb64_in1k
+ In Collection: DEIT
+ Metadata:
+ Student:
+ Config: mmcls::deit/deit-base_pt-16xb64_in1k.py
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
+ Metrics:
+ Top 1 Accuracy: 81.76
+ Top 5 Accuracy: 95.81
+ Teacher:
+ Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
+ Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth
+ Metrics:
+ Top 1 Accuracy: 82.83
+ Top 5 Accuracy: 96.42
+ Results:
+ - Task: Classification
+ Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.24
+ Top 5 Accuracy: 96.33
+ Weights: https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz
+ Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
diff --git a/configs/distill/mmcls/kd/README.md b/configs/distill/mmcls/kd/README.md
index 0dcde2dd3..0fbe7bd9a 100644
--- a/configs/distill/mmcls/kd/README.md
+++ b/configs/distill/mmcls/kd/README.md
@@ -14,9 +14,11 @@ A very simple way to improve the performance of almost any machine learning algo
### Classification
-| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
-| :------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 71.54 | 73.62 | 69.90 | [config](./wsld_cls_head_resnet34_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth?versionId=CAEQHxiBgMC6memK7xciIGMzMDFlYTA4YzhlYTRiMTNiZWU0YTVhY2I5NjVkMjY2) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_20211221_181516.log.json?versionId=CAEQHxiBgIDLmemK7xciIGNkM2FiN2Y4N2E5YjRhNDE4NDVlNmExNDczZDIxN2E5) |
+| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
+| :------: | :------: | :-----------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet34_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet18_8xb32_in1k.py) | 71.81 | 73.62 | 69.90 | [config](./kd_logits_resnet34_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.pth?versionId=CAEQThiBgID1_Me0oBgiIDE3NTk3MDgxZmU2YjRlMjVhMzg1ZTQwMmRhNmYyNGU2) \| [log](https://download.openmmlab.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.json?versionId=CAEQThiBgMDx_se0oBgiIDQxNTM2MWZjZGRhNjRhZDZiZTIzY2Y0NDU3NDA4ODBl) |
+| logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py) | [mobilenet-v2](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py) | 73.56 | 76.55 | 71.86 | [config](./kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/kd/kl_mbv2_w3t1/kd_logits_resnet50_mobilenet-v2_8xb32_in1k_20221025_212407-6ea9e2a5.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/kd/kl_mbv2_w3t1/kd_logits_resnet50_mobilenet-v2_8xb32_in1k_20221025_212407-6ea9e2a5.json) |
+| logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py) | [shufflenet-v2](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py) | 70.87 | 76.55 | 69.55 | [config](./kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/kd/kl_shuffle_w3t1/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k_20221025_224424-5d748c1b.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/kd/kl_shuffle_w3t1/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k_20221025_224424-5d748c1b.json) |
## Citation
diff --git a/configs/distill/mmcls/kd/kd_logits_resnet34_resnet18_8xb32_in1k.py b/configs/distill/mmcls/kd/kd_logits_resnet34_resnet18_8xb32_in1k.py
index 6bf8f0f19..35921c03b 100644
--- a/configs/distill/mmcls/kd/kd_logits_resnet34_resnet18_8xb32_in1k.py
+++ b/configs/distill/mmcls/kd/kd_logits_resnet34_resnet18_8xb32_in1k.py
@@ -4,6 +4,8 @@
'mmcls::_base_/default_runtime.py'
]
+teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501
+
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
@@ -17,8 +19,8 @@
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
- cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
- teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
+ cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=False),
+ teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
@@ -26,7 +28,7 @@
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
- loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)),
+ loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
diff --git a/configs/distill/mmcls/kd/kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py b/configs/distill/mmcls/kd/kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py
new file mode 100644
index 000000000..4f82fb3b0
--- /dev/null
+++ b/configs/distill/mmcls/kd/kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py
@@ -0,0 +1,37 @@
+_base_ = ['mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py']
+
+student = _base_.model
+
+teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501
+
+model = dict(
+ _scope_='mmrazor',
+ _delete_=True,
+ type='SingleTeacherDistill',
+ data_preprocessor=dict(
+ type='ImgDataPreprocessor',
+ # RGB format normalization parameters
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ bgr_to_rgb=True),
+ architecture=student,
+ teacher=dict(
+ cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
+ teacher_ckpt=teacher_ckpt,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.fc')),
+ teacher_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.fc')),
+ distill_losses=dict(
+ loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
+ loss_forward_mappings=dict(
+ loss_kl=dict(
+ preds_S=dict(from_student=True, recorder='fc'),
+ preds_T=dict(from_student=False, recorder='fc')))))
+
+find_unused_parameters = True
+
+val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
diff --git a/configs/distill/mmcls/kd/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py b/configs/distill/mmcls/kd/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py
new file mode 100644
index 000000000..fe9dd5891
--- /dev/null
+++ b/configs/distill/mmcls/kd/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py
@@ -0,0 +1,37 @@
+_base_ = ['mmcls::shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py']
+
+student = _base_.model
+
+teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501
+
+model = dict(
+ _scope_='mmrazor',
+ _delete_=True,
+ type='SingleTeacherDistill',
+ data_preprocessor=dict(
+ type='ImgDataPreprocessor',
+ # RGB format normalization parameters
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ bgr_to_rgb=True),
+ architecture=student,
+ teacher=dict(
+ cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
+ teacher_ckpt=teacher_ckpt,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.fc')),
+ teacher_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.fc')),
+ distill_losses=dict(
+ loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
+ loss_forward_mappings=dict(
+ loss_kl=dict(
+ preds_S=dict(from_student=True, recorder='fc'),
+ preds_T=dict(from_student=False, recorder='fc')))))
+
+find_unused_parameters = True
+
+val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
diff --git a/configs/distill/mmcls/kd/metafile.yml b/configs/distill/mmcls/kd/metafile.yml
index 89e7fc385..a208de783 100644
--- a/configs/distill/mmcls/kd/metafile.yml
+++ b/configs/distill/mmcls/kd/metafile.yml
@@ -7,9 +7,7 @@ Collections:
URL: https://arxiv.org/abs/1503.02531
Title: Distilling the Knowledge in a Neural Network
README: configs/distill/mmcls/kd/README.md
- Code:
- URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/losses/weighted_soft_label_distillation.py
- Version: v0.1.0
+
Models:
- Name: kd_logits_resnet34_resnet18_8xb32_in1k
In Collection: KD
@@ -31,6 +29,54 @@ Models:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
- Top 1 Accuracy: 71.54
+ Top 1 Accuracy: 71.81
Config: configs/distill/mmcls/kd/kd_logits_resnet34_resnet18_8xb32_in1k.py
- Weights: https://download.openmmlab.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth
+ Weights: https://download.openmmlab.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.pth?versionId=CAEQThiBgID1_Me0oBgiIDE3NTk3MDgxZmU2YjRlMjVhMzg1ZTQwMmRhNmYyNGU2
+
+ - Name: kd_logits_resnet50_mobilenet-v2_8xb32_in1k
+ In Collection: KD
+ Metadata:
+ Location: logits
+ Student:
+ Config: mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py
+ Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
+ Metrics:
+ Top 1 Accuracy: 71.86
+ Top 5 Accuracy: 90.42
+ Teacher:
+ Config: mmcls::resnet/resnet50_8xb32_in1k.py
+ Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
+ Metrics:
+ Top 1 Accuracy: 76.55
+ Top 5 Accuracy: 93.06
+ Results:
+ - Task: Image Classification
+ Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 73.56
+ Config: configs/distill/mmcls/kd/kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/kd/kl_mbv2_w3t1/kd_logits_resnet50_mobilenet-v2_8xb32_in1k_20221025_212407-6ea9e2a5.pth
+
+ - Name: kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k
+ In Collection: KD
+ Metadata:
+ Location: logits
+ Student:
+ Config: mmcls::shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py
+ Weights: https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth
+ Metrics:
+ Top 1 Accuracy: 69.55
+ Top 5 Accuracy: 88.92
+ Teacher:
+ Config: mmcls::resnet/resnet50_8xb32_in1k.py
+ Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
+ Metrics:
+ Top 1 Accuracy: 76.55
+ Top 5 Accuracy: 93.06
+ Results:
+ - Task: Image Classification
+ Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 70.87
+ Config: configs/distill/mmcls/kd/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/kd/kl_shuffle_w3t1/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k_20221025_224424-5d748c1b.pth
diff --git a/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py b/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py
new file mode 100644
index 000000000..13947952a
--- /dev/null
+++ b/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py
@@ -0,0 +1,21 @@
+_base_ = ['./cwd_fpn_retina_r101_retina_r50_1x_coco.py']
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', interval=-1),
+ visualization=dict(
+ _scope_='mmrazor',
+ type='RazorVisualizationHook',
+ enabled=True,
+ recorders=dict(
+ # todo: Maybe it is hard for users to understand why to add a
+ # prefix `architecture.`
+ neck=dict(
+ _scope_='mmrazor',
+ type='ModuleOutputs',
+ source='architecture.neck')),
+ mappings=dict(
+ p3=dict(recorder='neck', data_idx=0),
+ p4=dict(recorder='neck', data_idx=1),
+ p5=dict(recorder='neck', data_idx=2),
+ p6=dict(recorder='neck', data_idx=3)),
+ out_dir='retina_vis'))
diff --git a/configs/distill/mmdet/pkd/README.md b/configs/distill/mmdet/pkd/README.md
new file mode 100644
index 000000000..a25dc5ae8
--- /dev/null
+++ b/configs/distill/mmdet/pkd/README.md
@@ -0,0 +1,34 @@
+# PKD
+
+> [PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient](https://arxiv.org/abs/2207.02039)
+
+
+
+## Abstract
+
+Knowledge distillation(KD) is a widely-used technique to train compact models in object detection. However, there is still a lack of study on how to distill between heterogeneous detectors. In this paper, we empirically find that better FPN features from a heterogeneous teacher detector can help the student although their detection heads and label assignments are different. However, directly aligning the feature maps to distill detectors suffers from two problems. First, the difference in feature magnitude between the teacher and the student could enforce overly strict constraints on the student. Second, the FPN stages and channels with large feature magnitude from the teacher model could dominate the gradient of distillation loss, which will overwhelm the effects of other features in KD and introduce much noise. To address the above issues, we propose to imitate features with Pearson Correlation Coefficient to focus on the relational information from the teacher and relax constraints on the magnitude of the features. Our method consistently outperforms the existing detection KD methods and works for both homogeneous and heterogeneous student-teacher pairs. Furthermore, it converges faster. With a powerful MaskRCNN-Swin detector as the teacher, ResNet-50 based RetinaNet and FCOS achieve 41.5% and 43.9% mAP on COCO2017, which are 4.1% and 4.8% higher than the baseline, respectively.
+
+![pipeline](https://user-images.githubusercontent.com/41630003/197719796-76fa5f33-1d54-4927-8a08-86f5c6e33879.png)
+
+## Results and models
+
+### Detection
+
+| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download |
+| :------: | :-----: | :--------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------: | :-----: | :--: | :----: | :----: | :-----------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| FPN | COCO | [FCOS-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_1x_coco.py) | 1x | 40.3 | 42.6 | 36.5 | [config](pkd_fpn_fcos_x101_retina_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos_retina/pkd_fpn_fcos_x101_retina_r50_1x_coco_20220925_181547-9cac5059.pth?versionId=CAEQThiBgMCLyNC0oBgiIDBjY2FkY2JlNGFiYzRmM2RiZGUyYzM1NjQxYzQxODA4) \| [log](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos_retina/pkd_fpn_fcos_x101_retina_r50_1x_coco_20220925_181547-9cac5059.json?versionId=CAEQThiBgMDA0dS0oBgiIDM4ZjZlZmVkMzc4MjQxMGJiN2FlMDFlOTA2NGIzZGQ4) |
+| FPN | COCO | [Faster-Rcnn-R101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py) | [Faster-rcnn-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py) | 2x | 40.3 | 39.8 | 38.4 | [config](pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_frcnn/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco_20221014_103040-3efbd439.pth?versionId=CAEQThiBgMDQr9C0oBgiIDMyZWE1Y2ZlMDA2ZDQ2ZGNhZmQ3NzMxODk3YzgzYWFl) \| [log](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_frcnn/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco_20221014_103040-3efbd439.json?versionId=CAEQThiBgICYsNC0oBgiIDdhNWY5ZjZlYjUyNzRjMGU4NGFhYzk4NzQwZDAxY2Rj) |
+| FPN | COCO | [Mask-Rcnn-Swin](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | 41.5 | 48.2 | 37.4 | [config](pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_swin_retina/pkd_fpn_mask_rcnn_swin_retina_r50_2x_coco_20220925_142555-edec7433.pth?versionId=CAEQThiBgIDWqNC0oBgiIDViOGE0ZDU4ODgxNzQ5YmE5OGU3MzRkMjFiZGRjZmRm) \| [log](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_swin_retina/pkd_fpn_mask_rcnn_swin_retina_r50_2x_coco_20220925_142555-edec7433.json?versionId=CAEQThiBgIDVqdC0oBgiIDU3YzFjOWRmNWY3NTRmYjFhMDdmNzU2ODE3MzdlZThk) |
+| FPN | COCO | [Reppoints-X101-dcn](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py) | [Reppoints-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py) | 2x | 42.3 | 44.2 | 38.6 | [config](pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_reppoints/pkd_fpn_reppoints_x101_dcn_reppoints_r50_2x_coco_20220926_145818-f8932e12.pth?versionId=CAEQThiBgIC8rNC0oBgiIGU2N2IxM2NkMjNlMjQyN2E4YmVlNmViNGI2MDY3OTE5) \| [log](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_reppoints/pkd_fpn_reppoints_x101_dcn_reppoints_r50_2x_coco_20220926_145818-f8932e12.json?versionId=CAEQThiBgICordC0oBgiIDJhMjBjOGZiN2UxNjQxYmI5MzE3NWVhZDgxZDE2NmJm) |
+| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | 40.8 | 41.0 | 37.4 | [config](pkd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_retinax_retina/pkd_fpn_retina_x101_retina_r50_2x_coco_20221014_232526-4c0f8d96.pth?versionId=CAEQThiBgIDQqdC0oBgiIGFmZjNmZmE4NDFiMDQ4MzhiMzdjOGI2NzI4MTQxMjFi) \| [log](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_retinax_retina/pkd_fpn_retina_x101_retina_r50_2x_coco_20221014_232526-4c0f8d96.json?versionId=CAEQThiBgMC2qdC0oBgiIGRkMTIzODYwMzliMDQ3M2JiYjNlYjA5N2I4Y2QzMGFl) |
+
+## Citation
+
+```latex
+@article{cao2022pkd,
+ title={PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient},
+ author={Cao, Weihan and Zhang, Yifan and Gao, Jianfei and Cheng, Anda and Cheng, Ke and Cheng, Jian},
+ journal={arXiv preprint arXiv:2207.02039},
+ year={2022}
+}
+```
diff --git a/configs/distill/mmdet/pkd/metafile.yml b/configs/distill/mmdet/pkd/metafile.yml
new file mode 100644
index 000000000..6ea2347b5
--- /dev/null
+++ b/configs/distill/mmdet/pkd/metafile.yml
@@ -0,0 +1,110 @@
+Models:
+ - Name: pkd_fpn_fcos_x101_retina_r50_1x_coco
+ In Collection: PKD
+ Metadata:
+ Location: FPN
+ Student:
+ Metrics:
+ box AP: 36.5
+ Config: mmdet::retinanet/retinanet_r50_fpn_1x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth
+ Teacher:
+ Metrics:
+ box AP: 42.6
+ Config: mmdet::fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 40.3
+ Config: configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos_retina/pkd_fpn_fcos_x101_retina_r50_1x_coco_20220925_181547-9cac5059.pth?versionId=CAEQThiBgMCLyNC0oBgiIDBjY2FkY2JlNGFiYzRmM2RiZGUyYzM1NjQxYzQxODA4
+
+ - Name: pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco
+ In Collection: PKD
+ Metadata:
+ Location: FPN
+ Student:
+ Metrics:
+ box AP: 38.4
+ Config: mmdet::faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth
+ Teacher:
+ Metrics:
+ box AP: 39.8
+ Config: mmdet::faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 40.4
+ Config: configs/distill/mmdet/pkd/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_frcnn/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco_20221014_103040-3efbd439.pth?versionId=CAEQThiBgMDQr9C0oBgiIDMyZWE1Y2ZlMDA2ZDQ2ZGNhZmQ3NzMxODk3YzgzYWFl
+
+ - Name: pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco
+ In Collection: PKD
+ Metadata:
+ Location: FPN
+ Student:
+ Metrics:
+ box AP: 37.4
+ Config: mmdet::retinanet/retinanet_r50_fpn_2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth
+ Teacher:
+ Metrics:
+ box AP: 48.2
+ Config: mmdet::swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 41.5
+ Config: configs/distill/mmdet/pkd/pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_swin_retina/pkd_fpn_mask_rcnn_swin_retina_r50_2x_coco_20220925_142555-edec7433.pth?versionId=CAEQThiBgIDWqNC0oBgiIDViOGE0ZDU4ODgxNzQ5YmE5OGU3MzRkMjFiZGRjZmRm
+
+ - Name: pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco
+ In Collection: PKD
+ Metadata:
+ Location: FPN
+ Student:
+ Metrics:
+ box AP: 38.6
+ Config: mmdet::reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_r50_fpn_gn-neck%2Bhead_2x_coco/reppoints_moment_r50_fpn_gn-neck%2Bhead_2x_coco_20200329-91babaa2.pth
+ Teacher:
+ Metrics:
+ box AP: 44.2
+ Config: mmdet::reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 42.3
+ Config: configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_reppoints/pkd_fpn_reppoints_x101_dcn_reppoints_r50_2x_coco_20220926_145818-f8932e12.pth?versionId=CAEQThiBgIC8rNC0oBgiIGU2N2IxM2NkMjNlMjQyN2E4YmVlNmViNGI2MDY3OTE5
+
+ - Name: pkd_fpn_retina_x101_retina_r50_2x_coco
+ In Collection: PKD
+ Metadata:
+ Location: FPN
+ Student:
+ Metrics:
+ box AP: 37.4
+ Config: mmdet::retinanet/retinanet_r50_fpn_2x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth
+ Teacher:
+ Metrics:
+ box AP: 41.0
+ Config: mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 40.8
+ Config: configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_retinax_retina/pkd_fpn_retina_x101_retina_r50_2x_coco_20221014_232526-4c0f8d96.pth?versionId=CAEQThiBgIDQqdC0oBgiIGFmZjNmZmE4NDFiMDQ4MzhiMzdjOGI2NzI4MTQxMjFi
diff --git a/configs/distill/mmdet/pkd/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py b/configs/distill/mmdet/pkd/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py
new file mode 100644
index 000000000..c9496792d
--- /dev/null
+++ b/configs/distill/mmdet/pkd/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py
@@ -0,0 +1,45 @@
+_base_ = [
+ 'mmdet::_base_/datasets/coco_detection.py',
+ 'mmdet::_base_/schedules/schedule_2x.py',
+ 'mmdet::_base_/default_runtime.py'
+]
+
+teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth' # noqa: E501
+
+model = dict(
+ _scope_='mmrazor',
+ type='FpnTeacherDistill',
+ architecture=dict(
+ cfg_path='mmdet::faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py',
+ pretrained=False),
+ teacher=dict(
+ cfg_path='mmdet::faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py',
+ pretrained=False),
+ teacher_ckpt=teacher_ckpt,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
+ teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
+ distill_losses=dict(
+ loss_pkd_fpn0=dict(type='PKDLoss', loss_weight=6),
+ loss_pkd_fpn1=dict(type='PKDLoss', loss_weight=6),
+ loss_pkd_fpn2=dict(type='PKDLoss', loss_weight=6),
+ loss_pkd_fpn3=dict(type='PKDLoss', loss_weight=6)),
+ loss_forward_mappings=dict(
+ loss_pkd_fpn0=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=0),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=0)),
+ loss_pkd_fpn1=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=1),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=1)),
+ loss_pkd_fpn2=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=2),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=2)),
+ loss_pkd_fpn3=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=3),
+ preds_T=dict(from_student=False, recorder='fpn',
+ data_idx=3)))))
+
+find_unused_parameters = True
+
+val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
diff --git a/configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py b/configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py
new file mode 100644
index 000000000..3a8059acc
--- /dev/null
+++ b/configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py
@@ -0,0 +1,27 @@
+_base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py']
+
+teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth' # noqa: E501
+
+model = dict(
+ architecture=dict(
+ cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'),
+ teacher=dict(
+ cfg_path= # noqa: E251
+ 'mmdet::fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py'),
+ teacher_ckpt=teacher_ckpt)
+
+# training schedule for 1x
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
+
+# learning rate
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=12,
+ by_epoch=True,
+ milestones=[8, 11],
+ gamma=0.1)
+]
diff --git a/configs/distill/mmdet/pkd/pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py b/configs/distill/mmdet/pkd/pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py
new file mode 100644
index 000000000..3ba6727f5
--- /dev/null
+++ b/configs/distill/mmdet/pkd/pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py
@@ -0,0 +1,53 @@
+_base_ = [
+ 'mmdet::_base_/datasets/coco_instance.py',
+ 'mmdet::_base_/schedules/schedule_2x.py',
+ 'mmdet::_base_/default_runtime.py'
+]
+
+teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco_20210903_104808-b92c91f1.pth' # noqa: E501
+
+model = dict(
+ _scope_='mmrazor',
+ type='FpnTeacherDistill',
+ architecture=dict(
+ cfg_path='mmdet::retinanet/retinanet_r50_fpn_2x_coco.py',
+ pretrained=False),
+ teacher=dict(
+ cfg_path= # noqa: E251
+ 'mmdet::swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py',
+ pretrained=False),
+ teacher_ckpt=teacher_ckpt,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
+ teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
+ distill_losses=dict(
+ loss_pkd_fpn0=dict(type='PKDLoss', loss_weight=6),
+ loss_pkd_fpn1=dict(type='PKDLoss', loss_weight=6),
+ loss_pkd_fpn2=dict(type='PKDLoss', loss_weight=6),
+ loss_pkd_fpn3=dict(type='PKDLoss', loss_weight=6)),
+ loss_forward_mappings=dict(
+ loss_pkd_fpn0=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=0),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=0)),
+ loss_pkd_fpn1=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=1),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=1)),
+ loss_pkd_fpn2=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=2),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=2)),
+ loss_pkd_fpn3=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=3),
+ preds_T=dict(from_student=False, recorder='fpn',
+ data_idx=3)))))
+
+find_unused_parameters = True
+
+val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(lr=0.01))
+
+# dataset
+val_evaluator = dict(metric=['bbox'])
+test_evaluator = val_evaluator
diff --git a/configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py b/configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py
new file mode 100644
index 000000000..ecf06bd37
--- /dev/null
+++ b/configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py
@@ -0,0 +1,13 @@
+_base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py']
+
+teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth' # noqa: E501
+
+model = dict(
+ architecture=dict(
+ cfg_path= # noqa: E251
+ 'mmdet::reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py'),
+ teacher=dict(
+ cfg_path= # noqa: E251
+ 'mmdet::reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py' # noqa: E501
+ ),
+ teacher_ckpt=teacher_ckpt)
diff --git a/configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py b/configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py
new file mode 100644
index 000000000..40172c04d
--- /dev/null
+++ b/configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py
@@ -0,0 +1,19 @@
+_base_ = ['./pkd_fpn_frcnn_r101_frcnn_r50_2x_coco.py']
+
+teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501
+
+model = dict(
+ architecture=dict(
+ cfg_path='mmdet::retinanet/retinanet_r50_fpn_2x_coco.py'),
+ teacher=dict(
+ cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py'),
+ teacher_ckpt=teacher_ckpt,
+ distiller=dict(
+ distill_losses=dict(
+ loss_pkd_fpn0=dict(loss_weight=10),
+ loss_pkd_fpn1=dict(loss_weight=10),
+ loss_pkd_fpn2=dict(loss_weight=10),
+ loss_pkd_fpn3=dict(loss_weight=10))))
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(lr=0.01))
diff --git a/configs/distill/mmdet3d/pkd/README.md b/configs/distill/mmdet3d/pkd/README.md
new file mode 100644
index 000000000..fdd191f69
--- /dev/null
+++ b/configs/distill/mmdet3d/pkd/README.md
@@ -0,0 +1,30 @@
+# PKD
+
+> [PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient](https://arxiv.org/abs/2207.02039)
+
+
+
+## Abstract
+
+Knowledge distillation(KD) is a widely-used technique to train compact models in object detection. However, there is still a lack of study on how to distill between heterogeneous detectors. In this paper, we empirically find that better FPN features from a heterogeneous teacher detector can help the student although their detection heads and label assignments are different. However, directly aligning the feature maps to distill detectors suffers from two problems. First, the difference in feature magnitude between the teacher and the student could enforce overly strict constraints on the student. Second, the FPN stages and channels with large feature magnitude from the teacher model could dominate the gradient of distillation loss, which will overwhelm the effects of other features in KD and introduce much noise. To address the above issues, we propose to imitate features with Pearson Correlation Coefficient to focus on the relational information from the teacher and relax constraints on the magnitude of the features. Our method consistently outperforms the existing detection KD methods and works for both homogeneous and heterogeneous student-teacher pairs. Furthermore, it converges faster. With a powerful MaskRCNN-Swin detector as the teacher, ResNet-50 based RetinaNet and FCOS achieve 41.5% and 43.9% mAP on COCO2017, which are 4.1% and 4.8% higher than the baseline, respectively.
+
+![pipeline](https://user-images.githubusercontent.com/88702197/187424502-d8efb7a3-c40c-4e53-a36c-bd947de464a4.png)
+
+## Results and models
+
+### Detection
+
+| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download |
+| :------: | :--------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------: | :-----: | :--: | :----: | :----: | :------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| FPN | nus-mono3d | [FCOS3d-R101](https://github.com/open-mmlab/mmdetection3d/blob/dev-1.x/configs/fcos3d/fcos3d_r101-caffe-dcn_fpn_head-gn_8xb2-1x_nus-mono3d_finetune.py) | [FCOS3d-R50](<>) | 1x | 29.3 | 32.1 | 26.8 | [config](pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py) | [teacher](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645-8d806dc2.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos3d_w10/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d_20220928_234557-0b51b62e.pth?versionId=CAEQThiBgMC8sdC0oBgiIDAwOWE2OWUyNDU1NTQ1MjBhZTY1NmNjODZmMDZkZTM2) \| [log](https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos3d_w10/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d_20220928_234557-0b51b62e.json?versionId=CAEQThiBgIDrvdC0oBgiIDNmNGNkNDZhM2RmNjQ1MmI4ZDM0OGNmYmFkYjk5ZjFi) |
+
+## Citation
+
+```latex
+@article{cao2022pkd,
+ title={PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient},
+ author={Cao, Weihan and Zhang, Yifan and Gao, Jianfei and Cheng, Anda and Cheng, Ke and Cheng, Jian},
+ journal={arXiv preprint arXiv:2207.02039},
+ year={2022}
+}
+```
diff --git a/configs/distill/mmdet3d/pkd/metafile.yml b/configs/distill/mmdet3d/pkd/metafile.yml
new file mode 100644
index 000000000..1f60cffd7
--- /dev/null
+++ b/configs/distill/mmdet3d/pkd/metafile.yml
@@ -0,0 +1,22 @@
+Models:
+ - Name: pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d
+ In Collection: PKD
+ Metadata:
+ Location: FPN
+ Student:
+ Metrics:
+ box AP: 26.8
+ Config:
+ Weights:
+ Teacher:
+ Metrics:
+ box AP: 32.1
+ Config: mmdet3d::fcos3d/fcos3d_r101-caffe-dcn_fpn_head-gn_8xb2-1x_nus-mono3d_finetune.py
+ Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645-8d806dc2.pth
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 29.3
+ Config: configs/distill/mmdet3d/pkd/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py
+ Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos3d_w10/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d_20220928_234557-0b51b62e.json?versionId=CAEQThiBgIDrvdC0oBgiIDNmNGNkNDZhM2RmNjQ1MmI4ZDM0OGNmYmFkYjk5ZjFi
diff --git a/configs/distill/mmdet3d/pkd/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py b/configs/distill/mmdet3d/pkd/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py
new file mode 100644
index 000000000..cbbedf7e0
--- /dev/null
+++ b/configs/distill/mmdet3d/pkd/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py
@@ -0,0 +1,49 @@
+_base_ = [
+ 'mmdet3d::fcos3d/fcos3d_r101-caffe-dcn_fpn_head-gn_8xb2-1x_nus-mono3d.py',
+]
+
+train_dataloader = dict(num_workers=4)
+
+student = _base_.model
+student.backbone.depth = 50 # using original ResNet50
+student.backbone.dcn = None # no dcn in backbone
+student.backbone.stage_with_dcn = (False, False, False, False)
+student.backbone.init_cfg.checkpoint = 'open-mmlab://detectron2/resnet50_caffe'
+
+teacher_ckpt = 'https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645-8d806dc2.pth' # noqa: E501
+model = dict(
+ _scope_='mmrazor',
+ _delete_=True,
+ type='FpnTeacherDistill',
+ architecture=student,
+ teacher=dict(
+ cfg_path= # noqa: E251
+ 'mmdet3d::fcos3d/fcos3d_r101-caffe-dcn_fpn_head-gn_8xb2-1x_nus-mono3d_finetune.py', # noqa: E501
+ pretrained=False),
+ teacher_ckpt=teacher_ckpt,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
+ teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
+ distill_losses=dict(
+ loss_pkd_fpn0=dict(type='PKDLoss', loss_weight=10),
+ loss_pkd_fpn1=dict(type='PKDLoss', loss_weight=10),
+ loss_pkd_fpn2=dict(type='PKDLoss', loss_weight=10),
+ loss_pkd_fpn3=dict(type='PKDLoss', loss_weight=10)),
+ loss_forward_mappings=dict(
+ loss_pkd_fpn0=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=0),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=0)),
+ loss_pkd_fpn1=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=1),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=1)),
+ loss_pkd_fpn2=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=2),
+ preds_T=dict(from_student=False, recorder='fpn', data_idx=2)),
+ loss_pkd_fpn3=dict(
+ preds_S=dict(from_student=True, recorder='fpn', data_idx=3),
+ preds_T=dict(from_student=False, recorder='fpn',
+ data_idx=3)))))
+
+find_unused_parameters = True
+train_cfg = dict(val_interval=12)
diff --git a/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml b/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml
index d9faa6497..c56cfe46e 100644
--- a/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml
+++ b/configs/nas/mmcls/darts/DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml
@@ -1,56 +1,80 @@
normal_n2:
- - normal_n2_p0
- - normal_n2_p1
+ chosen:
+ - normal_n2_p0
+ - normal_n2_p1
normal_n2_p0:
- - sep_conv_3x3
+ chosen:
+ - sep_conv_3x3
normal_n2_p1:
- - sep_conv_3x3
+ chosen:
+ - sep_conv_3x3
normal_n3:
- - normal_n3_p0
- - normal_n3_p1
+ chosen:
+ - normal_n3_p0
+ - normal_n3_p1
normal_n3_p0:
- - skip_connect
+ chosen:
+ - skip_connect
normal_n3_p1:
- - sep_conv_5x5
+ chosen:
+ - sep_conv_5x5
normal_n4:
- - normal_n4_p0
- - normal_n4_p1
+ chosen:
+ - normal_n4_p0
+ - normal_n4_p1
normal_n4_p0:
- - sep_conv_3x3
+ chosen:
+ - sep_conv_3x3
normal_n4_p1:
- - skip_connect
+ chosen:
+ - skip_connect
normal_n5:
- - normal_n5_p0
- - normal_n5_p1
+ chosen:
+ - normal_n5_p0
+ - normal_n5_p1
normal_n5_p0:
- - skip_connect
+ chosen:
+ - skip_connect
normal_n5_p1:
- - skip_connect
+ chosen:
+ - skip_connect
reduce_n2:
- - reduce_n2_p0
- - reduce_n2_p1
+ chosen:
+ - reduce_n2_p0
+ - reduce_n2_p1
reduce_n2_p0:
- - max_pool_3x3
+ chosen:
+ - max_pool_3x3
reduce_n2_p1:
- - sep_conv_3x3
+ chosen:
+ - sep_conv_3x3
reduce_n3:
- - reduce_n3_p0
- - reduce_n3_p2
+ chosen:
+ - reduce_n3_p0
+ - reduce_n3_p2
reduce_n3_p0:
- - max_pool_3x3
+ chosen:
+ - max_pool_3x3
reduce_n3_p2:
- - dil_conv_5x5
+ chosen:
+ - dil_conv_5x5
reduce_n4:
- - reduce_n4_p0
- - reduce_n4_p2
+ chosen:
+ - reduce_n4_p0
+ - reduce_n4_p2
reduce_n4_p0:
- - max_pool_3x3
+ chosen:
+ - max_pool_3x3
reduce_n4_p2:
- - skip_connect
+ chosen:
+ - skip_connect
reduce_n5:
- - reduce_n5_p0
- - reduce_n5_p2
+ chosen:
+ - reduce_n5_p0
+ - reduce_n5_p2
reduce_n5_p0:
- - max_pool_3x3
+ chosen:
+ - max_pool_3x3
reduce_n5_p2:
- - skip_connect
+ chosen:
+ - skip_connect
diff --git a/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
index d2fa294d3..0c35c01b5 100644
--- a/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
+++ b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
@@ -1,20 +1,40 @@
-backbone.layers.0.0: shuffle_3x3
-backbone.layers.0.1: shuffle_3x3
-backbone.layers.0.2: shuffle_xception
-backbone.layers.0.3: shuffle_3x3
-backbone.layers.1.0: shuffle_xception
-backbone.layers.1.1: shuffle_7x7
-backbone.layers.1.2: shuffle_3x3
-backbone.layers.1.3: shuffle_3x3
-backbone.layers.2.0: shuffle_xception
-backbone.layers.2.1: shuffle_xception
-backbone.layers.2.2: shuffle_7x7
-backbone.layers.2.3: shuffle_xception
-backbone.layers.2.4: shuffle_xception
-backbone.layers.2.5: shuffle_xception
-backbone.layers.2.6: shuffle_7x7
-backbone.layers.2.7: shuffle_3x3
-backbone.layers.3.0: shuffle_3x3
-backbone.layers.3.1: shuffle_xception
-backbone.layers.3.2: shuffle_xception
-backbone.layers.3.3: shuffle_3x3
+backbone.layers.0.0:
+ chosen: shuffle_3x3
+backbone.layers.0.1:
+ chosen: shuffle_7x7
+backbone.layers.0.2:
+ chosen: shuffle_3x3
+backbone.layers.0.3:
+ chosen: shuffle_5x5
+backbone.layers.1.0:
+ chosen: shuffle_3x3
+backbone.layers.1.1:
+ chosen: shuffle_3x3
+backbone.layers.1.2:
+ chosen: shuffle_3x3
+backbone.layers.1.3:
+ chosen: shuffle_7x7
+backbone.layers.2.0:
+ chosen: shuffle_xception
+backbone.layers.2.1:
+ chosen: shuffle_3x3
+backbone.layers.2.2:
+ chosen: shuffle_3x3
+backbone.layers.2.3:
+ chosen: shuffle_5x5
+backbone.layers.2.4:
+ chosen: shuffle_3x3
+backbone.layers.2.5:
+ chosen: shuffle_5x5
+backbone.layers.2.6:
+ chosen: shuffle_7x7
+backbone.layers.2.7:
+ chosen: shuffle_7x7
+backbone.layers.3.0:
+ chosen: shuffle_xception
+backbone.layers.3.1:
+ chosen: shuffle_3x3
+backbone.layers.3.2:
+ chosen: shuffle_7x7
+backbone.layers.3.3:
+ chosen: shuffle_3x3
diff --git a/configs/nas/mmcls/dsnas/README.md b/configs/nas/mmcls/dsnas/README.md
index 6a085eb78..3c1501fd2 100644
--- a/configs/nas/mmcls/dsnas/README.md
+++ b/configs/nas/mmcls/dsnas/README.md
@@ -16,9 +16,9 @@ Based on this observation, DSNAS proposes a task-specific end-to-end differentia
### Supernet
-| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks |
-| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: |
-| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched |
+| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks |
+| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------: |
+| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/dsnas/dsnas_supernet_8xb128_in1k_20220926_171954-29b87e3a.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/dsnas/dsnas_supernet_8xb128_in1k_20220926_171954-29b87e3a.log) | MMRazor searched |
**Note**:
diff --git a/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py
index ca30a5946..a96c81f82 100644
--- a/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py
+++ b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py
@@ -1,28 +1,7 @@
_base_ = ['./dsnas_supernet_8xb128_in1k.py']
# NOTE: Replace this with the mutable_cfg searched by yourself.
-fix_subnet = {
- 'backbone.layers.0.0': 'shuffle_3x3',
- 'backbone.layers.0.1': 'shuffle_7x7',
- 'backbone.layers.0.2': 'shuffle_3x3',
- 'backbone.layers.0.3': 'shuffle_5x5',
- 'backbone.layers.1.0': 'shuffle_3x3',
- 'backbone.layers.1.1': 'shuffle_3x3',
- 'backbone.layers.1.2': 'shuffle_3x3',
- 'backbone.layers.1.3': 'shuffle_7x7',
- 'backbone.layers.2.0': 'shuffle_xception',
- 'backbone.layers.2.1': 'shuffle_3x3',
- 'backbone.layers.2.2': 'shuffle_3x3',
- 'backbone.layers.2.3': 'shuffle_5x5',
- 'backbone.layers.2.4': 'shuffle_3x3',
- 'backbone.layers.2.5': 'shuffle_5x5',
- 'backbone.layers.2.6': 'shuffle_7x7',
- 'backbone.layers.2.7': 'shuffle_7x7',
- 'backbone.layers.3.0': 'shuffle_xception',
- 'backbone.layers.3.1': 'shuffle_3x3',
- 'backbone.layers.3.2': 'shuffle_7x7',
- 'backbone.layers.3.3': 'shuffle_3x3',
-}
+fix_subnet = 'configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml'
model = dict(fix_subnet=fix_subnet)
diff --git a/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
index ea821da40..50d11dee2 100644
--- a/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
+++ b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
@@ -6,7 +6,7 @@
# model
model = dict(
- type='mmrazor.Dsnas',
+ type='mmrazor.DSNAS',
architecture=dict(
type='ImageClassifier',
data_preprocessor=_base_.data_preprocessor,
@@ -29,7 +29,7 @@
)
model_wrapper_cfg = dict(
- type='mmrazor.DsnasDDP',
+ type='mmrazor.DSNASDDP',
broadcast_buffers=False,
find_unused_parameters=True)
diff --git a/configs/nas/mmcls/spos/SPOS_SUBNET.yaml b/configs/nas/mmcls/spos/SPOS_SUBNET.yaml
new file mode 100644
index 000000000..ba809da1d
--- /dev/null
+++ b/configs/nas/mmcls/spos/SPOS_SUBNET.yaml
@@ -0,0 +1,40 @@
+backbone.layers.0.0:
+ chosen: shuffle_7x7
+backbone.layers.0.1:
+ chosen: shuffle_3x3
+backbone.layers.0.2:
+ chosen: shuffle_7x7
+backbone.layers.0.3:
+ chosen: shuffle_3x3
+backbone.layers.1.0:
+ chosen: shuffle_xception
+backbone.layers.1.1:
+ chosen: shuffle_5x5
+backbone.layers.1.2:
+ chosen: shuffle_5x5
+backbone.layers.1.3:
+ chosen: shuffle_3x3
+backbone.layers.2.0:
+ chosen: shuffle_3x3
+backbone.layers.2.1:
+ chosen: shuffle_5x5
+backbone.layers.2.2:
+ chosen: shuffle_3x3
+backbone.layers.2.3:
+ chosen: shuffle_5x5
+backbone.layers.2.4:
+ chosen: shuffle_3x3
+backbone.layers.2.5:
+ chosen: shuffle_xception
+backbone.layers.2.6:
+ chosen: shuffle_5x5
+backbone.layers.2.7:
+ chosen: shuffle_7x7
+backbone.layers.3.0:
+ chosen: shuffle_7x7
+backbone.layers.3.1:
+ chosen: shuffle_3x3
+backbone.layers.3.2:
+ chosen: shuffle_5x5
+backbone.layers.3.3:
+ chosen: shuffle_xception
diff --git a/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py
index ff7c3bf8c..1243d16b2 100644
--- a/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py
+++ b/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py
@@ -1,7 +1,7 @@
_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']
-# FIXME: you may replace this with the mutable_cfg searched by yourself
-fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20220715-aa94d5ef_subnet_cfg_v1.yaml' # noqa: E501
+# FIXME: you may replace this with the searched by yourself
+fix_subnet = 'configs/nas/mmcls/spos/SPOS_SUBNET.yaml'
model = dict(fix_subnet=fix_subnet)
diff --git a/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml b/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml
new file mode 100644
index 000000000..c7bcab916
--- /dev/null
+++ b/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml
@@ -0,0 +1,40 @@
+backbone.layers.0.0:
+ chosen: shuffle_5x5
+backbone.layers.0.1:
+ chosen: shuffle_3x3
+backbone.layers.0.2:
+ chosen: shuffle_3x3
+backbone.layers.0.3:
+ chosen: shuffle_3x3
+backbone.layers.1.0:
+ chosen: shuffle_xception
+backbone.layers.1.1:
+ chosen: shuffle_3x3
+backbone.layers.1.2:
+ chosen: shuffle_xception
+backbone.layers.1.3:
+ chosen: shuffle_7x7
+backbone.layers.2.0:
+ chosen: shuffle_7x7
+backbone.layers.2.1:
+ chosen: shuffle_7x7
+backbone.layers.2.2:
+ chosen: shuffle_xception
+backbone.layers.2.3:
+ chosen: shuffle_xception
+backbone.layers.2.4:
+ chosen: shuffle_3x3
+backbone.layers.2.5:
+ chosen: shuffle_7x7
+backbone.layers.2.6:
+ chosen: shuffle_5x5
+backbone.layers.2.7:
+ chosen: shuffle_xception
+backbone.layers.3.0:
+ chosen: shuffle_7x7
+backbone.layers.3.1:
+ chosen: shuffle_7x7
+backbone.layers.3.2:
+ chosen: shuffle_7x7
+backbone.layers.3.3:
+ chosen: shuffle_5x5
diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py
index 43d0f4983..8334c78b8 100644
--- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py
+++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py
@@ -1,7 +1,7 @@
_base_ = ['./detnas_frcnn_shufflenet_supernet_coco_1x.py']
# FIXME: you may replace this with the searched by yourself
-fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_subnet_cfg_v1.yaml' # noqa: E501
+fix_subnet = 'configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml'
model = dict(fix_subnet=fix_subnet)
diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/adaround.py
new file mode 100644
index 000000000..389575dc6
--- /dev/null
+++ b/configs/quantization/ptq/adaround.py
@@ -0,0 +1,47 @@
+_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']
+
+test_cfg = dict(
+ _delete_=True,
+ type='mmrazor.PTQLoop',
+ dataloader=_base_.test_dataloader,
+ evaluator=_base_.test_evaluator,
+ calibrate_dataloader=_base_.train_dataloader,
+ batch_num=32,
+ # reconstruction_cfg=dict(
+ # pattern='layer',
+ # loss=dict(
+ # type='mmrazor.AdaRoundLoss',
+ # iters=20000
+ # )
+ # )
+)
+
+model = dict(
+ _delete_=True,
+ type='mmrazor.GeneralQuant',
+ architecture=_base_.model,
+ quantizer=dict(
+ type='mmrazor.CustomQuantizer',
+ is_qat=False,
+ skipped_methods=[
+ 'mmcls.models.heads.ClsHead._get_loss',
+ 'mmcls.models.heads.ClsHead._get_predictions'
+ ],
+ qconfig=dict(
+ qtype='affine',
+ w_observer=dict(type='mmrazor.MSEObserver'),
+ a_observer=dict(type='mmrazor.EMAMSEObserver'),
+ w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'),
+ a_fake_quant=dict(type='mmrazor.FakeQuantize'),
+ w_qscheme=dict(
+ bit=2,
+ is_symmetry=False,
+ is_per_channel=True,
+ is_pot_scale=False,
+ ),
+ a_qscheme=dict(
+ bit=4,
+ is_symmetry=False,
+ is_per_channel=False,
+ is_pot_scale=False),
+ )))
diff --git a/configs/quantization/ptq/demo.py b/configs/quantization/ptq/demo.py
new file mode 100644
index 000000000..af6a0a5df
--- /dev/null
+++ b/configs/quantization/ptq/demo.py
@@ -0,0 +1 @@
+_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']
diff --git a/configs/quantization/qat/demo.py b/configs/quantization/qat/demo.py
new file mode 100644
index 000000000..be3ec6013
--- /dev/null
+++ b/configs/quantization/qat/demo.py
@@ -0,0 +1 @@
+_base_ = ['./lsq_resnet50_8xb16_cifar10.py']
diff --git a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py
new file mode 100644
index 000000000..a246bc265
--- /dev/null
+++ b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py
@@ -0,0 +1,37 @@
+_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py']
+
+train_cfg = dict(
+ _delete_=True,
+ type='mmrazor.QATEpochBasedLoop',
+ max_epochs=_base_.train_cfg.max_epochs,
+)
+
+model = dict(
+ _delete_=True,
+ _scope_='mmrazor',
+ type='GeneralQuant',
+ architecture={{_base_.model}},
+ quantizer=dict(
+ type='TensorRTQuantizer',
+ skipped_methods=[
+ 'mmcls.models.heads.ClsHead._get_loss',
+ 'mmcls.models.heads.ClsHead._get_predictions'
+ ],
+ qconfig=dict(
+ qtype='affine',
+ w_observer=dict(type='mmrazor.MinMaxObserver'),
+ a_observer=dict(type='mmrazor.EMAMinMaxObserver'),
+ w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
+ a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
+ w_qscheme=dict(
+ bit=2,
+ is_symmetry=False,
+ is_per_channel=True,
+ is_pot_scale=False,
+ ),
+ a_qscheme=dict(
+ bit=4,
+ is_symmetry=False,
+ is_per_channel=False,
+ is_pot_scale=False),
+ )))
diff --git a/docs/en/advanced_guides/tutorials/how_to_prune_your_model.md b/docs/en/advanced_guides/tutorials/how_to_prune_your_model.md
new file mode 100644
index 000000000..0efbc7cde
--- /dev/null
+++ b/docs/en/advanced_guides/tutorials/how_to_prune_your_model.md
@@ -0,0 +1,134 @@
+# How to prune your model
+
+## Overview
+
+This section will introduce you to pruning your model. Before that, we suggest you read the document [User Guides: Pruning Framework](../../user_guides/pruning_user_guide.md) to have an overview of our pruning framework.
+
+First, we suppose your model is defined and trained using one openmmlab repo.
+Our pruning algorithms work as a wrapper of a model. To prune your model, you need to replace your model config with our algorithm config, which has a parameter 'architecture' to store your original model. The pipeline is shown below.
+
+
+
+After this replacement, the algorithm will prune your model during your training process.
+
+## How to Config an Algorithm
+
+All pruning algorithms are defined in mmrazor.models.algorithms.pruning. All algorithms have some shared pruning-related arguments, some specific arguments, and some shared mmengine.BaseModel arguments.
+
+Here we take pruning resnet34 using the l1-norm algorithm as an example. We use "mmcls::resnet/resnet34_8xb32_in1k.py" as a base config. Then we override the model config and use the original model config as the architecture of 'ItePruneAlgorithm'.
+
+```python
+_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']
+
+stage_ratio_1 = 0.7
+stage_ratio_2 = 0.7
+stage_ratio_3 = 0.7
+stage_ratio_4 = 1.0
+
+target_pruning_ratio = {
+ 'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_1,
+ 'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
+ 'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,
+ 'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1,
+ 'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_2,
+ 'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2,
+ 'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,
+ 'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,
+ 'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_2,
+ 'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_3,
+ 'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,
+ 'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4,
+ 'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,
+ 'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4
+}
+
+architecture = _base_.model
+
+model = dict(
+ _scope_='mmrazor',
+ _delete_=True,
+ type='ItePruneAlgorithm',
+ architecture=architecture,
+ mutator_cfg=dict(
+ type='BaseChannelMutator',
+ channel_unit_cfg=dict(
+ type='L1MutableChannelUnit',
+ default_args=dict(choice_mode='ratio'))
+ parse_cfg=dict(
+ type='BackwardTracer',
+ loss_calculator=dict(type='ImageClassifierPseudoLoss')),
+ target_pruning_ratio=target_pruning_ratio,
+ step_epoch=1,
+ prune_times=1,
+ data_preprocessor=None,
+ init_cfg=None
+)
+```
+
+**Shared pruning-related arguments**: All pruning algorithms have two shared pruning-related arguments.
+
+- Architecture
+ - Architecture defines the model to be pruned. Usually, you need to pass your original model config to the argument.
+- mutator_cfg
+ - The config of a mutator to manage the structure of your model. Usually, each algorithm has a frequently-used mutator. Please refer to the next section for more detail.
+
+**Specific arguments**:
+A algorithm may have its specific arguments. You need to read their documents to know how to config. Here, we only introduce the specific arguments of ItePruneAlgorithm.
+
+- target_pruning_ratio: target_pruning_ratio is a dict that uses the name of units as keys and the choice values as values.. It indicates how many channels remain after pruning. You can use python ./tools/get_channel_units.py --choice {config_file} to get the choice template. Please refer to [How to Use our Config Tool for Pruning](./how_to_use_config_tool_of_pruning.md).
+- step_epoch: the step between two pruning operations.
+- prune_times: the times to prune to reach the pruning target. Here, we prune resnet34 once, so we set it to 1.
+
+**Shared BaseModel arguments**:
+Our algorithms inherit from BaseModel, so each algorithm has shared arguments from BaseModel.
+
+- data_preprocessor: Used for pre-processing data sampled by dataloader to the format accepted by :meth:`forward`.
+- init_cfg: Initialization config dict
+
+## How to Config A Mutator
+
+A mutator is used to manage the structure of a model.
+
+Mutators have two augments:
+
+- channel_unit_cfg: config of channel units. The config should follow the template below.
+
+ ```python
+ channel_unit_cfg = dict(
+ # type of used MutableChannelUnit
+ type ='XxxMutableChannelUnit',
+ # default args for MutableChananelUnit
+ default_args={},
+ units = {
+ # config of a unit
+ "xxx_unit_name": {
+ "init_args":{},
+ "channels":{},
+ },
+ ...
+ }
+ ),
+ ```
+
+ MutableChannelUnit decides how to generate a channel choice. It's important to choose the right MutableChannelUnit. Here, we choose 'L1MutableChannelUnit' to apply the l1-norm algorithm.
+
+- parse_cfg: parse_cfg defines the method to parse the model and get channel units.
+ There are three ways used in BaseChannelMutator to parse a model and get MutableChannelUnits.
+
+ 1. Using tracer. It needs parse_cfg to be the config of a tracer.
+ 2. Using config. When parse_cfg\['type'\]='Config'. It needs that channel_unit_cfg\['unit'\]\['xxx_unit_name\] to have a key 'channels' to indicate channel units.
+ 3. Using the model with pre-defined DynamicOps and MutableChannels: When parse_cfg\['type'\]='Predefined', the mutator will parse the dynamic ops in the model and get channel units.
+
+In the example above, we directly use a tracer to parse the model.
+We also provide a tool for you to configure the mutator, please refer to [How to Use our Config Tool for Pruning](./how_to_use_config_tool_of_pruning.md).
+Besides, please refer to [ChannelMutator](../../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb) for more details about ChannelMutator.
+
+## End
+
+After configuring the algorithm, you can rerun the config file with a pretrained checkpoint to prune your model.
diff --git a/docs/en/advanced_guides/tutorials/how_to_use_config_tool_of_pruning.md b/docs/en/advanced_guides/tutorials/how_to_use_config_tool_of_pruning.md
new file mode 100644
index 000000000..a41204c22
--- /dev/null
+++ b/docs/en/advanced_guides/tutorials/how_to_use_config_tool_of_pruning.md
@@ -0,0 +1,239 @@
+# How to Use our Config Tool for Pruning
+
+## How We Get MutableChannelUnits Automatically
+
+Our pruning framework can automatically parse a model and get MutableChannelUnits.
+It makes it easy to prune new models.
+
+The parsing process is placed in ChannelUnitMutator.prepare_from_supernet. We first trace the model and get a graph, then we parse the graph and get MutableChannelUnits.
+
+
+
+## How to Get ChannelUnit Config Template
+
+To make the configuration of ChannelUnit easy, we provide an interface to get the config template: ChannelMutator.config_template(). It returns a config dict. The config\['channel_unit_cfg'\]\['units\] store all parsed MutableChannelUnits.
+
+```python
+def config_template(self,
+ only_mutable_units=False,
+ with_unit_init_args=False,
+ with_channels=False):
+ """Config template of the mutator.
+
+ Args:
+ only_mutable_units (bool, optional): Whether only return config of
+ prunable units. It can omit unmutable MutableChannelUnits
+ to decrease the length of the config. Defaults to False.
+ with_unit_init_args (bool, optional): Whether return init_args of
+ units. Let it be true, when you want to change the init
+ args of units. Defaults to False.
+ with_channels (bool, optional): Whether return channel info.
+ The channel info can initialization the units without
+ tracer. When you want to prune your model without a
+ tracer next time, let it be true. Defaults to False.
+
+ Example:
+ dict(
+ channel_unit_cfg = dict(
+ # type of used MutableChannelUnit
+ type ='XxxMutableChannelUnit',
+ # default args for MutableChananelUnit
+ default_args={},
+ # config of units
+ units = {
+ # config of a unit
+ "xxx_unit_name": {
+ 'init_args':{}, # if with_unit_init_args
+ 'channels':{} # if with_channels
+ },
+ ...
+ }
+ ),
+ # config of tracer
+ parse_cfg={}
+ )
+
+
+ About the detail of the config of each unit, please refer to
+ MutableChannelUnit.config_template()
+ """
+```
+
+Note the name of a unit is generated automatically according to their content, avoid to change the name in config.
+
+Here, we give an example of getting a config template using code.
+
+```python
+from mmrazor.models.mutators import ChannelMutator
+from torchvision.models import resnet34
+model = resnet34()
+# initialize a ChannelMutator object
+mutator = ChannelMutator(
+ channel_unit_cfg=dict(
+ type='SequentialMutableChannelUnit',
+ default_args=dict(choice_mode='ratio'),
+ units={},
+ ),
+ parse_cfg=dict(
+ type='BackwardTracer',
+ loss_calculator=dict(type='ImageClassifierPseudoLoss')))
+# init the ChannelMutator object with a model
+mutator.prepare_from_supernet(model)
+config=mutator.config_template(with_unit_init_args=True)
+print(config)
+# {
+# 'type': 'ChannelMutator',
+# 'channel_unit_cfg': {
+# 'type': 'SequentialMutableChannelUnit',
+# 'default_args': {
+# 'choice_mode': 'ratio'
+# },
+# 'units': {
+# 'conv1_(0, 3)_3': {
+# 'init_args': {
+# 'num_channels': 3,
+# 'choice_mode': 'ratio',
+# ...
+# },
+# 'choice': 1.0
+# },
+# ...
+# }
+# },
+# 'parse_cfg': {
+# 'type': 'BackwardTracer',
+# 'loss_calculator': {
+# 'type': 'ImageClassifierPseudoLoss'
+# }
+# }
+# }
+```
+
+Besides, it's also easy to initialize a new mutator using the config dict.
+
+```python
+# follow the code above
+from mmrazor.registry import MODELS
+mutator2=MODELS.build(config)
+mutator2.prepare_from_supernet(resnet34())
+```
+
+To make your development more fluent, we provide a command tool to parse a model and return the config template.
+
+```shell
+$ python ./tools/get_channel_units.py -h
+
+usage: get_channel_units.py [-h] [-c] [-i] [--choice] [-o OUTPUT_PATH] config
+
+Get channel unit of a model.
+
+positional arguments:
+ config config of the model
+
+optional arguments:
+ -h, --help show this help message and exit
+ -c, --with-channel output with channel config
+ -i, --with-init-args output with init args
+ --choice output choices template. When this flag is activated, -c and -i will be ignored
+ -o OUTPUT_PATH, --output-path OUTPUT_PATH
+ the file path to store channel unit info
+```
+
+Take the algorithm Slimmable Network as an example.
+
+```shell
+python ./tools/get_channel_units.py ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
+
+# {
+# "type":"SlimmableChannelMutator",
+# "channel_unit_cfg":{
+# "type":"SlimmableChannelUnit",
+# "default_args":{},
+# "units":{
+# "backbone.conv1.conv_(0, 3)_3":{
+# "choice":3
+# },
+# "backbone.conv1.conv_(0, 48)_48":{
+# "choice":32
+# },
+ ...
+# }
+# },
+# "parse_cfg":{
+# "type":"BackwardTracer",
+# "loss_calculator":{
+# "type":"ImageClassifierPseudoLoss"
+# }
+# }
+# }
+```
+
+The '-i' flag will return the config with the initialization arguments.
+
+```shell
+python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
+
+# {
+# "type":"SlimmableChannelMutator",
+# "channel_unit_cfg":{
+# "type":"SlimmableChannelUnit",
+# "default_args":{},
+# "units":{
+# "backbone.conv1.conv_(0, 3)_3":{
+# "init_args":{
+# "num_channels":3,
+# "divisor":1,
+# "min_value":1,
+# "min_ratio":0.9,
+# "candidate_choices":[
+# 3
+# ],
+# "choice_mode":"number"
+# },
+# "choice":3
+# },
+# ...
+# }
+# },
+# "parse_cfg":{
+# "type":"BackwardTracer",
+# "loss_calculator":{
+# "type":"ImageClassifierPseudoLoss"
+# }
+# }
+# }
+```
+
+With "--choice" flag, it will return the choice template, a dict which uses unit_name as key, and use the choice value as value.
+
+```shell
+python ./tools/get_channel_units.py -i ./configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py --choice
+
+# {
+# "backbone.conv1.conv_(0, 48)_48":32,
+# "backbone.layer1.0.conv.1.conv_(0, 24)_24":16,
+# "backbone.layer2.0.conv.0.conv_(0, 144)_144":144,
+# "backbone.layer2.0.conv.2.conv_(0, 40)_40":24,
+# "backbone.layer2.1.conv.0.conv_(0, 240)_240":176,
+# "backbone.layer3.0.conv.0.conv_(0, 240)_240":192,
+# "backbone.layer3.0.conv.2.conv_(0, 48)_48":48,
+# "backbone.layer3.1.conv.0.conv_(0, 288)_288":240,
+# "backbone.layer3.2.conv.0.conv_(0, 288)_288":144,
+# "backbone.layer4.0.conv.0.conv_(0, 288)_288":264,
+# "backbone.layer4.0.conv.2.conv_(0, 96)_96":88,
+# "backbone.layer4.1.conv.0.conv_(0, 576)_576":288,
+# "backbone.layer4.2.conv.0.conv_(0, 576)_576":336,
+# "backbone.layer4.3.conv.0.conv_(0, 576)_576":432,
+# "backbone.layer5.0.conv.0.conv_(0, 576)_576":576,
+# "backbone.layer5.0.conv.2.conv_(0, 144)_144":144,
+# "backbone.layer5.1.conv.0.conv_(0, 864)_864":576,
+# "backbone.layer5.2.conv.0.conv_(0, 864)_864":648,
+# "backbone.layer6.0.conv.0.conv_(0, 864)_864":864,
+# "backbone.layer6.0.conv.2.conv_(0, 240)_240":240,
+# "backbone.layer6.1.conv.0.conv_(0, 1440)_1440":1440,
+# "backbone.layer6.2.conv.0.conv_(0, 1440)_1440":1440,
+# "backbone.layer7.0.conv.0.conv_(0, 1440)_1440":1440,
+# "backbone.layer7.0.conv.2.conv_(0, 480)_480":480,
+# "backbone.conv2.conv_(0, 1920)_1920":1920
+# }
+```
diff --git a/docs/en/imgs/pruning/draw-config.png b/docs/en/imgs/pruning/draw-config.png
new file mode 100644
index 000000000..b733b7622
Binary files /dev/null and b/docs/en/imgs/pruning/draw-config.png differ
diff --git a/docs/en/imgs/pruning/framework-ChanelMutator.png b/docs/en/imgs/pruning/framework-ChanelMutator.png
new file mode 100644
index 000000000..0b44b4c30
Binary files /dev/null and b/docs/en/imgs/pruning/framework-ChanelMutator.png differ
diff --git a/docs/en/imgs/pruning/framework-algorithm.png b/docs/en/imgs/pruning/framework-algorithm.png
new file mode 100644
index 000000000..e373dc40c
Binary files /dev/null and b/docs/en/imgs/pruning/framework-algorithm.png differ
diff --git a/docs/en/imgs/pruning/framework-framework.png b/docs/en/imgs/pruning/framework-framework.png
new file mode 100644
index 000000000..4344fccc9
Binary files /dev/null and b/docs/en/imgs/pruning/framework-framework.png differ
diff --git a/docs/en/imgs/pruning/framework-graph.png b/docs/en/imgs/pruning/framework-graph.png
new file mode 100644
index 000000000..72ef9e632
Binary files /dev/null and b/docs/en/imgs/pruning/framework-graph.png differ
diff --git a/docs/en/imgs/pruning/framework-op.png b/docs/en/imgs/pruning/framework-op.png
new file mode 100644
index 000000000..04513d4d2
Binary files /dev/null and b/docs/en/imgs/pruning/framework-op.png differ
diff --git a/docs/en/imgs/pruning/pruning_framework.png b/docs/en/imgs/pruning/pruning_framework.png
new file mode 100644
index 000000000..ad954953b
Binary files /dev/null and b/docs/en/imgs/pruning/pruning_framework.png differ
diff --git a/docs/en/imgs/pruning/unit.png b/docs/en/imgs/pruning/unit.png
new file mode 100644
index 000000000..ec2830c66
Binary files /dev/null and b/docs/en/imgs/pruning/unit.png differ
diff --git a/docs/en/user_guides/pruning_user_guide.md b/docs/en/user_guides/pruning_user_guide.md
new file mode 100644
index 000000000..a1600c887
--- /dev/null
+++ b/docs/en/user_guides/pruning_user_guide.md
@@ -0,0 +1,148 @@
+# User Guides: Pruning Framework
+
+## Background
+
+// TODO
+
+## Pruning Framework
+
+This document introduces the pruning framework in mmrazor. Our pruning framework can help you prune a model automatically, making it easy to extend new algorithms.
+
+The pruning framework consists of five modules: Algorithm, ChanelMutator, MutableChannelUnit, MutableChannel, and DynamicOp. Their main features are detailed below:
+
+| Module | Features |
+| ------------------ | --------------------------------------------------------------------- |
+| Algorithm | Controls training process. |
+| ChanelMutator | Manages the pruning structure of the model. |
+| MutableChannelUnit | Makes pruning decisions. |
+| MutableChannel | Manage a channel mask. |
+| DynamicOp | Forwards with mutable number of channels, and exports pruned modules. |
+
+
+
+
+
+## Algorithm
+
+
+
+
+
+Algorithms inherit from BaseAlgorithm. They control the training process, like deciding when to prune the model in the training/finetune process.
+
+For example, IteAlgorithm prunes the model iteratively by certain epochs.
+
+Here is an example of how to use PruneAlgoritm.
+
+```python
+from mmrazor.models.algorithms import IteAlgorithm
+from mmengine.model import BaseModel
+import torch.nn as nn
+
+class Model(BaseModel):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(3, 8, 3, 1, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+model = Model()
+algorithm = IteAlgorithm(model,
+ mutator_cfg=dict(
+ type='ChannelMutator',
+ channl_unit_cfg=dict(type='L1ChannelUnit')),)
+print(algorithm)
+# IteAlgorithm(
+# (data_preprocessor): BaseDataPreprocessor()
+# (architecture): Model(
+# (data_preprocessor): BaseDataPreprocessor()
+# (conv): DynamicConv2d(
+# 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
+# (mutable_attrs): ModuleDict(
+# (in_channels): MutableChannelContainer(name=, num_channels=3, activated_channels: 3
+# (out_channels): MutableChannelContainer(name=, num_channels=8, activated_channels: 8
+# )
+# )
+# )
+# (mutator): BaseChannelMutator()
+# )
+
+```
+
+## ChanelMutator
+
+
+
+A ChanelMutator controls the pruning structure of a model. In other words, ChanelMutator decides how many channels each layer prunes. Usually, given a pruning target, such as a flops, latency, or pruning ratio target, the ChannelUnitMutator will output a pruning structure for the model. The pruning structure is variable. The default definition is the remaining channel ratio, and it's also easy to extend to the number of channels or channel buckets.
+
+As some layers' channels are related, the related layers share one pruning decision. We put these associated layers into a MutableChannelUnit. Therefore, the ChanelMutator directly decides the pruning ratio of each MutableChannelUnit.
+
+```python
+from mmrazor.models.mutators import BaseChannelMutator
+from mmengine.model import BaseModel
+import torch.nn as nn
+
+class Model(BaseModel):
+ def __init__(self):
+ super().__init__()
+ self.feature = nn.Sequential(
+ nn.Conv2d(3, 8, 3, 2, 1),
+ nn.Conv2d(8, 16, 3, 2, 1)
+ )
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.head = nn.Linear(16, 1000)
+
+ def forward(self, x):
+ x_ = self.pool(self.feature(x))
+ return self.head(x_.flatten(1))
+
+model = Model()
+mutator = BaseChannelMutator()
+mutator.prepare_from_supernet(model)
+print(mutator.sample_choices())
+# {
+# 'feature.0_(0, 8)_out_1_in_1': 0.5,
+# 'feature.1_(0, 16)_out_1_in_1': 0.5625
+# }
+```
+
+Please refer to [ChannelMutator](../../../mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb) for more details.
+
+## MutableChannelUnit
+
+
+
+Because some layers' channels are related, the related layers are collected and put in a MutableChannelUnit.
+
+Each MutableChannelUnit accepts a pruning ratio and generates a channel mask for all related layers.
+
+All related layers are divided into two types: output_related and input_related.
+
+- The output channels of output-related layers are in the MutableChannelUnit.
+- The input channels of input-related layers are in the MutableChannelUnit.
+
+Please refer to [MutableChannelUnit](../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb) for more details.
+
+Besides, basic PyTorch modules are converted to DynamicOps, which can deal with a mutable number of channels with MutableChannels.
+
+## DynamicOP && MutableChannel
+
+
+
+**MutableChannel**: Each MutableChannel manages a channel mask for a model. They help DynamicOps to deal with mutable numbers of channels. Please refer to [MutableChannel](../../../mmrazor/models/mutables/mutable_channel/MutableChannel.md) for more details.
+
+**DynamicOp**: DynamicOps inherit from basic torch modules, like nn.Conv2d or nn.Linear. They can forward with mutable numbers of channels and export pruned torch modules.
+Compared with basic torch modules, each DynamicOp has two MutableChannel modules, which control the input and output channels.
+
+## More Documents about Pruning
+
+Please refer to the following documents for more details.
+
+- Development tutorials
+ - [How to prune your model](../advanced_guides/tutorials/how_to_prune_your_model.md)
+ - [How to use config tool of pruning](../advanced_guides/tutorials/how_to_use_config_tool_of_pruning.md)
+- READMEs
+ - [MutableChannel](../../../mmrazor/models/mutables/mutable_channel/MutableChannel.md)
+ - [ChannelMutator](../../../mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb)
+ - [MutableChannelUnit](../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb)
diff --git a/docs/zh_cn/user_guides/visualization.md b/docs/zh_cn/user_guides/visualization.md
new file mode 100644
index 000000000..4642139d5
--- /dev/null
+++ b/docs/zh_cn/user_guides/visualization.md
@@ -0,0 +1,158 @@
+## 可视化
+
+## 特征图可视化
+
+
+
+
+可视化可以给深度学习的模型训练和测试过程提供直观解释。
+
+MMRazor 中,将使用 MMEngine 提供的 `Visualizer` 可视化器搭配 MMRazor 自带的 `Recorder`组件的数据记录功能,进行特征图可视化,其具备如下功能:
+
+- 支持基础绘图接口以及特征图可视化。
+- 支持选择模型中的任意位点来得到特征图,包含 `pixel_wise_max` ,`squeeze_mean` , `select_max` , `topk` 四种显示方式,用户还可以使用 `arrangement` 自定义特征图显示的布局方式。
+
+## 特征图绘制
+
+你可以调用 `tools/visualizations/vis_configs/feature_visualization.py` 来简单快捷地得到单张图片单个模型的可视化结果。
+
+为了方便理解,将其主要参数的功能梳理如下:
+
+- `img`:选择要用于特征图可视化的图片,支持单张图片或者图片路径列表。
+
+- `config`:选择算法的配置文件。
+
+- `vis_config`:可视化功能需借助可配置的 `Recorder` 组件获取模型中用户自定义位点的特征图,
+ 用户可以将 `Recorder` 相关配置文件放入 `vis_config` 中。 MMRazor提供了对backbone及neck
+ 输出进行可视化对应的config文件,详见 `configs/visualizations`
+
+- `checkpoint`:选择对应算法的权重文件。
+
+- `--out-file`:将得到的特征图保存到本地,并指定路径和文件名。若没有选定,则会直接显示特征图。
+
+- `--device`:指定用于推理图片的硬件,`--device cuda:0` 表示使用第 1 张 GPU 推理,`--device cpu` 表示用 CPU 推理。
+
+- `--repo`:模型对应的算法库。`--repo mmdet` 表示模型为检测模型。
+
+- `--use-norm`:是否将获取的特征图进行batch normalization后再显示。
+
+- `--overlaid`:是否将特征图覆盖在原图之上。若设为True,考虑到输入的特征图通常非常小,函数默认将特征图进行上采样后方便进行可视化。
+
+- `--channel-reduction`:输入的 Tensor 一般是包括多个通道的,`channel_reduction` 参数可以将多个通道压缩为单通道,然后和图片进行叠加显示,有以下三个参数可以设置:
+
+ - `pixel_wise_max`:将输入的 C 维度采用 max 函数压缩为一个通道,输出维度变为 (1, H, W)。
+ - `squeeze_mean`:将输入的 C 维度采用 mean 函数压缩为一个通道,输出维度变成 (1, H, W)。
+ - `select_max`:从输入的 C 维度中先在空间维度 sum,维度变成 (C, ),然后选择值最大的通道。
+ - `None`:表示不需要压缩,此时可以通过 `topk` 参数可选择激活度最高的 `topk` 个特征图显示。
+
+- `--topk`:只有在 `channel_reduction` 参数为 `None` 的情况下, `topk` 参数才会生效,其会按照激活度排序选择 `topk` 个通道,然后和图片进行叠加显示,并且此时会通过 `--arrangement` 参数指定显示的布局,该参数表示为一个数组,两个数字需要以空格分开,例如: `--topk 5 --arrangement 2 3` 表示以 `2行 3列` 显示激活度排序最高的 5 张特征图, `--topk 7 --arrangement 3 3` 表示以 `3行 3列` 显示激活度排序最高的 7 张特征图。
+
+ - 如果 topk 不是 -1,则会按照激活度排序选择 topk 个通道显示。
+ - 如果 topk = -1,此时通道 C 必须是 1 或者 3 表示输入数据是图片,否则报错提示用户应该设置 `channel_reduction` 来压缩通道。
+
+- `--arrangement`:特征图的排布。当 `channel_reduction` 不是None且topk > 0时才会有用。
+
+- `--resize-shape`:当`--overlaid`为True时,是否需要将原图和特征图resize为某一尺寸。
+
+- `--cfg-options`:由于不同算法库的visualizer拥有特例化的add_datasample方法,如mmdet的visualizer
+ 拥有 `pred_score_thr` 作为输入参数,可以在`--cfg-options`加入一些特例化的设置。
+
+类似的,用户可以通过调用 `tools/visualizations/vis_configs/feature_diff_visualization.py` 来得到
+单张图片两个模型的特征差异可视化结果,用法与上述类似,差异为:
+
+- `config1` / `config2`:选择算法1/2的配置文件。
+- `checkpoint1` / `checkpoint2`:选择对应算法1/2的权重文件。
+
+## 用法示例
+
+以预训练好的 RetinaNet-r101 与 RetinaNet-r50 模型为例:
+
+请提前下载 RetinaNet-r101 与 RetinaNet-r50 模型权重到本仓库根路径下:
+
+```shell
+cd mmrazor
+wget https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_2x_coco/retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth
+wget https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth
+```
+
+(1) 将多通道特征图采用 `pixel_wise_max` 参数压缩为单通道并显示, 通过提取 `neck` 层输出进行特征图可视化(这里只显示了前4个stage的特征图):
+
+```shell
+python tools/visualizations/feature_visualization.py \
+ tools/visualizations/demo.jpg \
+ PATH/TO/THE/CONFIG \
+ tools/visualizations/vis_configs/fpn_feature_visualization.py \
+ retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth \
+ --repo mmdet --use-norm --overlaid
+ --channel-reduction pixel_wise_max
+```
+
+
+
+
+
+(2) 将多通道特征图采用 `select_max` 参数压缩为单通道并显示, 通过提取 `neck` 层输出进行特征图可视化(这里只显示了前4个stage的特征图):
+
+```shell
+python tools/visualizations/feature_visualization.py \
+ tools/visualizations/demo.jpg \
+ PATH/TO/THE/CONFIG \
+ tools/visualizations/vis_configs/fpn_feature_visualization.py \
+ retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth \
+ --repo mmdet --overlaid
+ --channel-reduction select_max
+```
+
+
+
+
+
+(3) 将多通道特征图采用 `squeeze_mean` 参数压缩为单通道并显示, 通过提取 `neck` 层输出进行特征图可视化(这里只显示了前4个stage的特征图):
+
+```shell
+python tools/visualizations/feature_visualization.py \
+ tools/visualizations/demo.jpg \
+ PATH/TO/THE/CONFIG \
+ tools/visualizations/vis_configs/fpn_feature_visualization.py \
+ retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth \
+ --repo mmdet --overlaid
+ --channel-reduction squeeze_mean
+```
+
+
+
+
+
+(4) 将多通道特征图采用 `squeeze_mean` 参数压缩为单通道并显示, 通过提取 `neck` 层输出进行特征图可视化(这里只显示了前4个stage的特征图):
+
+```shell
+python tools/visualizations/feature_visualization.py \
+ tools/visualizations/demo.jpg \
+ PATH/TO/THE/CONFIG \
+ tools/visualizations/vis_configs/fpn_feature_visualization.py \
+ retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth \
+ --repo mmdet --overlaid
+ --channel-reduction squeeze_mean
+```
+
+
+
+
+
+(5) 将多通道的两个模型的特征图差异采用 `pixel_wise_max` 参数压缩为单通道并显示, 这里只显示了前4个stage的特征图差异:
+
+```shell
+python tools/visualizations/feature_diff_visualization.py \
+ tools/visualizations/demo.jpg \
+ PATH/TO/THE/CONFIG1 \
+ PATH/TO/THE/CONFIG2 \
+ tools/visualizations/vis_configs/fpn_feature_diff_visualization.py.py \
+ retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth \
+ retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth \
+ --repo mmdet --use-norm --overlaid
+ --channel-reduction pixel_wise_max
+```
+
+
+
+
diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py
index f2df86a83..64f86bc1c 100644
--- a/mmrazor/engine/__init__.py
+++ b/mmrazor/engine/__init__.py
@@ -3,13 +3,14 @@
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
- GreedySamplerTrainLoop, SelfDistillValLoop,
- SingleTeacherDistillValLoop, SlimmableValLoop)
+ GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop,
+ SelfDistillValLoop, SingleTeacherDistillValLoop,
+ SlimmableValLoop)
__all__ = [
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
- 'SelfDistillValLoop'
+ 'SelfDistillValLoop', 'QATEpochBasedLoop', 'PTQLoop'
]
diff --git a/mmrazor/engine/hooks/__init__.py b/mmrazor/engine/hooks/__init__.py
index 2fc3cc12f..d25c7c993 100644
--- a/mmrazor/engine/hooks/__init__.py
+++ b/mmrazor/engine/hooks/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dump_subnet_hook import DumpSubnetHook
from .estimate_resources_hook import EstimateResourcesHook
+from .visualization_hook import RazorVisualizationHook
-__all__ = ['DumpSubnetHook', 'EstimateResourcesHook']
+__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook']
diff --git a/mmrazor/engine/hooks/visualization_hook.py b/mmrazor/engine/hooks/visualization_hook.py
new file mode 100644
index 000000000..a52145b10
--- /dev/null
+++ b/mmrazor/engine/hooks/visualization_hook.py
@@ -0,0 +1,205 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from typing import List, Optional, Union
+
+import mmcv
+import torch
+from mmcv.transforms import Compose
+from mmengine.dist import master_only
+from mmengine.fileio import FileClient
+from mmengine.hooks import Hook
+from mmengine.model import is_model_wrapper
+from mmengine.utils import mkdir_or_exist
+from mmengine.visualization import Visualizer
+
+from mmrazor.models.task_modules import RecorderManager
+from mmrazor.registry import HOOKS
+from mmrazor.visualization.local_visualizer import modify
+
+
+def norm(feat):
+ assert len(feat.shape) == 4
+ N, C, H, W = feat.shape
+ feat = feat.permute(1, 0, 2, 3).reshape(C, -1)
+ mean = feat.mean(dim=-1, keepdim=True)
+ std = feat.std(dim=-1, keepdim=True)
+ centered = (feat - mean) / (std + 1e-6)
+ centered = centered.reshape(C, N, H, W).permute(1, 0, 2, 3)
+ return centered
+
+
+@HOOKS.register_module()
+class RazorVisualizationHook(Hook):
+ """Razor Visualization Hook. Used to visualize training process immediate
+ feature maps.
+
+ 1. If ``show`` is True, it means that only the immediate feature maps are
+ visualized without storing data, so ``vis_backends`` needs to
+ be excluded.
+ 2. If ``out_dir`` is specified, it means that the immediate feature maps
+ need to be saved to ``out_dir``. In order to avoid vis_backends
+ also storing data, so ``vis_backends`` needs to be excluded.
+ 3. ``vis_backends`` takes effect if the user does not specify ``show``
+ and `out_dir``. You can set ``vis_backends`` to WandbVisBackend or
+ TensorboardVisBackend to store the immediate feature maps in Wandb or
+ Tensorboard.
+
+ Args:
+ recorders (dict): All recorders' config.
+ mappings: (Dict[str, Dict]): The mapping between feature names and
+ records.
+ enabled (bool): Whether to draw immediate feature maps. If it is False,
+ it means that no drawing will be done. Defaults to False.
+ interval (int): The interval of visualization. Defaults to 1.
+ show (bool): Whether to display the drawn image. Default to False.
+ wait_time (float): The interval of show (s). Defaults to 0.
+ out_dir (str, optional): directory where painted images
+ will be saved in testing process.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmengine.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ is_overlaid (bool): If `is_overlaid` is True, the final output image
+ will be the weighted sum of img and featmap. Defaults to True.
+ visualization_cfg (dict): Configs for visualization.
+ use_norm (bool): Whether to apply Batch Normalization over the
+ feature map. Defaults to False.
+ """
+
+ def __init__(self,
+ recorders: dict,
+ mappings: dict,
+ enabled: bool = False,
+ data_idx: Union[int, List] = 0,
+ interval: int = 1,
+ show: bool = False,
+ wait_time: float = 0.1,
+ out_dir: Optional[str] = None,
+ file_client_args: dict = dict(backend='disk'),
+ is_overlaid: bool = True,
+ visualization_cfg=dict(
+ channel_reduction='pixel_wise_max',
+ topk=20,
+ arrangement=(4, 5),
+ resize_shape=None,
+ alpha=0.5),
+ use_norm: bool = False):
+ self.enabled = enabled
+ self._visualizer: Visualizer = Visualizer.get_current_instance()
+ self._visualizer.draw_featmap = modify
+ if isinstance(data_idx, int):
+ data_idx = [data_idx]
+ self.data_idx = data_idx
+ self.show = show
+ if self.show:
+ # No need to think about vis backends.
+ self._visualizer._vis_backends = {}
+ warnings.warn('The show is True, it means that only '
+ 'the prediction results are visualized '
+ 'without storing data, so vis_backends '
+ 'needs to be excluded.')
+
+ self.wait_time = wait_time
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.out_dir = out_dir
+ self.interval = interval
+
+ self.is_overlaid = is_overlaid
+ self.visualization_cfg = visualization_cfg
+ self.use_norm = use_norm
+
+ self.recorder_manager = RecorderManager(recorders)
+ self.mappings = mappings
+
+ self._step = 0 # Global step value to record
+
+ @master_only
+ def before_run(self, runner) -> None:
+ model = runner.model
+ if is_model_wrapper(model):
+ self.recorder_manager.initialize(model.module)
+ else:
+ self.recorder_manager.initialize(model)
+
+ @master_only
+ def before_train(self, runner):
+ if not self.enabled or runner.epoch % self.interval != 0:
+ return
+ self._visualize(runner, 'before_run')
+
+ @master_only
+ def after_train_epoch(self, runner) -> None:
+ if not self.enabled or runner.epoch % self.interval != 0:
+ return
+ self._visualize(runner, f'epoch_{runner.epoch}')
+
+ def _visualize(self, runner, stage):
+ if self.out_dir is not None:
+ self.out_dir = osp.join(runner.work_dir, runner.timestamp,
+ self.out_dir)
+ mkdir_or_exist(self.out_dir)
+
+ if self.file_client is None:
+ self.file_client = FileClient(**self.file_client_args)
+
+ cfg = runner.cfg.copy()
+ test_pipeline = cfg.test_dataloader.dataset.pipeline
+ new_test_pipeline = []
+ for pipeline in test_pipeline:
+ if pipeline['type'] != 'LoadAnnotations' and pipeline[
+ 'type'] != 'LoadPanopticAnnotations':
+ new_test_pipeline.append(pipeline)
+
+ test_pipeline = Compose(new_test_pipeline)
+ dataset = runner.val_loop.dataloader.dataset
+
+ for idx in self.data_idx:
+ data_info = dataset.get_data_info(idx)
+ img_path = data_info['img_path']
+ data_ = dict(img_path=img_path, img_id=0)
+ data_ = test_pipeline(data_)
+
+ data_['inputs'] = [data_['inputs']]
+ data_['data_samples'] = [data_['data_samples']]
+
+ with torch.no_grad(), self.recorder_manager:
+ runner.model.test_step(data_)
+
+ if self.is_overlaid:
+ img_bytes = self.file_client.get(img_path)
+ overlaid_image = mmcv.imfrombytes(
+ img_bytes, channel_order='rgb')
+ else:
+ overlaid_image = None
+
+ for name, record in self.mappings.items():
+ recorder = self.recorder_manager.get_recorder(record.recorder)
+ record_idx = getattr(record, 'record_idx', 0)
+ data_idx = getattr(record, 'data_idx', None)
+ feats = recorder.get_record_data(record_idx, data_idx)
+ if isinstance(feats, torch.Tensor):
+ feats = (feats, )
+
+ for i, feat in enumerate(feats):
+ if self.use_norm:
+ feat = norm(feat)
+ drawn_img = self._visualizer.draw_featmap(
+ feat[0], overlaid_image, **self.visualization_cfg)
+
+ out_file = None
+ if self.out_dir is not None:
+ out_file = f'{stage}_data_idx_{idx}_{name}_{i}.jpg'
+ out_file = osp.join(self.out_dir, out_file)
+
+ self._visualizer.add_datasample(
+ f'{stage}_data_idx_{idx}_{name}_{i}',
+ drawn_img,
+ draw_gt=False,
+ draw_pred=False,
+ show=self.show,
+ wait_time=0.1,
+ # TODO: Supported in mmengine's Viusalizer.
+ out_file=out_file,
+ step=self._step)
+ self._step += 1
diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py
index 9715a4e6b..ba80113fa 100644
--- a/mmrazor/engine/runner/__init__.py
+++ b/mmrazor/engine/runner/__init__.py
@@ -3,11 +3,13 @@
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
from .evolution_search_loop import EvolutionSearchLoop
+from .quantization_loops import PTQLoop, QATEpochBasedLoop
from .slimmable_val_loop import SlimmableValLoop
from .subnet_sampler_loop import GreedySamplerTrainLoop
__all__ = [
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
- 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
+ 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
+ 'QATEpochBasedLoop', 'PTQLoop'
]
diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py
new file mode 100644
index 000000000..2f15f5deb
--- /dev/null
+++ b/mmrazor/engine/runner/quantization_loops.py
@@ -0,0 +1,298 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from mmengine.evaluator import Evaluator
+from mmengine.registry import MODELS
+from mmengine.runner import EpochBasedTrainLoop, TestLoop
+from torch.utils.data import DataLoader
+
+from mmrazor.models.task_modules import (ModuleInputsRecorder,
+ ModuleOutputsRecorder,
+ RecorderManager)
+from mmrazor.registry import LOOPS
+from .utils import extract_blocks, extract_layers, extract_subgraph
+
+_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
+
+
+@LOOPS.register_module()
+class QATEpochBasedLoop(EpochBasedTrainLoop):
+ """`EpochBasedLoop` for `QuantizationAwareTraining`
+
+ Args:
+ runner (Runner): A reference of runner
+ dataloader (Dataloader or dict): An iterator to generate one batch of
+ dataset each iteration.
+ max_epochs (int): Total training epochs.
+ calibrate_dataloader (Dataloader or dict, optional): A dataloader
+ object or a dict to build a dataloader for calibration. Defaults
+ to None.
+ val_begin (int): The epoch that begins validating.
+ Defaults to 1.
+ val_interval (int): Validation interval. Defaults to 1.
+ dynamic_intervals (List[Tuple[int, int]], optional): The
+ first element in the tuple is a milestone and the second
+ element is a interval. The interval is used after the
+ corresponding milestone. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ runner,
+ dataloader: Union[DataLoader, Dict],
+ max_epochs: int,
+ calibrate_dataloader: Union[DataLoader, Dict] = None,
+ val_begin: int = 1,
+ val_interval: int = 1,
+ dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
+ super().__init__(runner, dataloader, max_epochs, val_begin,
+ val_interval, dynamic_intervals)
+ if isinstance(calibrate_dataloader, dict):
+ # Determine whether or not different ranks use different seed.
+ diff_rank_seed = runner._randomness_cfg.get(
+ 'diff_rank_seed', False)
+ self.calibrate_dataloader = runner.build_dataloader(
+ calibrate_dataloader,
+ seed=runner.seed,
+ diff_rank_seed=diff_rank_seed)
+ else:
+ self.calibrate_dataloader = calibrate_dataloader
+
+ self.is_calibrate = True if calibrate_dataloader is not None else False
+
+ if self.runner.distributed:
+ self.model = runner.model.module
+ else:
+ self.model = runner.model
+
+ def calibrate(self, calibrate_dataloader) -> None:
+ self.model.eval()
+ with torch.no_grad():
+ for batch_data in calibrate_dataloader:
+ self.model(batch_data)
+
+ def run(self) -> None:
+ """Launch training."""
+ self.runner.call_hook('before_train')
+
+ self.model.prepare()
+
+ if self.is_calibrate:
+ self.model.state = (1, 0)
+ self.calibrate(self.calibrate_dataloader)
+
+ self.model.state = (1, 1)
+
+ while self._epoch < self._max_epochs:
+ self.run_epoch()
+
+ self._decide_current_val_interval()
+ if (self.runner.val_loop is not None
+ and self._epoch >= self.val_begin
+ and self._epoch % self.val_interval == 0):
+ self.runner.val_loop.run()
+
+ self.model.convert()
+
+ # self.runner.val_loop.run()
+
+ self.runner.call_hook('after_train')
+
+
+@LOOPS.register_module()
+class PTQLoop(TestLoop):
+ """`TestLoop` for Post Training Quantization.
+
+ Args:
+ runner (Runner): A reference of runner
+ dataloader (Dataloader or dict): An iterator to generate one batch of
+ dataset each iteration.
+ evaluator (Evaluator or dict or list): Used for computing metrics.
+ calibrate_dataloader (Dataloader or dict, optional): A dataloader
+ object or a dict to build a dataloader for calibration. Defaults
+ to None.
+ batch_num (Optional[int], optional): Total calibration batches.
+ Defaults to None.
+ reconstruction_cfg (Optional[Dict], optional): Model reconstruction
+ configuration. Defaults to None.
+ fp16 (bool, optional): Enable FP16 training mode. Defaults to False.
+ """
+
+ def __init__(self,
+ runner,
+ dataloader: Union[DataLoader, Dict],
+ evaluator: Union[Evaluator, Dict, List],
+ calibrate_dataloader: Optional[Union[DataLoader,
+ Dict]] = None,
+ batch_num: Optional[int] = None,
+ reconstruction_cfg: Optional[Dict] = None,
+ fp16: bool = False):
+ super().__init__(runner, dataloader, evaluator, fp16)
+ if isinstance(calibrate_dataloader, dict):
+ # Determine whether or not different ranks use different seed.
+ diff_rank_seed = runner._randomness_cfg.get(
+ 'diff_rank_seed', False)
+ self.calibrate_dataloader = runner.build_dataloader(
+ calibrate_dataloader,
+ seed=runner.seed,
+ diff_rank_seed=diff_rank_seed)
+ else:
+ self.calibrate_dataloader = calibrate_dataloader
+
+ self.is_calibrate = True if calibrate_dataloader is not None else False
+
+ if self.runner.distributed:
+ self.model = runner.model.module
+ else:
+ self.model = runner.model
+
+ self.batch_num = batch_num
+ self.config = reconstruction_cfg
+
+ def calibrate(self, calibrate_dataloader) -> None:
+ self.model.eval()
+ with torch.no_grad():
+ for i, batch_data in enumerate(calibrate_dataloader):
+ if self.batch_num and i >= self.batch_num:
+ break
+ self.model.calib_step(batch_data)
+
+ def _save_inter_result(self,
+ model,
+ dataloader,
+ slices,
+ store_input=True,
+ store_output=True):
+ recorders = {}
+ for s in slices:
+ node_l, node_r = s[:2]
+ if store_input:
+ recorders[node_l.target + '_input'] = ModuleInputsRecorder(
+ node_l.target)
+ if store_output:
+ recorders[node_r.target + '_output'] = ModuleOutputsRecorder(
+ node_r.target)
+ manager = RecorderManager(recorders)
+ manager.initialize(model)
+
+ with torch.no_grad():
+ with manager:
+ for i, batch_data in enumerate(dataloader):
+ if self.batch_num and i >= self.batch_num:
+ break
+ batch_data = self.model.data_preprocessor(
+ batch_data, False)
+ model(**batch_data)
+ return manager
+
+ def sub_reconstruction(self, graphmodule, input_recorder, output_recorder,
+ config):
+ w_para = []
+ for layer in graphmodule.modules():
+ # import pdb
+ # pdb.set_trace()
+ if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
+ weight_fake_quant = layer.weight_fake_quant
+ weight_fake_quant.init(layer.weight.data)
+ w_para += [weight_fake_quant.alpha]
+
+ w_opt = torch.optim.Adam(w_para)
+ loss_func = MODELS.build(config.loss)
+
+ for _ in range(config.loss.iters):
+ w_opt.zero_grad()
+
+ data_size = len(input_recorder.data_buffer)
+ data_index = np.random.randint(0, data_size)
+ out_quant = graphmodule(
+ input_recorder.get_recorder_data(data_index))
+ out_fp = output_recorder.get_recorder_data(data_index)
+ err = loss_func(graphmodule, out_quant, out_fp)
+ err.backward()
+ w_opt.step()
+
+ for layer in graphmodule.modules():
+ if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
+ weight_fake_quant = layer.weight_fake_quant
+ layer.weight.data = weight_fake_quant.get_hard_value(
+ layer.weight.data)
+ weight_fake_quant.adaround = False
+ if isinstance(layer, torch.quantization.FakeQuantize) and hasattr(
+ layer, 'prob'):
+ # recover to promise that drop activation quantization only
+ # occurs at reconstruction phase
+ layer.prob = 1.0
+
+ def reconstruction(self, graphmodule, calibrate_dataloader, config):
+ assert isinstance(graphmodule, torch.fx.GraphModule)
+ graphmodule_fp = graphmodule
+ graphmodule_quant = copy.deepcopy(graphmodule)
+
+ # get layers/blocks need to reconstructe
+ slices = []
+ if config.pattern == 'layer':
+ slices = extract_layers(
+ graphmodule, layer_types=_ADAROUND_SUPPORT_TYPE)
+ elif config.pattern == 'block':
+ slices = extract_blocks(graphmodule)
+ else:
+ # TODO: add remind
+ raise NotImplementedError
+
+ # save fp inputs and outputs of each layers
+ manager_fp = self._save_inter_result(graphmodule_fp,
+ self.calibrate_dataloader, slices)
+
+ # extract subgraph_module
+ for s in slices:
+ sub_graphmodule = extract_subgraph(graphmodule_quant, s)
+ manager_quant = self._save_inter_result(
+ graphmodule_quant,
+ self.calibrate_dataloader, [s],
+ store_output=False)
+ input_index = s[0].target + '_input'
+ output_index = s[1].target + '_output'
+ input_recorder = manager_quant.get_recorder(input_index)
+ output_recorder = manager_fp.get_recorder(output_index)
+ self.sub_reconstruction(sub_graphmodule, input_recorder,
+ output_recorder, config)
+
+ return graphmodule_quant
+
+ def run(self) -> None:
+ """Launch test."""
+ self.runner.call_hook('before_test')
+ self.runner.call_hook('before_test_epoch')
+
+ self.model.eval()
+ self.model.prepare()
+
+ if self.is_calibrate:
+ self.model.state = (1, 0)
+ self.calibrate(self.calibrate_dataloader)
+
+ self.model.state = (1, 1)
+
+ if self.config is not None:
+ self.model.architecture = self.reconstruction(
+ self.model.architecture, self.calibrate_dataloader,
+ self.config)
+
+ self.model.convert()
+
+ self.model.eval()
+ from torch.onnx import OperatorExportTypes
+ dummy_input = torch.randn([1, 3, 224, 224])
+ onnx_path = os.path.join(self.runner.work_dir, 'quantizied.onnx')
+ torch.onnx.export(
+ self.model.architecture,
+ dummy_input,
+ onnx_path,
+ opset_version=11,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
+
+ self.runner.call_hook('after_test')
diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py
index ec2f2cb29..7f55bef0a 100644
--- a/mmrazor/engine/runner/utils/__init__.py
+++ b/mmrazor/engine/runner/utils/__init__.py
@@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .check import check_subnet_flops
from .genetic import crossover
+from .state import set_quant_state
+from .subgraph import extract_blocks, extract_layers, extract_subgraph
-__all__ = ['crossover', 'check_subnet_flops']
+__all__ = [
+ 'crossover', 'check_subnet_flops', 'extract_subgraph', 'extract_blocks',
+ 'extract_layers', 'set_quant_state'
+]
diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py
index e2fdcfcc6..4b7078b4f 100644
--- a/mmrazor/engine/runner/utils/check.py
+++ b/mmrazor/engine/runner/utils/check.py
@@ -4,7 +4,7 @@
import torch.nn as nn
-from mmrazor.models import ResourceEstimator
+from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.structures import export_fix_subnet, load_fix_subnet
from mmrazor.utils import SupportRandomSubnet
diff --git a/mmrazor/engine/runner/utils/state.py b/mmrazor/engine/runner/utils/state.py
new file mode 100644
index 000000000..2f6d602a5
--- /dev/null
+++ b/mmrazor/engine/runner/utils/state.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.logging import print_log
+from torch.ao.quantization import FakeQuantize
+
+
+# TODO: may be removed
+def set_quant_state(model, enable_observer=True, enable_fake_quant=True):
+ for name, submodule in model.named_modules():
+ if isinstance(submodule, FakeQuantize):
+ if enable_observer:
+ submodule.enable_observer()
+ else:
+ submodule.disable_observer()
+ if enable_fake_quant:
+ submodule.enable_fake_quant()
+ else:
+ submodule.disable_fake_quant()
+ print_log(f'Enable observer: {enable_observer}; \
+ Enable fake quant: {enable_fake_quant}')
diff --git a/mmrazor/engine/runner/utils/subgraph.py b/mmrazor/engine/runner/utils/subgraph.py
new file mode 100644
index 000000000..ea0f8837f
--- /dev/null
+++ b/mmrazor/engine/runner/utils/subgraph.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.fx as fx
+
+
+def extract_subgraph(graphmodule, block_slice):
+ subgraph = copy.deepcopy(graphmodule.graph)
+ block_start, block_end = block_slice[:2]
+ for node in subgraph.nodes:
+ if node.name == 'inputs':
+ input_node = node
+ if node.name == block_start.name:
+ node.replace_input_with(node.prev, input_node)
+ if node.name == block_end.name:
+ output_node = node
+ if node.op == 'output':
+ node.replace_input_with(node.prev, output_node)
+ subgraph.lint()
+ subgraph_module = fx.GraphModule(graphmodule, subgraph)
+ subgraph_module.graph.eliminate_dead_code()
+ subgraph_module.recompile()
+ return subgraph_module
+
+
+def extract_blocks(graph, key_word='layer'):
+ block_slices = []
+ block_slice = []
+ pre_stage_index, pre_block_index = 0, 0
+ cur_stage_index, cur_block_index = 0, 0
+ for node in graph.nodes:
+ if key_word not in node.name:
+ continue
+ else:
+ items = node.name.split('_')
+ for i, item in enumerate(items):
+ if key_word in item:
+ cur_stage_index = int(item[5:])
+ cur_block_index = int(items[i + 1])
+ break
+ if (cur_block_index != pre_block_index) or (cur_stage_index !=
+ pre_stage_index):
+ block_slice.append(node.prev)
+ if len(block_slice) == 2:
+ block_slices.append(block_slice)
+ block_slice = []
+ block_slice.append(node)
+
+ pre_stage_index, pre_block_index = cur_stage_index, cur_block_index
+
+ return block_slices
+
+
+def extract_layers(graphmodule, layer_types):
+ layer_slices = []
+ for node in graphmodule.graph.nodes:
+ if node.op == 'call_module':
+ m = graphmodule.get_submodule(node.target)
+ if isinstance(m, layer_types):
+ layer_slices.append((node, node))
+ return layer_slices
diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py
index f5295aa9e..e5b9ec451 100644
--- a/mmrazor/models/__init__.py
+++ b/mmrazor/models/__init__.py
@@ -2,7 +2,11 @@
from .algorithms import * # noqa: F401,F403
from .architectures import * # noqa: F401,F403
from .distillers import * # noqa: F401,F403
+from .fake_quants import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .mutables import * # noqa: F401,F403
from .mutators import * # noqa: F401,F403
+from .observers import * # noqa: F401,F403
+from .quantizers import * # noqa: F401,F403
from .task_modules import * # noqa: F401,F403
+from .utils import * # noqa: F401,F403
diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py
index e6258b012..1bcf89629 100644
--- a/mmrazor/models/algorithms/__init__.py
+++ b/mmrazor/models/algorithms/__init__.py
@@ -3,26 +3,33 @@
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, OverhaulFeatureDistillation,
SelfDistill, SingleTeacherDistill)
-from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
+from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
from .pruning.ite_prune_algorithm import ItePruneAlgorithm
+from .quantization import GeneralQuant
__all__ = [
- 'SingleTeacherDistill',
+ # base
'BaseAlgorithm',
+ # distill
+ 'DAFLDataFreeDistillation',
+ 'DataFreeDistillation',
'FpnTeacherDistill',
+ 'OverhaulFeatureDistillation',
+ 'SelfDistill',
+ 'SingleTeacherDistill',
+ # nas
+ 'DSNAS',
+ 'DSNASDDP',
'SPOS',
- 'SlimmableNetwork',
- 'SlimmableNetworkDDP',
'AutoSlim',
'AutoSlimDDP',
'Darts',
'DartsDDP',
- 'SelfDistill',
- 'DataFreeDistillation',
- 'DAFLDataFreeDistillation',
- 'OverhaulFeatureDistillation',
+ # pruning
+ 'SlimmableNetwork',
+ 'SlimmableNetworkDDP',
'ItePruneAlgorithm',
- 'Dsnas',
- 'DsnasDDP',
+ # quantization
+ 'GeneralQuant'
]
diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py
index 17eab7e86..b290afa0a 100644
--- a/mmrazor/models/algorithms/nas/__init__.py
+++ b/mmrazor/models/algorithms/nas/__init__.py
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim import AutoSlim, AutoSlimDDP
from .darts import Darts, DartsDDP
-from .dsnas import Dsnas, DsnasDDP
+from .dsnas import DSNAS, DSNASDDP
from .spos import SPOS
__all__ = [
- 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
+ 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'DSNAS', 'DSNASDDP'
]
diff --git a/mmrazor/models/algorithms/nas/dsnas.py b/mmrazor/models/algorithms/nas/dsnas.py
index 62c2c7f04..5434ce0ac 100644
--- a/mmrazor/models/algorithms/nas/dsnas.py
+++ b/mmrazor/models/algorithms/nas/dsnas.py
@@ -23,7 +23,7 @@
@MODELS.register_module()
-class Dsnas(BaseAlgorithm):
+class DSNAS(BaseAlgorithm):
"""Implementation of `DSNAS `_
Args:
@@ -272,7 +272,7 @@ def handle_grads(self):
@MODEL_WRAPPERS.register_module()
-class DsnasDDP(MMDistributedDataParallel):
+class DSNASDDP(MMDistributedDataParallel):
def __init__(self,
*,
diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
index cca03a71f..4b592740a 100644
--- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
+++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
@@ -7,6 +7,7 @@
from mmengine.model import BaseModel
from mmengine.structures import BaseDataElement
+from mmrazor.models.mutables import MutableChannelUnit
from mmrazor.models.mutators import ChannelMutator
from mmrazor.registry import MODELS
from ..base import BaseAlgorithm
@@ -107,7 +108,7 @@ def __init__(self,
channel_unit_cfg=dict(
type='SequentialMutableChannelUnit')),
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
- target_pruning_ratio={},
+ target_pruning_ratio: Optional[Dict[str, float]] = None,
step_epoch=1,
prune_times=1,
init_cfg: Optional[Dict] = None) -> None:
@@ -118,15 +119,49 @@ def __init__(self,
self.mutator: ChannelMutator = MODELS.build(mutator_cfg)
self.mutator.prepare_from_supernet(self.architecture)
+ if target_pruning_ratio is None:
+ group_target_ratio = self.mutator.current_choices
+ else:
+ group_target_ratio = self.group_target_pruning_ratio(
+ target_pruning_ratio, self.mutator.search_groups)
+
# config_manager
- self.check_prune_targe(target_pruning_ratio)
self.prune_config_manager = ItePruneConfigManager(
- target_pruning_ratio,
- self.mutator.choice_template,
+ group_target_ratio,
+ self.mutator.current_choices,
step_epoch,
times=prune_times)
- def check_prune_targe(self, config: Dict):
+ def group_target_pruning_ratio(
+ self, target: Dict[str, float],
+ search_groups: Dict[int,
+ List[MutableChannelUnit]]) -> Dict[int, float]:
+ """According to the target pruning ratio of each unit, set the target
+ ratio of each search group."""
+ group_target: Dict[int, float] = dict()
+ for group_id, units in search_groups.items():
+ for unit in units:
+ unit_name = unit.name
+ # The config of target pruning ratio does not
+ # contain all units.
+ if unit_name not in target:
+ continue
+ if group_id in group_target:
+ unit_target = target[unit_name]
+ if unit_target != group_target[group_id]:
+ group_names = [u.name for u in units]
+ raise ValueError(
+ f"'{unit_name}' target ratio is different from "
+ f'other units in the same group {group_names}. '
+ 'Pls check your target pruning ratio config.')
+ else:
+ unit_target = target[unit_name]
+ assert isinstance(unit_target, (float, int))
+ group_target[group_id] = unit_target
+
+ return group_target
+
+ def check_prune_target(self, config: Dict):
"""Check if the prune-target is supported."""
for value in config.values():
assert isinstance(value, int) or isinstance(value, float)
@@ -141,7 +176,9 @@ def forward(self,
self._iteration):
config = self.prune_config_manager.prune_at(self._epoch)
+
self.mutator.set_choices(config)
+
logger = MMLogger.get_current_instance()
logger.info(f'The model is pruned at {self._epoch}th epoch once.')
diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py
new file mode 100644
index 000000000..84c25bbc0
--- /dev/null
+++ b/mmrazor/models/algorithms/quantization/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import GeneralQuant
+
+__all__ = ['GeneralQuant']
diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py
new file mode 100644
index 000000000..718b08725
--- /dev/null
+++ b/mmrazor/models/algorithms/quantization/base.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from mmengine.structures import BaseDataElement
+from torch.fx import GraphModule
+
+from mmrazor.registry import MODELS
+from ..base import BaseAlgorithm
+
+LossResults = Dict[str, torch.Tensor]
+TensorResults = Union[Tuple[torch.Tensor], torch.Tensor]
+PredictResults = List[BaseDataElement]
+ForwardResults = Union[LossResults, TensorResults, PredictResults]
+
+
+@MODELS.register_module()
+class GeneralQuant(BaseAlgorithm):
+ """General quantization.
+
+ Args:
+ Args:
+ architecture (dict | :obj:`BaseModel`): The config of
+ :class:`BaseModel` or built model.
+ quantizer (dict | :obj:`BaseModel`): The config of
+ :class:`BaseQuantizer` or built model.
+ data_preprocessor (dict | torch.nn.Module | None): The pre-process
+ config of :class:`BaseDataPreprocessor`. Defaults to None.
+ init_cfg (dict): The weight initialized config for
+ :class:`BaseModule`.
+ """
+
+ def __init__(self,
+ architecture,
+ quantizer,
+ data_preprocessor=None,
+ init_cfg=None):
+ if data_preprocessor is None:
+ data_preprocessor = {}
+ # The build process is in MMEngine, so we need to add scope here.
+ data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
+
+ super().__init__(architecture, data_preprocessor, init_cfg)
+ self.quantizer = MODELS.build(quantizer)
+ self.observers_enabled = True
+ self.fake_quants_enabled = True
+ self.gen_graphs(self.architecture)
+
+ def gen_graphs(self, model):
+ self.quantizer._swap_ff_with_fxff(model)
+ tracer = self.quantizer.tracer
+ for mode in ['tensor', 'loss', 'predict']:
+ concrete_args = {'mode': mode}
+ if mode == 'tensor':
+ self.graph_tensor = GraphModule(
+ model, tracer.trace(model, concrete_args=concrete_args))
+ if mode == 'loss':
+ self.graph_loss = GraphModule(
+ model, tracer.trace(model, concrete_args=concrete_args))
+ if mode == 'predict':
+ self.graph_predict = GraphModule(
+ model, tracer.trace(model, concrete_args=concrete_args))
+
+ def forward(self,
+ inputs: torch.Tensor,
+ data_samples: Optional[List[BaseDataElement]] = None,
+ mode: str = 'tensor') -> ForwardResults:
+
+ if mode == 'loss':
+ return self.graph_loss(inputs, data_samples, mode)
+ elif mode == 'tensor':
+ return self.graph_tensor(inputs, data_samples, mode)
+ elif mode == 'predict':
+ return self.graph_predict(inputs, data_samples, mode)
+ else:
+ raise RuntimeError(f'Invalid mode "{mode}". '
+ 'Only supports loss, predict and tensor mode')
+
+ def calib_step(self, data):
+ data = self.data_preprocessor(data, False)
+ return self._run_forward(data, mode='tensor')
+
+ def prepare(self, mode='tensor'):
+ assert mode in ['tensor', 'loss', 'predict']
+ if mode == 'tensor':
+ graph = self.graph_tensor
+ elif mode == 'loss':
+ graph = self.graph_loss
+ else:
+ graph = self.graph_predict
+ self.architecture = self.quantizer.prepare(self.architecture, graph)
+
+ def convert(self):
+ self.architecture = self.quantizer.convert(self.architecture)
+
+ @property
+ def state(self):
+ return (self.observers_enabled, self.fake_quants_enabled)
+
+ @state.setter
+ def state(self, state):
+ observers_enabled, fake_quants_enabled = state
+ for name, submodule in self.architecture.named_modules():
+ if isinstance(submodule, torch.quantization.FakeQuantize):
+ if observers_enabled:
+ submodule.enable_observer()
+ else:
+ submodule.disable_observer()
+
+ if fake_quants_enabled:
+ submodule.enable_fake_quant()
+ else:
+ submodule.disable_fake_quant()
+
+ self.observers_enabled = observers_enabled
+ self.fake_quants_enabled = fake_quants_enabled
diff --git a/mmrazor/models/architectures/heads/__init__.py b/mmrazor/models/architectures/heads/__init__.py
index de84c30d5..0d7da475d 100644
--- a/mmrazor/models/architectures/heads/__init__.py
+++ b/mmrazor/models/architectures/heads/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_subnet_head import DartsSubnetClsHead
+from .deit_head import DeiTClsHead
-__all__ = ['DartsSubnetClsHead']
+__all__ = ['DartsSubnetClsHead', 'DeiTClsHead']
diff --git a/mmrazor/models/architectures/heads/deit_head.py b/mmrazor/models/architectures/heads/deit_head.py
new file mode 100644
index 000000000..61d587d93
--- /dev/null
+++ b/mmrazor/models/architectures/heads/deit_head.py
@@ -0,0 +1,69 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from mmrazor.registry import MODELS
+
+try:
+ from mmcls.models import VisionTransformerClsHead
+except ImportError:
+ from mmrazor.utils import get_placeholder
+ VisionTransformerClsHead = get_placeholder('mmcls')
+
+
+@MODELS.register_module()
+class DeiTClsHead(VisionTransformerClsHead):
+ """Distilled Vision Transformer classifier head.
+
+ Comparing with the :class:`DeiTClsHead` in mmcls, this head support to
+ train the distilled version DeiT.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ hidden_dim (int, optional): Number of the dimensions for hidden layer.
+ Defaults to None, which means no extra hidden layer.
+ act_cfg (dict): The activation config. Only available during
+ pre-training. Defaults to ``dict(type='Tanh')``.
+ init_cfg (dict): The extra initialization configs. Defaults to
+ ``dict(type='Constant', layer='Linear', val=0)``.
+ """
+
+ def _init_layers(self):
+ """"Init extra hidden linear layer to handle dist token if exists."""
+ super(DeiTClsHead, self)._init_layers()
+ if self.hidden_dim is None:
+ head_dist = nn.Linear(self.in_channels, self.num_classes)
+ else:
+ head_dist = nn.Linear(self.hidden_dim, self.num_classes)
+ self.layers.add_module('head_dist', head_dist)
+
+ def pre_logits(
+ self, feats: Tuple[List[torch.Tensor]]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """The process before the final classification head.
+
+ The input ``feats`` is a tuple of list of tensor, and each tensor is
+ the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
+ feature of the last stage and forward in hidden layer if exists.
+ """
+ _, cls_token, dist_token = feats[-1]
+ if self.hidden_dim is None:
+ return cls_token, dist_token
+ else:
+ cls_token = self.layers.act(self.layers.pre_logits(cls_token))
+ dist_token = self.layers.act(self.layers.pre_logits(dist_token))
+ return cls_token, dist_token
+
+ def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
+ """The forward process."""
+ cls_token, dist_token = self.pre_logits(feats)
+ # The final classification head.
+ cls_score = self.layers.head(cls_token)
+ # Forward so that the corresponding recorder can record the output
+ # of the distillation token
+ _ = self.layers.head_dist(dist_token)
+ return cls_score
diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py
new file mode 100644
index 000000000..cea7708a2
--- /dev/null
+++ b/mmrazor/models/fake_quants/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .adaround import AdaRoundFakeQuantize
+from .base import FakeQuantize
+from .lsq import LearnableFakeQuantize
+from .qdrop import QDropFakeQuantize
+
+__all__ = [
+ 'FakeQuantize', 'AdaRoundFakeQuantize', 'QDropFakeQuantize',
+ 'LearnableFakeQuantize'
+]
diff --git a/mmrazor/models/fake_quants/adaround.py b/mmrazor/models/fake_quants/adaround.py
new file mode 100644
index 000000000..9388f1aa4
--- /dev/null
+++ b/mmrazor/models/fake_quants/adaround.py
@@ -0,0 +1,98 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parameter import Parameter
+
+from mmrazor.registry import MODELS
+from .base import FakeQuantize
+
+_version_under_1100 = int(torch.__version__.split('.')[1]) < 10
+
+
+@MODELS.register_module()
+class AdaRoundFakeQuantize(FakeQuantize):
+
+ def __init__(self, observer, **observer_kwargs):
+ super().__init__(observer, **observer_kwargs)
+ self.adaround = False
+
+ def init(self, weight_tensor: torch.Tensor):
+ self.adaround = True
+ self.observer_enabled[0] = 0
+ self.fake_quant_enabled[0] = 1
+
+ # self.soft_targets = False # delete this
+ self.gamma = -0.1
+ self.zeta = 1.1
+ self.init_alpha(x=weight_tensor.data.clone())
+
+ def init_alpha(self, x: torch.Tensor):
+ if self.ch_axis != -1:
+ new_shape = [1] * len(x.shape)
+ new_shape[self.ch_axis] = x.shape[self.ch_axis]
+ scale = self.scale.data.reshape(new_shape)
+ else:
+ scale = self.scale.data
+ x_floor = torch.floor(x / scale)
+ rest = (x / scale) - x_floor # rest of rounding [0, 1)
+ alpha = -torch.log((self.zeta - self.gamma) /
+ (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
+ self.alpha = Parameter(alpha)
+
+ def rectified_sigmoid(self):
+ """Function to generate rounding mask.
+
+ Args:
+ x (torch.Tensor):
+ zeta (torch.Tensor):
+ gamma (torch.Tensor):
+ Returns:
+ torch.Tensor:
+ """
+ return ((self.zeta - self.gamma) * torch.sigmoid(self.alpha) +
+ self.gamma).clamp(0, 1)
+
+ def adaround_forward(self, x, hard_value=False):
+ if self.ch_axis != -1:
+ new_shape = [1] * len(x.shape)
+ new_shape[self.ch_axis] = x.shape[self.ch_axis]
+ scale = self.scale.reshape(new_shape)
+ zero_point = self.zero_point.reshape(new_shape)
+ x = torch.floor(x / scale)
+ if hard_value:
+ x += (self.alpha >= 0).float()
+ else:
+ x += self.rectified_sigmoid(self.alpha, self.zeta, self.gamma)
+ x += zero_point
+ x = torch.clamp(x, self.quant_min, self.quant_max)
+ x = (x - zero_point) * scale
+ return x
+
+ def forward(self, X):
+ if self.observer_enabled[0] == 1:
+ self.activation_post_process(X.detach())
+ _scale, _zero_point = self.calculate_qparams()
+ _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
+ self.zero_point.device)
+ if self.scale.shape != _scale.shape:
+ self.scale.resize_(_scale.shape)
+ self.zero_point.resize_(_zero_point.shape)
+ self.scale.copy_(_scale)
+ self.zero_point.copy_(_zero_point)
+
+ if self.fake_quant_enabled[0] == 1:
+ if not self.adaround:
+ if self.is_per_channel:
+ X = torch.fake_quantize_per_channel_affine(
+ X, self.scale,
+ self.zero_point.long()
+ if _version_under_1100 else self.zero_point,
+ self.ch_axis, self.quant_min, self.quant_max)
+ else:
+ X = torch.fake_quantize_per_tensor_affine(
+ X, self.scale.item(), int(self.zero_point.item()),
+ self.quant_min, self.quant_max)
+ else:
+ if not hasattr(self, 'alpha'):
+ raise NotImplementedError
+ X = self.adaround_forward(X)
+ return X
diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py
new file mode 100644
index 000000000..13f8a1e43
--- /dev/null
+++ b/mmrazor/models/fake_quants/base.py
@@ -0,0 +1,124 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.ao.quantization import FakeQuantizeBase
+
+from mmrazor.models.utils import (_is_float_qparams, _is_per_channel,
+ _is_per_tensor, _is_symmetric_quant)
+from mmrazor.registry import MODELS
+
+
+@MODELS.register_module()
+class FakeQuantize(FakeQuantizeBase):
+
+ scale: torch.Tensor
+ zero_point: torch.Tensor
+
+ def __init__(self, observer, **observer_kwargs):
+ super().__init__()
+ self.activation_post_process = observer(**observer_kwargs)
+ self.quant_min = self.activation_post_process.quant_min
+ self.quant_max = self.activation_post_process.quant_max
+ if _is_float_qparams(self.activation_post_process.qscheme):
+ zero_point_dtype = torch.float
+ else:
+ zero_point_dtype = torch.int
+ self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
+ self.register_buffer('zero_point',
+ torch.tensor([0], dtype=zero_point_dtype))
+ self.dtype = self.activation_post_process.dtype
+ self.qscheme = self.activation_post_process.qscheme
+ self.ch_axis = self.activation_post_process.ch_axis \
+ if hasattr(self.activation_post_process, 'ch_axis') else -1
+ assert _is_per_channel(self.qscheme) or \
+ _is_per_tensor(self.qscheme), \
+ 'Only per channel and per tensor quantization are supported in ' \
+ 'fake quantize' + ' got qscheme: ' + str(self.qscheme)
+ self.is_per_channel = _is_per_channel(self.qscheme)
+
+ bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double()
+ self.bitwidth = int(torch.log2(bitrange).item())
+ self.is_pot_scale = self.activation_post_process.is_pot_scale
+ self.is_symmetric_quant = _is_symmetric_quant(self.qscheme)
+
+ @torch.jit.export
+ def calculate_qparams(self):
+ return self.activation_post_process.calculate_qparams()
+
+ def forward(self, X):
+ if self.observer_enabled[0] == 1:
+ self.activation_post_process(X.detach())
+ _scale, _zero_point = self.calculate_qparams()
+ _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
+ self.zero_point.device)
+ if self.scale.shape != _scale.shape:
+ self.scale.resize_(_scale.shape)
+ self.zero_point.resize_(_zero_point.shape)
+ self.scale.copy_(_scale)
+ self.zero_point.copy_(_zero_point)
+
+ if self.fake_quant_enabled[0] == 1:
+ if self.is_per_channel:
+ X = torch.fake_quantize_per_channel_affine(
+ X, self.scale, self.zero_point, self.ch_axis,
+ self.activation_post_process.quant_min,
+ self.activation_post_process.quant_max)
+ else:
+ X = torch.fake_quantize_per_tensor_affine(
+ X, self.scale, self.zero_point,
+ self.activation_post_process.quant_min,
+ self.activation_post_process.quant_max)
+ return X
+
+ @torch.jit.export
+ def extra_repr(self):
+ return 'fake_quant_enabled={}, observer_enabled={}, ' \
+ 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ' \
+ 'ch_axis={}, scale={}, zero_point={}'.format(
+ self.fake_quant_enabled, self.observer_enabled,
+ self.activation_post_process.quant_min,
+ self.activation_post_process.quant_max, self.dtype,
+ self.qscheme, self.ch_axis, self.scale, self.zero_point)
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ # We cannot currently register scalar values as buffers, so need to
+ # manually specify serialization here.
+ super(FakeQuantize, self)._save_to_state_dict(destination, prefix,
+ keep_vars)
+ destination[prefix + 'scale'] = self.scale
+ destination[prefix + 'zero_point'] = self.zero_point
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ # Removing this function throws an error that the the size of the
+ # loaded tensor does not match the original size i.e., These buffers
+ # start out with numel 0 and become numel 1 once they have their
+ # first forward pass.
+ local_state = ['scale', 'zero_point']
+ for name in local_state:
+ key = prefix + name
+ if key in state_dict:
+ val = state_dict[key]
+ # Custom handling to allow loading scale and zero_point
+ # of size N into uninitialized buffers of size 0. The
+ # buffers are resized here, and the values are copied in
+ # the default state_dict loading code of the parent.
+ if name == 'scale':
+ self.scale.resize_(val.shape)
+ else:
+ assert name == 'zero_point'
+ self.zero_point.resize_(val.shape)
+ # For torchscript module we need to update the attributes here
+ # since we do not call the `_load_from_state_dict` function
+ # defined module.py
+ if torch.jit.is_scripting():
+ if name == 'scale':
+ self.scale.copy_(val)
+ else:
+ assert name == 'zero_point'
+ self.zero_point.copy_(val)
+ elif strict:
+ missing_keys.append(key)
+ super(FakeQuantize,
+ self)._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys,
+ unexpected_keys, error_msgs)
diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py
new file mode 100644
index 000000000..10970a6a3
--- /dev/null
+++ b/mmrazor/models/fake_quants/lsq.py
@@ -0,0 +1,137 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parameter import Parameter
+
+from mmrazor.registry import MODELS
+from ..utils import PerChannelLoadHook, _is_symmetric_quant, is_tracing_state
+from .base import FakeQuantize
+
+
+@MODELS.register_module()
+class LearnableFakeQuantize(FakeQuantize):
+ r""" This is an extension of the FakeQuantize module in fake_quantize.py,
+ which supports more generalized lower-bit quantization and support learning
+ of the scale and zero point parameters through backpropagation. For
+ literature references, please see the class
+ `_LearnableFakeQuantizePerTensorOp`. In addition to the attributes in the
+ original FakeQuantize module, the `_LearnableFakeQuantize` module also
+ includes the following attributes to support quantization parameter
+ learning.
+ """
+
+ def __init__(self,
+ observer,
+ scale=1.,
+ zero_point=0.,
+ use_grad_scaling=True,
+ **observer_kwargs):
+ super(LearnableFakeQuantize, self).__init__(observer,
+ **observer_kwargs)
+ self.use_grad_scaling = use_grad_scaling
+ self.scale = Parameter(torch.tensor([scale]))
+ self.zero_point = Parameter(torch.tensor([zero_point]))
+ self.register_buffer('eps',
+ torch.tensor([torch.finfo(torch.float32).eps]))
+ # Check whether the module will load a state dict;
+ # Initialize the shape of per-channel 'scale' and
+ # 'zero-point' before copying values
+ self.load_state_dict_hook = PerChannelLoadHook(self)
+
+ @torch.jit.export
+ def extra_repr(self):
+ return 'fake_quant_enabled={}, observer_enabled={}, ' \
+ 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={},'\
+ 'scale={}, zero_point={}'.format(
+ self.fake_quant_enabled, self.observer_enabled,
+ self.quant_min, self.quant_max,
+ self.dtype, self.qscheme, self.ch_axis,
+ self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape), # noqa: E501
+ self.zero_point if self.ch_axis == -1 else 'List')
+
+ def forward(self, X):
+ # Learnable fake quantize have to zero_point.float()
+ # to make it learnable.
+ if self.observer_enabled[0] == 1:
+ self.activation_post_process(X.detach())
+ _scale, _zero_point = \
+ self.activation_post_process.calculate_qparams()
+ _scale = _scale.to(self.scale.device)
+ _zero_point = _zero_point.to(self.zero_point.device)
+
+ if self.ch_axis != -1:
+ self.scale.data = torch.ones_like(_scale)
+ self.zero_point.data = torch.zeros_like(_zero_point.float())
+
+ self.scale.data.copy_(_scale)
+ self.zero_point.data.copy_(_zero_point.float())
+ else:
+ self.scale.data.abs_()
+ self.scale.data.clamp_(min=self.eps.item())
+
+ if self.fake_quant_enabled[0] == 1:
+ if _is_symmetric_quant(self.qscheme):
+ self.zero_point.data.zero_()
+ else:
+ self.zero_point.data.clamp_(self.quant_min,
+ self.quant_max).float()
+
+ if self.is_per_channel:
+ if self.use_grad_scaling:
+ grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] *
+ self.quant_max)**0.5
+ else:
+ grad_factor = 1.0
+ if is_tracing_state():
+ X = FakeQuantizeLearnablePerchannelAffine.apply(
+ X, self.scale, self.zero_point, self.ch_axis,
+ self.quant_min, self.quant_max, grad_factor)
+ else:
+ X = _fake_quantize_learnable_per_channel_affine_training(
+ X, self.scale, self.zero_point, self.ch_axis,
+ self.quant_min, self.quant_max, grad_factor)
+ else:
+ if self.use_grad_scaling:
+ grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5
+ else:
+ grad_factor = 1.0
+ X = torch._fake_quantize_learnable_per_tensor_affine(
+ X, self.scale, self.zero_point, self.quant_min,
+ self.quant_max, grad_factor)
+ return X
+
+
+def _fake_quantize_learnable_per_channel_affine_training(
+ x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
+ zero_point = (zero_point.round() - zero_point).detach() + zero_point
+ new_shape = [1] * len(x.shape)
+ new_shape[ch_axis] = x.shape[ch_axis]
+ scale = grad_scale(scale, grad_factor).reshape(new_shape)
+ zero_point = grad_scale(zero_point, grad_factor).reshape(new_shape)
+ x = x / scale + zero_point
+ x = (x.round() - x).detach() + x
+ x = torch.clamp(x, quant_min, quant_max)
+ return (x - zero_point) * scale
+
+
+def grad_scale(t, scale):
+ return (t - (t * scale)).detach() + (t * scale)
+
+
+class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max,
+ grad_factor):
+ return _fake_quantize_learnable_per_channel_affine_training(
+ x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor)
+
+ @staticmethod
+ def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max,
+ grad_factor):
+ return g.op(
+ '::FakeQuantizeLearnablePerchannelAffine',
+ x,
+ scale,
+ zero_point,
+ quant_min_i=quant_min,
+ quant_max_i=quant_max)
diff --git a/mmrazor/models/fake_quants/qdrop.py b/mmrazor/models/fake_quants/qdrop.py
new file mode 100644
index 000000000..e2e13bfc0
--- /dev/null
+++ b/mmrazor/models/fake_quants/qdrop.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parameter import Parameter
+
+from mmrazor.registry import MODELS
+from .base import FakeQuantize
+
+
+@MODELS.register_module()
+class QDropFakeQuantize(FakeQuantize):
+
+ def __init__(self, observer, **observer_kwargs):
+ super().__init__(observer, **observer_kwargs)
+ self.scale = Parameter(torch.tensor([1.0], dtype=torch.float))
+ self.prob = 1.0
+
+ def forward(self, X):
+ if self.observer_enabled[0] == 1:
+ self.activation_post_process(X.detach())
+ _scale, _zero_point = self.calculate_qparams()
+ _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
+ self.zero_point.device)
+ if self.scale.shape != _scale.shape:
+ self.scale.resize_(_scale.shape)
+ self.zero_point.resize_(_zero_point.shape)
+ self.scale.copy_(_scale)
+ self.zero_point.copy_(_zero_point)
+
+ if self.fake_quant_enabled[0] == 1:
+ x_orig = X
+ if self.is_per_channel:
+ X = torch.fake_quantize_per_channel_affine(
+ X, self.scale, self.zero_point, self.ch_axis,
+ self.activation_post_process.quant_min,
+ self.activation_post_process.quant_max)
+ else:
+ X = torch.fake_quantize_per_tensor_affine(
+ X, self.scale, self.zero_point,
+ self.activation_post_process.quant_min,
+ self.activation_post_process.quant_max)
+ if self.prob < 1.0:
+ x_prob = torch.where(torch.rand_like(X) < self.prob, X, x_orig)
+ return x_prob
+ return X
diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py
index a145ba914..42327e564 100644
--- a/mmrazor/models/losses/__init__.py
+++ b/mmrazor/models/losses/__init__.py
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ab_loss import ABLoss
+from .adaround_loss import AdaRoundLoss
from .at_loss import ATLoss
from .crd_loss import CRDLoss
+from .cross_entropy_loss import CrossEntropyLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
@@ -12,6 +14,7 @@
from .l1_loss import L1Loss
from .l2_loss import L2Loss
from .ofd_loss import OFDLoss
+from .pkd_loss import PKDLoss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
from .weighted_soft_label_distillation import WSLD
@@ -19,5 +22,6 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
- 'L1Loss', 'FBKDLoss', 'CRDLoss'
+ 'L1Loss', 'FBKDLoss', 'CRDLoss', 'AdaRoundLoss', 'CrossEntropyLoss',
+ 'PKDLoss'
]
diff --git a/mmrazor/models/losses/adaround_loss.py b/mmrazor/models/losses/adaround_loss.py
new file mode 100644
index 000000000..76c97977d
--- /dev/null
+++ b/mmrazor/models/losses/adaround_loss.py
@@ -0,0 +1,87 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmengine.logging import print_log
+
+from mmrazor.registry import MODELS
+
+_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
+
+
+@MODELS.register_module()
+class AdaRoundLoss(nn.Module):
+ r'''loss function to calculate mse reconstruction loss and relaxation loss
+ use some tempdecay to balance the two losses.
+ '''
+
+ def __init__(self,
+ weight: float = 1.,
+ iters: int = 10000,
+ beta_range: tuple = (20, 2),
+ warm_up: float = 0.0,
+ p: float = 2.):
+ self.weight = weight
+ self.loss_start = iters * warm_up
+ self.p = p
+
+ self.temp_decay = LinearTempDecay(
+ iters,
+ warm_up=warm_up,
+ start_beta=beta_range[0],
+ end_beta=beta_range[1])
+ self.count = 0
+
+ def forward(self, subgraph, pred, tgt):
+ """Compute the total loss for adaptive rounding: rec_loss is the
+ quadratic output reconstruction loss, round_loss is a regularization
+ term to optimize the rounding policy.
+
+ :param pred: output from quantized model
+ :param tgt: output from FP model
+ :return: total loss function
+ """
+
+ def lp_loss(pred, tgt, p=2.0):
+ """loss function measured in L_p Norm."""
+ return (pred - tgt).abs().pow(p).sum(1).mean()
+
+ self.count += 1
+ rec_loss = lp_loss(pred, tgt, p=self.p)
+
+ beta = self.temp_decay(self.count)
+ if self.count < self.loss_start:
+ round_loss = 0
+ else:
+ round_loss = 0
+ for layer in subgraph.modules():
+ if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
+ round_vals = layer.weight_fake_quant.rectified_sigmoid()
+ round_loss += self.weight * (1 - (
+ (round_vals - .5).abs() * 2).pow(beta)).sum()
+
+ total_loss = rec_loss + round_loss
+ if self.count % 500 == 0:
+ print_log('Total loss:\t{:.3f} (rec_loss:{:.3f}, '
+ 'round_loss:{:.3f})\tbeta={:.2f}\tcount={}'.format(
+ float(total_loss), float(rec_loss),
+ float(round_loss), beta, self.count))
+ return total_loss
+
+
+class LinearTempDecay:
+
+ def __init__(self, t_max=10000, warm_up=0.2, start_beta=20, end_beta=2):
+ self.t_max = t_max
+ self.start_decay = warm_up * t_max
+ self.start_beta = start_beta
+ self.end_beta = end_beta
+
+ def __call__(self, t):
+ if t < self.start_decay:
+ return self.start_beta
+ elif t > self.t_max:
+ return self.end_beta
+ else:
+ rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
+ return self.end_beta + (self.start_beta - self.end_beta) * \
+ max(0.0, (1 - rel_t))
diff --git a/mmrazor/models/losses/cross_entropy_loss.py b/mmrazor/models/losses/cross_entropy_loss.py
new file mode 100644
index 000000000..685748092
--- /dev/null
+++ b/mmrazor/models/losses/cross_entropy_loss.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmrazor.registry import MODELS
+
+
+@MODELS.register_module()
+class CrossEntropyLoss(nn.Module):
+ """Cross entropy loss.
+
+ Args:
+ loss_weight (float): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ self.loss_weight = loss_weight
+
+ def forward(self, preds_S, preds_T):
+ preds_T = preds_T.detach()
+ loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1))
+ return loss * self.loss_weight
diff --git a/mmrazor/models/losses/pkd_loss.py b/mmrazor/models/losses/pkd_loss.py
new file mode 100644
index 000000000..febc05c36
--- /dev/null
+++ b/mmrazor/models/losses/pkd_loss.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmrazor.registry import MODELS
+
+
+@MODELS.register_module()
+class PKDLoss(nn.Module):
+ """PyTorch version of `PKD: General Distillation Framework for Object
+ Detectors via Pearson Correlation Coefficient.
+
+ `_.
+
+ Args:
+ loss_weight (float): Weight of loss. Defaults to 1.0.
+ resize_stu (bool): If True, we'll down/up sample the features of the
+ student model to the spatial size of those of the teacher model if
+ their spatial sizes are different. And vice versa. Defaults to
+ True.
+ """
+
+ def __init__(self, loss_weight=1.0, resize_stu=True):
+ super(PKDLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.resize_stu = resize_stu
+
+ def norm(self, feat: torch.Tensor) -> torch.Tensor:
+ """Normalize the feature maps to have zero mean and unit variances.
+
+ Args:
+ feat (torch.Tensor): The original feature map with shape
+ (N, C, H, W).
+ """
+ assert len(feat.shape) == 4
+ N, C, H, W = feat.shape
+ feat = feat.permute(1, 0, 2, 3).reshape(C, -1)
+ mean = feat.mean(dim=-1, keepdim=True)
+ std = feat.std(dim=-1, keepdim=True)
+ feat = (feat - mean) / (std + 1e-6)
+ return feat.reshape(C, N, H, W).permute(1, 0, 2, 3)
+
+ def forward(self, preds_S: Union[torch.Tensor, Tuple],
+ preds_T: Union[torch.Tensor, Tuple]) -> torch.Tensor:
+ """Forward computation.
+
+ Args:
+ preds_S (torch.Tensor | Tuple[torch.Tensor]): The student model
+ prediction. If tuple, it should be several tensors with shape
+ (N, C, H, W).
+ preds_T (torch.Tensor | Tuple[torch.Tensor]): The teacher model
+ prediction. If tuple, it should be several tensors with shape
+ (N, C, H, W).
+
+ Return:
+ torch.Tensor: The calculated loss value.
+ """
+ if isinstance(preds_S, torch.Tensor):
+ preds_S, preds_T = (preds_S, ), (preds_T, )
+
+ loss = 0.
+
+ for pred_S, pred_T in zip(preds_S, preds_T):
+ size_S, size_T = pred_S.shape[2:], pred_T.shape[2:]
+ if size_S[0] != size_T[0]:
+ if self.resize_stu:
+ pred_S = F.interpolate(pred_S, size_T, mode='bilinear')
+ else:
+ pred_T = F.interpolate(pred_T, size_S, mode='bilinear')
+ assert pred_S.shape == pred_T.shape
+
+ norm_S, norm_T = self.norm(pred_S), self.norm(pred_T)
+
+ # First conduct feature normalization and then calculate the
+ # MSE loss. Methematically, it is equivalent to firstly calculate
+ # the Pearson Correlation Coefficient (r) between two feature
+ # vectors, and then use 1-r as the new feature imitation loss.
+ loss += F.mse_loss(norm_S, norm_T) / 2
+
+ return loss * self.loss_weight
diff --git a/mmrazor/models/mutables/base_mutable.py b/mmrazor/models/mutables/base_mutable.py
index b0df98d4f..2b5972d9f 100644
--- a/mmrazor/models/mutables/base_mutable.py
+++ b/mmrazor/models/mutables/base_mutable.py
@@ -1,14 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
-from typing import Dict, Generic, Optional, TypeVar
+from typing import Dict, Optional
from mmengine.model import BaseModule
-CHOICE_TYPE = TypeVar('CHOICE_TYPE')
-CHOSEN_TYPE = TypeVar('CHOSEN_TYPE')
+from mmrazor.utils.typing import DumpChosen
-class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):
+class BaseMutable(BaseModule, ABC):
"""Base Class for mutables. Mutable means a searchable module widely used
in Neural Architecture Search(NAS).
@@ -17,13 +16,12 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):
All subclass should implement the following APIs:
- - ``forward()``
- ``fix_chosen()``
- - ``choices()``
+ - ``dump_chosen()``
+ - ``current_choice.setter()``
+ - ``current_choice.getter()``
Args:
- module_kwargs (dict[str, dict], optional): Module initialization named
- arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
@@ -38,19 +36,18 @@ def __init__(self,
self.alias = alias
self._is_fixed = False
- self._current_choice: Optional[CHOICE_TYPE] = None
- @property
- def current_choice(self) -> Optional[CHOICE_TYPE]:
+ @property # type: ignore
+ @abstractmethod
+ def current_choice(self):
"""Current choice will affect :meth:`forward` and will be used in
:func:`mmrazor.core.subnet.utils.export_fix_subnet` or mutator.
"""
- return self._current_choice
- @current_choice.setter
- def current_choice(self, choice: Optional[CHOICE_TYPE]) -> None:
+ @current_choice.setter # type: ignore
+ @abstractmethod
+ def current_choice(self, choice) -> None:
"""Current choice setter will be executed in mutator."""
- self._current_choice = choice
@property
def is_fixed(self) -> bool:
@@ -76,22 +73,22 @@ def is_fixed(self, is_fixed: bool) -> None:
self._is_fixed = is_fixed
@abstractmethod
- def fix_chosen(self, chosen: CHOSEN_TYPE) -> None:
- """Fix mutable with choice. This function would fix the choice of
- Mutable. The :attr:`is_fixed` will be set to True and only the selected
+ def fix_chosen(self, chosen) -> None:
+ """Fix mutable with chosen. This function would fix the chosen of
+ mutable. The :attr:`is_fixed` will be set to True and only the selected
operations can be retained. All subclasses must implement this method.
Note:
This operation is irreversible.
"""
+ raise NotImplementedError()
- # TODO
- # type hint
@abstractmethod
- def dump_chosen(self) -> CHOSEN_TYPE:
- ...
+ def dump_chosen(self) -> DumpChosen:
+ """Save the current state of the mutable as a dictionary.
- @property
- @abstractmethod
- def num_choices(self) -> int:
- pass
+ ``DumpChosen`` has ``chosen`` and ``meta`` fields. ``chosen`` is
+ necessary, ``fix_chosen`` will use the ``chosen`` . ``meta`` is used to
+ store some non-essential information.
+ """
+ raise NotImplementedError()
diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py
index 98f680ee9..ddbf6adeb 100644
--- a/mmrazor/models/mutables/derived_mutable.py
+++ b/mmrazor/models/mutables/derived_mutable.py
@@ -15,8 +15,9 @@
from mmengine.logging import print_log
from torch import Tensor
+from mmrazor.utils.typing import DumpChosen
from ..utils import make_divisible
-from .base_mutable import CHOICE_TYPE, BaseMutable
+from .base_mutable import BaseMutable
class MutableProtocol(Protocol): # pragma: no cover
@@ -172,8 +173,7 @@ def derive_concat_mutable(
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)
-class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE],
- DerivedMethodMixin):
+class DerivedMutable(BaseMutable, DerivedMethodMixin):
"""Class for derived mutable.
A derived mutable is a mutable derived from other mutables that has
@@ -242,7 +242,7 @@ def __init__(self,
# TODO
# has no effect
- def fix_chosen(self, chosen: CHOICE_TYPE) -> None:
+ def fix_chosen(self, chosen) -> None:
"""Fix mutable with subnet config.
Warning:
@@ -253,7 +253,7 @@ def fix_chosen(self, chosen: CHOICE_TYPE) -> None:
'which will have no effect.',
level=logging.WARNING)
- def dump_chosen(self) -> CHOICE_TYPE:
+ def dump_chosen(self) -> DumpChosen:
"""Dump information of chosen.
Returns:
@@ -263,6 +263,9 @@ def dump_chosen(self) -> CHOICE_TYPE:
'Trying to dump chosen for derived mutable, '
'but its value depend on the source mutables.',
level=logging.WARNING)
+ return DumpChosen(chosen=self.export_chosen(), meta=None)
+
+ def export_chosen(self):
return self.current_choice
@property
@@ -314,12 +317,12 @@ def num_choices(self) -> int:
return 1
@property
- def current_choice(self) -> CHOICE_TYPE:
+ def current_choice(self):
"""Current choice of derived mutable."""
return self.choice_fn()
@current_choice.setter
- def current_choice(self, choice: CHOICE_TYPE) -> None:
+ def current_choice(self, choice) -> None:
"""Setter of current choice.
Raises:
diff --git a/mmrazor/models/mutables/mutable_channel/MutableChannel.md b/mmrazor/models/mutables/mutable_channel/MutableChannel.md
new file mode 100644
index 000000000..20b3db816
--- /dev/null
+++ b/mmrazor/models/mutables/mutable_channel/MutableChannel.md
@@ -0,0 +1,36 @@
+# MutableChannels
+
+MutableChannels are used to deal with mutable number of channels in DynamicOps.
+
+```
+|-----------------------------------------|
+| mutable_in_channel(BaseMutableChannel) |
+| --------------------------------------- |
+| DynamicOp |
+| --------------------------------------- |
+| mutable_out_channel(BaseMutableChannel) |
+| --------------------------------------- |
+```
+
+\`
+All MutableChannels inherit from BaseMutableChannel. Each MutableChannel has to implement two property.
+
+- current_choice: get and set the choice of the MutableChannel.
+- current_mask: get the channel mask according to the current_choice.
+
+## MutableChannelContainer
+
+Here, we introduce a special MutableChannel: MutableChannelContainer. As the channels of a DynamicOp may belong to different MutableChannelUnits, we use MutableChannelContainers to store multiple MutableChannels as below.
+
+```
+-----------------------------------------------------------
+| MutableChannelContainer |
+-----------------------------------------------------------
+|MutableChannel1| MutableChannel2 |MutableChannel3|
+-----------------------------------------------------------
+```
+
+MutableChannelContainer has an method to register MutableChannels.
+
+- register_mutable: register/store BaseMutableChannel in the
+ MutableChannelContainer
diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py
index 28f1e4854..65d5a44d6 100644
--- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py
+++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py
@@ -4,6 +4,7 @@
import torch
+from mmrazor.utils.typing import DumpChosen
from ..base_mutable import BaseMutable
from ..derived_mutable import DerivedMethodMixin
@@ -20,9 +21,9 @@ class BaseMutableChannel(BaseMutable, DerivedMethodMixin):
|mutable_out_channel(BaseMutableChannel)|
|---------------------------------------|
- All subclasses should implement the following APIs:
+ All subclasses should implement the following APIs and the other
+ abstract method in ``BaseMutable``
- - ``current_choice``
- ``current_mask``
Args:
@@ -34,20 +35,6 @@ def __init__(self, num_channels: int, **kwargs):
self.name = ''
self.num_channels = num_channels
- # choice
-
- @property # type: ignore
- @abstractmethod
- def current_choice(self):
- """get current choice."""
- raise NotImplementedError()
-
- @current_choice.setter # type: ignore
- @abstractmethod
- def current_choice(self):
- """set current choice."""
- raise NotImplementedError()
-
@property # type: ignore
@abstractmethod
def current_mask(self) -> torch.Tensor:
@@ -73,9 +60,15 @@ def fix_chosen(self, chosen=None):
self.is_fixed = True
- def dump_chosen(self):
- """dump current choice to a dict."""
- raise NotImplementedError()
+ def dump_chosen(self) -> DumpChosen:
+ """Dump chosen."""
+ meta = dict(max_channels=self.mask.size(0))
+ chosen = self.export_chosen()
+
+ return DumpChosen(chosen=chosen, meta=meta)
+
+ def export_chosen(self) -> int:
+ return self.activated_channels
def num_choices(self) -> int:
"""Number of available choices."""
diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py
index eae559d41..9b891e349 100644
--- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py
+++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py
@@ -69,10 +69,6 @@ def fix_chosen(self, chosen=...):
self.current_choice = chosen
self.is_fixed = True
- def dump_chosen(self):
- """Dump chosen."""
- return self.current_choice
-
def __rmul__(self, other) -> DerivedMutable:
return self * other
diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py
index 576412ec0..e494b4018 100644
--- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py
+++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py
@@ -149,9 +149,10 @@ class ChannelUnit(BaseModule):
def __init__(self, num_channels: int, **kwargs):
super().__init__()
+
self.num_channels = num_channels
- self.output_related: nn.ModuleList = nn.ModuleList()
- self.input_related: nn.ModuleList = nn.ModuleList()
+ self.output_related: List[nn.Module] = list()
+ self.input_related: List[nn.Module] = list()
self.init_args: Dict = {
} # is used to generate new channel unit with same args
@@ -208,14 +209,14 @@ def init_from_graph(cls,
def init_from_base_channel_unit(base_channel_unit: BaseChannelUnit):
unit = cls(len(base_channel_unit.channel_elems), **unit_args)
- unit.input_related = nn.ModuleList([
+ unit.input_related = [
Channel.init_from_base_channel(channel)
for channel in base_channel_unit.input_related
- ])
- unit.output_related = nn.ModuleList([
+ ]
+ unit.output_related = [
Channel.init_from_base_channel(channel)
for channel in base_channel_unit.output_related
- ])
+ ]
return unit
unit_graph = ChannelGraph.copy_from(graph,
@@ -239,6 +240,11 @@ def name(self) -> str:
name = f'{first_module_name}_{self.num_channels}'
return name
+ @property
+ def alias(self) -> str:
+ """str: alias of the unit"""
+ return self.name
+
def config_template(self,
with_init_args=False,
with_channels=False) -> Dict:
diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb
new file mode 100644
index 000000000..5af2d496b
--- /dev/null
+++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb
@@ -0,0 +1,419 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MutableChannelUnit"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each MutableChannelUnit is a basic unit for pruning. It records all channels which are dependent on each other.\n",
+ "Below, we will introduce you about:\n",
+ "1. The data structure of MutableChannelUnit.\n",
+ "2. How to prune the model with a MutableChannelUnit.\n",
+ "3. How to get MutableChannelUnits.\n",
+ "4. How to develop a new MutableChannelUnit for a new pruning algorithm.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The Data Structure of MutableChannelUnit"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "First, let's parse a model and get several MutableChannelUnits."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define a model\n",
+ "from mmengine.model import BaseModel\n",
+ "from torch import nn\n",
+ "import torch\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "class MyModel(BaseModel):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__(None, None)\n",
+ " self.net = nn.Sequential(\n",
+ " OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),\n",
+ " ('relu', nn.ReLU()),\n",
+ " ('conv1', nn.Conv2d(8, 16, 3, 1, 1))]))\n",
+ " self.pool = nn.AdaptiveAvgPool2d(1)\n",
+ " self.head = nn.Linear(16, 1000)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " feature = self.net(x)\n",
+ " pool = self.pool(feature).flatten(1)\n",
+ " return self.head(pool)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "This model has 4 MutableChannelUnit(SequentialMutableChannelUnit).\n"
+ ]
+ }
+ ],
+ "source": [
+ "# There are multiple types of MutableChannelUnits. Here, We take SequentialMutableChannelUnit as the example.\n",
+ "from mmrazor.models.mutables.mutable_channel.units import SequentialMutableChannelUnit\n",
+ "from mmrazor.structures.graph import ModuleGraph\n",
+ "from typing import List\n",
+ "\n",
+ "model = MyModel()\n",
+ "graph = ModuleGraph.init_from_backward_tracer(model)\n",
+ "units: List[\n",
+ " SequentialMutableChannelUnit] = SequentialMutableChannelUnit.init_from_graph(graph) # type: ignore\n",
+ "print(\n",
+ " f'This model has {len(units)} MutableChannelUnit(SequentialMutableChannelUnit).'\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "SequentialMutableChannelUnit(\n",
+ " name=net.conv0_(0, 8)_8\n",
+ " (output_related): ModuleList(\n",
+ " (0): Channel(net.conv0, index=(0, 8), is_output_channel=true, expand_ratio=1)\n",
+ " )\n",
+ " (input_related): ModuleList(\n",
+ " (0): Channel(net.conv1, index=(0, 8), is_output_channel=false, expand_ratio=1)\n",
+ " )\n",
+ " (mutable_channel): SquentialMutableChannel(num_channels=8, activated_channels=8)\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "unit1=units[1]\n",
+ "print(unit1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As shown above, each MutableChannelUnit has four important attributes: \n",
+ "1. name: str\n",
+ "2. output_related: ModuleList\n",
+ "3. input_related: ModuleList\n",
+ "4. mutable_channel: BaseMutableChannel\n",
+ "\n",
+ "\"name\" is the identifier of the MutableChannelUnit. It's automatically generated usually.\n",
+ "\n",
+ "\"output_related\" and \"input_related\" are two ModuleLists. They store all Channels with channel dependency.\n",
+ "The difference is that the \"output_related\" includes output channels and the \"input_related\" includes input channels.\n",
+ "All these channels\n",
+ "\n",
+ "\"mutable_channel\" is a BaseMutableChannel used to control the channel mask of modules. The mutable_channel is registered to the modules whose channels are stored in \"output_related\" and \"input_related\"."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## How to prune the model with a MutableChannelUnit."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There are three steps to prune the model using a MutableChannelUnit:\n",
+ "1. replace modules, whose channel are stored in the \"output_related\" and \"input_related\", with dynamic ops which are able to deal with mutable number of channels.\n",
+ "2. register the \"mutable_channel\" to the replaced dynamic ops.\n",
+ "3. change the choice of the \"mutable_channel\".\n",
+ "\n",
+ "For simplicity, we run step 1 and 2 with one method \"prepare_for_pruning\"."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The current choice of unit1 is 8.\n",
+ "DynamicConv2d(\n",
+ " 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n",
+ " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=8)\n",
+ " )\n",
+ ")\n",
+ "DynamicConv2d(\n",
+ " 8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=8)\n",
+ " (out_channels): MutableChannelContainer(num_channels=16, activated_channels=16)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# We run \"prepare_for_pruning\" once before pruning to run step 1 and 2 above.\n",
+ "unit1.prepare_for_pruning(model)\n",
+ "print(f'The current choice of unit1 is {unit1.current_choice}.')\n",
+ "print(model.net.conv0)\n",
+ "print(model.net.conv1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We prune the model by changing the current_choice of the MutableChannelUnits."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "We get a sampled choice 2.\n",
+ "DynamicConv2d(\n",
+ " 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n",
+ " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=2)\n",
+ " )\n",
+ ")\n",
+ "DynamicConv2d(\n",
+ " 8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=2)\n",
+ " (out_channels): MutableChannelContainer(num_channels=16, activated_channels=16)\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "sampled_choice=unit1.sample_choice()\n",
+ "print(f'We get a sampled choice {sampled_choice}.')\n",
+ "unit1.current_choice=sampled_choice\n",
+ "print(model.net.conv0)\n",
+ "print(model.net.conv1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Besides, different types of MutableChannelUnit may have different types of choices. Please read documents for more details."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## How to get MutableChannelUnits."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There are three ways to get MutableChannelUnits.\n",
+ "1. Using a tracer.\n",
+ " This way, firstly, converts a model to a graph, then converts the graph to MutableChannelUnits. It automatically returns all available MutableChannelUnits.\n",
+ "2. Using a config.\n",
+ " This way uses a config to initialize a MutableChannelUnit.\n",
+ "3. Using a predefined model.\n",
+ " This way parses a predefined model with dynamic ops. It returns all available MutableChannelUnits.\n",
+ "\n",
+ "All these three ways have corresponding documents in the README of ChannelMutator."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The model has 4 MutableChannelUnits.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 1. using tracer\n",
+ "def get_mutable_channel_units_using_tracer(model):\n",
+ " graph = ModuleGraph.init_from_backward_tracer(model)\n",
+ " units = SequentialMutableChannelUnit.init_from_graph(graph)\n",
+ " return units\n",
+ "\n",
+ "\n",
+ "model = MyModel()\n",
+ "units = get_mutable_channel_units_using_tracer(model)\n",
+ "print(f'The model has {len(units)} MutableChannelUnits.')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "SequentialMutableChannelUnit(\n",
+ " name=net.conv0_(0, 8)_8\n",
+ " (output_related): ModuleList(\n",
+ " (0): Channel(net.conv0, index=(0, 8), is_output_channel=true, expand_ratio=1)\n",
+ " )\n",
+ " (input_related): ModuleList(\n",
+ " (0): Channel(net.conv1, index=(0, 8), is_output_channel=false, expand_ratio=1)\n",
+ " )\n",
+ " (mutable_channel): SquentialMutableChannel(num_channels=8, activated_channels=8)\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 2. using config\n",
+ "config = {\n",
+ " 'init_args': {\n",
+ " 'num_channels': 8,\n",
+ " },\n",
+ " 'channels': {\n",
+ " 'input_related': [{\n",
+ " 'name': 'net.conv1',\n",
+ " }],\n",
+ " 'output_related': [{\n",
+ " 'name': 'net.conv0',\n",
+ " }]\n",
+ " },\n",
+ " 'choice': 8\n",
+ "}\n",
+ "unit=SequentialMutableChannelUnit.init_from_cfg(model, config)\n",
+ "print(unit)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The model has 2 MutableChannelUnits.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 3. using predefined model\n",
+ "\n",
+ "from mmrazor.models.architectures.dynamic_ops import DynamicConv2d, DynamicLinear\n",
+ "from mmrazor.models.mutables import MutableChannelUnit, MutableChannelContainer,SquentialMutableChannel\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "class MyDynamicModel(BaseModel):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__(None, None)\n",
+ " self.net = nn.Sequential(\n",
+ " OrderedDict([('conv0', DynamicConv2d(3, 8, 3, 1, 1)),\n",
+ " ('relu', nn.ReLU()),\n",
+ " ('conv1', DynamicConv2d(8, 16, 3, 1, 1))]))\n",
+ " self.pool = nn.AdaptiveAvgPool2d(1)\n",
+ " self.head = DynamicLinear(16, 1000)\n",
+ "\n",
+ " # register MutableChannelContainer\n",
+ " MutableChannelUnit._register_channel_container(\n",
+ " self, MutableChannelContainer)\n",
+ " self._register_mutables()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " feature = self.net(x)\n",
+ " pool = self.pool(feature).flatten(1)\n",
+ " return self.head(pool)\n",
+ "\n",
+ " def _register_mutables(self):\n",
+ " mutable1 = SquentialMutableChannel(8)\n",
+ " mutable2 = SquentialMutableChannel(16)\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.net.conv0, mutable1, is_to_output_channel=True)\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.net.conv1, mutable1, is_to_output_channel=False)\n",
+ "\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.net.conv1, mutable2, is_to_output_channel=True)\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.head, mutable2, is_to_output_channel=False)\n",
+ "model=MyDynamicModel()\n",
+ "units=SequentialMutableChannelUnit.init_from_predefined_model(model) \n",
+ "print(f'The model has {len(units)} MutableChannelUnits.')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.9.12 ('mmlab')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "feec882ee78c63cb8d4b485f1b52bbb873bb9a7b094435863200c7afba202382"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py
index 5e44c330d..e524ec67c 100644
--- a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py
+++ b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py
@@ -9,13 +9,13 @@
from torch import Tensor
from mmrazor.registry import MODELS
-from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE
+from mmrazor.utils.typing import DumpChosen
from .mutable_module import MutableModule
PartialType = Callable[[Any, Optional[nn.Parameter]], Any]
-class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
+class DiffMutableModule(MutableModule):
"""Base class for differentiable mutables.
Args:
@@ -34,9 +34,12 @@ class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
- def forward(self,
- x: Any,
- arch_param: Optional[nn.Parameter] = None) -> Any:
+ @abstractmethod
+ def sample_choice(self, arch_param: Tensor):
+ """Sample choice according arch parameters."""
+ raise NotImplementedError
+
+ def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None):
"""Calls either :func:`forward_fixed` or :func:`forward_arch_param`
depending on whether :func:`is_fixed` is ``True`` and whether
:func:`arch_param` is None.
@@ -60,27 +63,17 @@ def forward(self,
if self.is_fixed:
return self.forward_fixed(x)
else:
- return self.forward_arch_param(x, arch_param=arch_param)
+ if arch_param is None:
+ return self.forward_all(x)
+ else:
+ return self.forward_arch_param(x, arch_param=arch_param)
def compute_arch_probs(self, arch_param: nn.Parameter) -> Tensor:
"""compute chosen probs according to architecture params."""
return F.softmax(arch_param, -1)
@abstractmethod
- def forward_fixed(self, x: Any) -> Any:
- """Forward when the mutable is fixed.
-
- All subclasses must implement this method.
- """
-
- @abstractmethod
- def forward_all(self, x: Any) -> Any:
- """Forward all choices."""
-
- @abstractmethod
- def forward_arch_param(self,
- x: Any,
- arch_param: Optional[nn.Parameter] = None) -> Any:
+ def forward_arch_param(self, x, arch_param: nn.Parameter):
"""Forward when the mutable is not fixed.
All subclasses must implement this method.
@@ -94,7 +87,7 @@ def set_forward_args(self, arch_param: nn.Parameter) -> None:
@MODELS.register_module()
-class DiffMutableOP(DiffMutableModule[str, str]):
+class DiffMutableOP(DiffMutableModule):
"""A type of ``MUTABLES`` for differentiable architecture search, such as
DARTS. Search the best module by learnable parameters `arch_param`.
@@ -159,7 +152,7 @@ def _build_ops(candidates: Dict[str, Dict],
ops[name] = MODELS.build(op_cfg)
return ops
- def forward_fixed(self, x: Any) -> Tensor:
+ def forward_fixed(self, x) -> Tensor:
"""Forward when the mutable is in `fixed` mode.
Args:
@@ -171,10 +164,7 @@ def forward_fixed(self, x: Any) -> Tensor:
"""
return sum(self._candidates[choice](x) for choice in self._chosen)
- def forward_arch_param(self,
- x: Any,
- arch_param: Optional[nn.Parameter] = None
- ) -> Tensor:
+ def forward_arch_param(self, x, arch_param: nn.Parameter) -> Tensor:
"""Forward with architecture parameters.
Args:
@@ -187,21 +177,19 @@ def forward_arch_param(self,
Returns:
Tensor: the result of forward with ``arch_param``.
"""
- if arch_param is None:
- return self.forward_all(x)
- else:
- # compute the probs of choice
- probs = self.compute_arch_probs(arch_param=arch_param)
- # forward based on probs
- outputs = list()
- for prob, module in zip(probs, self._candidates.values()):
- if prob > 0.:
- outputs.append(prob * module(x))
+ # compute the probs of choice
+ probs = self.compute_arch_probs(arch_param=arch_param)
- return sum(outputs)
+ # forward based on probs
+ outputs = list()
+ for prob, module in zip(probs, self._candidates.values()):
+ if prob > 0.:
+ outputs.append(prob * module(x))
- def forward_all(self, x: Any) -> Tensor:
+ return sum(outputs)
+
+ def forward_all(self, x) -> Tensor:
"""Forward all choices. Used to calculate FLOPs.
Args:
@@ -240,12 +228,16 @@ def fix_chosen(self, chosen: Union[str, List[str]]) -> None:
self._chosen = chosen
self.is_fixed = True
- def sample_choice(self, arch_param):
+ def sample_choice(self, arch_param: Tensor) -> str:
"""Sample choice based on arch_parameters."""
return self.choices[torch.argmax(arch_param).item()]
- def dump_chosen(self):
- """Dump current choice."""
+ def dump_chosen(self) -> DumpChosen:
+ chosen = self.export_chosen()
+ meta = dict(all_choices=self.choices)
+ return DumpChosen(chosen=chosen, meta=meta)
+
+ def export_chosen(self) -> str:
assert self.current_choice is not None
return self.current_choice
@@ -297,10 +289,11 @@ def sample_weights(self,
m = D.one_hot_categorical.OneHotCategorical(probs=probs)
return m.sample()
- def forward_arch_param(self,
- x: Any,
- arch_param: Optional[nn.Parameter] = None
- ) -> Tensor:
+ def forward_arch_param(
+ self,
+ x: Any,
+ arch_param: nn.Parameter,
+ ) -> Tensor:
"""Forward with architecture parameters.
Args:
@@ -312,39 +305,35 @@ def forward_arch_param(self,
Returns:
Tensor: the result of forward with ``arch_param``.
"""
- if arch_param is None:
- return self.forward_all(x)
- else:
- # compute the probs of choice
- probs = self.compute_arch_probs(arch_param=arch_param)
-
- if not self.is_fixed:
- self.arch_weights = self.sample_weights(arch_param, probs)
- sorted_param = torch.topk(probs, 2)
- index = (
- sorted_param[0][0] - sorted_param[0][1] >=
- self.fix_threshold)
- if index:
- self.fix_chosen(self.choices[index])
-
- if self.is_fixed:
- index = self.choices.index(self._chosen[0])
- self.arch_weights.data.zero_()
- self.arch_weights.data[index].fill_(1.0)
- self.arch_weights.requires_grad_()
-
- # forward based on self.arch_weights
- outputs = list()
- for prob, module in zip(self.arch_weights,
- self._candidates.values()):
- if prob > 0.:
- outputs.append(prob * module(x))
-
- return sum(outputs)
+
+ # compute the probs of choice
+ probs = self.compute_arch_probs(arch_param=arch_param)
+
+ if not self.is_fixed:
+ self.arch_weights = self.sample_weights(arch_param, probs)
+ sorted_param = torch.topk(probs, 2)
+ index = (
+ sorted_param[0][0] - sorted_param[0][1] >= self.fix_threshold)
+ if index:
+ self.fix_chosen(self.choices[index])
+
+ if self.is_fixed:
+ index = self.choices.index(self._chosen[0])
+ self.arch_weights.data.zero_()
+ self.arch_weights.data[index].fill_(1.0)
+ self.arch_weights.requires_grad_()
+
+ # forward based on self.arch_weights
+ outputs = list()
+ for prob, module in zip(self.arch_weights, self._candidates.values()):
+ if prob > 0.:
+ outputs.append(prob * module(x))
+
+ return sum(outputs)
@MODELS.register_module()
-class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
+class DiffChoiceRoute(DiffMutableModule):
"""A type of ``MUTABLES`` for Neural Architecture Search, which can select
inputs from different edges in a differentiable or non-differentiable way.
It is commonly used in DARTS.
@@ -404,6 +393,35 @@ def __init__(
self._candidates: nn.ModuleDict = edges
self.num_chosen = num_chosen
+ def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None):
+ """Calls either :func:`forward_fixed` or :func:`forward_arch_param`
+ depending on whether :func:`is_fixed` is ``True`` and whether
+ :func:`arch_param` is None.
+
+ To reduce the coupling between `Mutable` and `Mutator`, the
+ `arch_param` is generated by the `Mutator` and is passed to the
+ forward function as an argument.
+
+ Note:
+ :meth:`forward_fixed` is called when in `fixed` mode.
+ :meth:`forward_arch_param` is called when in `unfixed` mode.
+
+ Args:
+ x (Any): input data for forward computation.
+ arch_param (nn.Parameter, optional): the architecture parameters
+ for ``DiffMutableModule``.
+
+ Returns:
+ Any: the result of forward
+ """
+ if self.is_fixed:
+ return self.forward_fixed(x)
+ else:
+ if arch_param is not None and self._with_arch_param:
+ return self.forward_arch_param(x, arch_param=arch_param)
+ else:
+ return self.forward_all(x)
+
def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor:
"""Forward when the mutable is in `fixed` mode.
@@ -424,10 +442,7 @@ def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor:
outputs.append(self._candidates[choice](x))
return sum(outputs)
- def forward_arch_param(
- self,
- x: Union[List[Any], Tuple[Any]],
- arch_param: Optional[nn.Parameter] = None) -> Tensor:
+ def forward_arch_param(self, x, arch_param: nn.Parameter) -> Tensor:
"""Forward with architecture parameters.
Args:
@@ -443,21 +458,17 @@ def forward_arch_param(
f'Length of `edges` {len(self._candidates)} should be ' \
f'same as the length of inputs {len(x)}.'
- if self._with_arch_param:
- probs = self.compute_arch_probs(arch_param=arch_param)
+ probs = self.compute_arch_probs(arch_param=arch_param)
- outputs = list()
- for prob, module, input in zip(probs, self._candidates.values(),
- x):
- if prob > 0:
- # prob may equal to 0 in gumbel softmax.
- outputs.append(prob * module(input))
+ outputs = list()
+ for prob, module, input in zip(probs, self._candidates.values(), x):
+ if prob > 0:
+ # prob may equal to 0 in gumbel softmax.
+ outputs.append(prob * module(input))
- return sum(outputs)
- else:
- return self.forward_all(x)
+ return sum(outputs)
- def forward_all(self, x: Any) -> Tensor:
+ def forward_all(self, x):
"""Forward all choices.
Args:
@@ -500,16 +511,20 @@ def fix_chosen(self, chosen: List[str]) -> None:
self.is_fixed = True
@property
- def choices(self) -> List[CHOSEN_TYPE]:
+ def choices(self) -> List[str]:
"""list: all choices. """
return list(self._candidates.keys())
- def dump_chosen(self):
- """dump current choice."""
+ def dump_chosen(self) -> DumpChosen:
+ chosen = self.export_chosen()
+ meta = dict(all_choices=self.choices)
+ return DumpChosen(chosen=chosen, meta=meta)
+
+ def export_chosen(self) -> str:
assert self.current_choice is not None
return self.current_choice
- def sample_choice(self, arch_param):
+ def sample_choice(self, arch_param: Tensor) -> List[str]:
"""sample choice based on `arch_param`."""
sort_idx = torch.argsort(-arch_param).cpu().numpy().tolist()
choice_idx = sort_idx[:self.num_chosen]
diff --git a/mmrazor/models/mutables/mutable_module/mutable_module.py b/mmrazor/models/mutables/mutable_module/mutable_module.py
index 8840fd783..c71f1a969 100644
--- a/mmrazor/models/mutables/mutable_module/mutable_module.py
+++ b/mmrazor/models/mutables/mutable_module/mutable_module.py
@@ -2,20 +2,22 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional
-from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable
+from ..base_mutable import BaseMutable
-class MutableModule(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
+class MutableModule(BaseMutable):
"""Base Class for mutables. Mutable means a searchable module widely used
in Neural Architecture Search(NAS).
It mainly consists of some optional operations, and achieving
searchable function by handling choice with ``MUTATOR``.
- All subclass should implement the following APIs:
+ All subclass should implement the following APIs and the other
+ abstract method in ``BaseMutable``:
- ``forward()``
- - ``fix_chosen()``
+ - ``forward_all()``
+ - ``forward_fix()``
- ``choices()``
Args:
@@ -30,20 +32,48 @@ class MutableModule(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]):
def __init__(self,
module_kwargs: Optional[Dict[str, Dict]] = None,
- **kwargs) -> None:
- super().__init__(**kwargs)
+ alias: Optional[str] = None,
+ init_cfg: Optional[Dict] = None) -> None:
+ super().__init__(alias, init_cfg)
self.module_kwargs = module_kwargs
+ self._current_choice = None
+
+ @property
+ def current_choice(self):
+ """Current choice will affect :meth:`forward` and will be used in
+ :func:`mmrazor.core.subnet.utils.export_fix_subnet` or mutator.
+ """
+ return self._current_choice
+
+ @current_choice.setter
+ def current_choice(self, choice) -> None:
+ """Current choice setter will be executed in mutator."""
+ self._current_choice = choice
@property
@abstractmethod
- def choices(self) -> List[CHOICE_TYPE]:
+ def choices(self) -> List[str]:
"""list: all choices. All subclasses must implement this method."""
@abstractmethod
def forward(self, x: Any) -> Any:
"""Forward computation."""
+ @abstractmethod
+ def forward_fixed(self, x):
+ """Forward with the fixed mutable.
+
+ All subclasses must implement this method.
+ """
+
+ @abstractmethod
+ def forward_all(self, x):
+ """Forward all choices.
+
+ All subclasses must implement this method.
+ """
+
@property
def num_choices(self) -> int:
"""Number of choices."""
diff --git a/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py b/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py
index f04c61eb5..434b05079 100644
--- a/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py
+++ b/mmrazor/models/mutables/mutable_module/one_shot_mutable_module.py
@@ -8,30 +8,20 @@
from torch import Tensor
from mmrazor.registry import MODELS
-from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE
+from mmrazor.utils.typing import DumpChosen
from .mutable_module import MutableModule
-class OneShotMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
+class OneShotMutableModule(MutableModule):
"""Base class for one shot mutable module. A base type of ``MUTABLES`` for
single path supernet such as Single Path One Shot.
- All subclass should implement the following APIs:
+ All subclass should implement the following APIs and the other
+ abstract method in ``MutableModule``:
- ``sample_choice()``
- - ``forward_fixed()``
- - ``forward_all()``
- ``forward_choice()``
- Args:
- module_kwargs (dict[str, dict], optional): Module initialization named
- arguments. Defaults to None.
- alias (str, optional): alias of the `MUTABLE`.
- init_cfg (dict, optional): initialization configuration dict for
- ``BaseModule``. OpenMMLab has implement 5 initializer including
- `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
- and `Pretrained`.
-
Note:
:meth:`forward_all` is called when calculating FLOPs.
"""
@@ -63,29 +53,15 @@ def forward(self, x: Any) -> Any:
return self.forward_choice(x, choice=self.current_choice)
@abstractmethod
- def sample_choice(self) -> CHOICE_TYPE:
+ def sample_choice(self) -> str:
"""Sample random choice.
Returns:
- CHOICE_TYPE: the chosen key in ``MUTABLE``.
- """
-
- @abstractmethod
- def forward_fixed(self, x: Any) -> Any:
- """Forward with the fixed mutable.
-
- All subclasses must implement this method.
- """
-
- @abstractmethod
- def forward_all(self, x: Any) -> Any:
- """Forward all choices.
-
- All subclasses must implement this method.
+ str: the chosen key in ``MUTABLE``.
"""
@abstractmethod
- def forward_choice(self, x: Any, choice: CHOICE_TYPE) -> Any:
+ def forward_choice(self, x, choice: str):
"""Forward with the unfixed mutable and current_choice is not None.
All subclasses must implement this method.
@@ -93,7 +69,7 @@ def forward_choice(self, x: Any, choice: CHOICE_TYPE) -> Any:
@MODELS.register_module()
-class OneShotMutableOP(OneShotMutableModule[str, str]):
+class OneShotMutableOP(OneShotMutableModule):
"""A type of ``MUTABLES`` for single path supernet, such as Single Path One
Shot. In single path supernet, each choice block only has one choice
invoked at the same time. A path is obtained by sampling all the choice
@@ -117,7 +93,6 @@ class OneShotMutableOP(OneShotMutableModule[str, str]):
>>> candidates = nn.ModuleDict({
... 'conv3x3': nn.Conv2d(32, 32, 3, 1, 1),
... 'conv5x5': nn.Conv2d(32, 32, 5, 1, 2),
- ... 'conv7x7': nn.Conv2d(32, 32, 7, 1, 3)})
>>> input = torch.randn(1, 32, 64, 64)
>>> op = OneShotMutableOP(candidates)
@@ -214,7 +189,7 @@ def forward_fixed(self, x: Any) -> Tensor:
"""
return self._candidates[self._chosen](x)
- def forward_choice(self, x: Any, choice: str) -> Tensor:
+ def forward_choice(self, x, choice: str) -> Tensor:
"""Forward with the `unfixed` mutable and current choice is not None.
Args:
@@ -228,7 +203,7 @@ def forward_choice(self, x: Any, choice: str) -> Tensor:
assert isinstance(choice, str) and choice in self.choices
return self._candidates[choice](x)
- def forward_all(self, x: Any) -> Tensor:
+ def forward_all(self, x) -> Tensor:
"""Forward all choices. Used to calculate FLOPs.
Args:
@@ -263,9 +238,13 @@ def fix_chosen(self, chosen: str) -> None:
self._chosen = chosen
self.is_fixed = True
- def dump_chosen(self) -> str:
- assert self.current_choice is not None
+ def dump_chosen(self) -> DumpChosen:
+ chosen = self.export_chosen()
+ meta = dict(all_choices=self.choices)
+ return DumpChosen(chosen=chosen, meta=meta)
+ def export_chosen(self) -> str:
+ assert self.current_choice is not None
return self.current_choice
def sample_choice(self) -> str:
@@ -277,10 +256,6 @@ def choices(self) -> List[str]:
"""list: all choices. """
return list(self._candidates.keys())
- @property
- def num_choices(self):
- return len(self.choices)
-
@MODELS.register_module()
class OneShotProbMutableOP(OneShotMutableOP):
diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py
index 49a0c870f..20055287d 100644
--- a/mmrazor/models/mutables/mutable_value/mutable_value.py
+++ b/mmrazor/models/mutables/mutable_value/mutable_value.py
@@ -3,12 +3,15 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from mmrazor.registry import MODELS
+from mmrazor.utils.typing import DumpChosen
from ..base_mutable import BaseMutable
from ..derived_mutable import DerivedMethodMixin, DerivedMutable
+Value = Union[int, float]
+
@MODELS.register_module()
-class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin):
+class MutableValue(BaseMutable, DerivedMethodMixin):
"""Base class for mutable value.
A mutable value is actually a mutable that adds some functionality to a
@@ -26,7 +29,7 @@ class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin):
"""
def __init__(self,
- value_list: List[Any],
+ value_list: List[Value],
default_value: Optional[Any] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
@@ -59,7 +62,7 @@ def choices(self) -> List[Any]:
"""List of choices."""
return self._value_list
- def fix_chosen(self, chosen: Dict[str, Any]) -> None:
+ def fix_chosen(self, chosen: Value) -> None:
"""Fix mutable value with subnet config.
Args:
@@ -68,24 +71,23 @@ def fix_chosen(self, chosen: Dict[str, Any]) -> None:
if self.is_fixed:
raise RuntimeError('MutableValue can not be fixed twice')
- all_choices = chosen['all_choices']
- current_choice = chosen['current_choice']
+ assert chosen in self.choices
- assert all_choices == self.choices, \
- f'Expect choices to be: {self.choices}, but got: {all_choices}'
- assert current_choice in self.choices
-
- self.current_choice = current_choice
+ self.current_choice = chosen
self.is_fixed = True
- def dump_chosen(self) -> Dict[str, Any]:
+ def dump_chosen(self) -> DumpChosen:
"""Dump information of chosen.
Returns:
Dict[str, Any]: Dumped information.
"""
- return dict(
- current_choice=self.current_choice, all_choices=self.choices)
+ chosen = self.export_chosen()
+ meta = dict(all_choices=self.choices)
+ return DumpChosen(chosen=chosen, meta=meta)
+
+ def export_chosen(self):
+ return self.current_choice
@property
def num_choices(self) -> int:
diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb b/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb
new file mode 100644
index 000000000..307ffc669
--- /dev/null
+++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb
@@ -0,0 +1,365 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ChannelMutator\n",
+ "A channel mutator is a manager of the channel structure of a model. In other words, it manages all MutableChannelUnits of a model. \n",
+ "ChannelMutator is the simplest channel mutator. All other channel mutators should inherit from ChannelMutator class. We take ChannelMutator as an example."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## How to Construct a ChannelMutator"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Suppose we have a model archtecture defineed below"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define a model\n",
+ "from mmengine.model import BaseModel\n",
+ "from torch import nn\n",
+ "import torch\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "class MyModel(BaseModel):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__(None, None)\n",
+ " self.net = nn.Sequential(\n",
+ " OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),\n",
+ " ('relu', nn.ReLU()),\n",
+ " ('conv1', nn.Conv2d(8, 16, 3, 1, 1))]))\n",
+ " self.pool = nn.AdaptiveAvgPool2d(1)\n",
+ " self.head = nn.Linear(16, 1000)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " feature = self.net(x)\n",
+ " pool = self.pool(feature).flatten(1)\n",
+ " return self.head(pool)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There are two steps to fully constructing a ChannelMutator object as below. \n",
+ "1. we need to initialize a ChannelMutator object.\n",
+ "2. Then we need to init the ChannelMutator object with a model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The mutator has 2 mutable channel units.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from mmrazor.models.mutators import ChannelMutator\n",
+ "\n",
+ "model = MyModel()\n",
+ "# initialize a ChannelMutator object\n",
+ "mutator = ChannelMutator(\n",
+ " channel_unit_cfg=dict(\n",
+ " type='SequentialMutableChannelUnit',\n",
+ " default_args=dict(choice_mode='ratio'),\n",
+ " units={},\n",
+ " ),\n",
+ " parse_cfg=dict(\n",
+ " type='BackwardTracer',\n",
+ " loss_calculator=dict(type='ImageClassifierPseudoLoss')))\n",
+ "# init the ChannelMutator object with a model\n",
+ "mutator.prepare_from_supernet(model)\n",
+ "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "ChannelMutator has two arguments:\n",
+ "1. channel_unit_cfg: config of the MutableChannelUnit to use in the ChannelMutator.\n",
+ "2. parse_cfg: the way to parse the model and get MutableChannelUnits."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There are there ways to parse model and get MutableChannelUnits.\n",
+ "1. Use a tracer to get MutableChannelUnits automatically.\n",
+ "2. Use config dicts to indicate MutableChannelUnits.\n",
+ "3. Predefine MutableChannels in the model archtecture.\n",
+ " \n",
+ "The example of method 1 has been post above. We post the examples of method 2 and method 3 below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The mutator has 2 mutable channel units.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 2. use config dicts to indicate MutableChannelUnits.\n",
+ "from mmrazor.models.mutators import ChannelMutator\n",
+ "\n",
+ "model = MyModel()\n",
+ "# initialize a ChannelMutator object\n",
+ "mutator = ChannelMutator(\n",
+ " channel_unit_cfg=dict(\n",
+ " type='SequentialMutableChannelUnit',\n",
+ " default_args=dict(choice_mode='ratio'),\n",
+ " units={\n",
+ " 'net.conv0_(0, 8)_8': {\n",
+ " 'init_args': {\n",
+ " 'num_channels': 8,\n",
+ " },\n",
+ " 'channels': {\n",
+ " 'input_related': [{\n",
+ " 'name': 'net.conv1',\n",
+ " }],\n",
+ " 'output_related': [{\n",
+ " 'name': 'net.conv0',\n",
+ " }]\n",
+ " },\n",
+ " 'choice': 1.0\n",
+ " },\n",
+ " 'net.conv1_(0, 16)_16': {\n",
+ " 'init_args': {\n",
+ " 'num_channels': 16,\n",
+ " },\n",
+ " 'channels': {\n",
+ " 'input_related': [{\n",
+ " 'name': 'head',\n",
+ " }],\n",
+ " 'output_related': [{\n",
+ " 'name': 'net.conv1',\n",
+ " }]\n",
+ " },\n",
+ " 'choice': 1.0\n",
+ " }\n",
+ " }),\n",
+ " parse_cfg=dict(type='Config'))\n",
+ "# init the ChannelMutator object with a model\n",
+ "mutator.prepare_from_supernet(model)\n",
+ "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The mutator has 2 mutable channel units.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 3. Predefine MutableChannels in the model archtecture.\n",
+ "\n",
+ "from mmrazor.models.architectures.dynamic_ops import DynamicConv2d, DynamicLinear\n",
+ "from mmrazor.models.mutables import MutableChannelUnit, MutableChannelContainer, SquentialMutableChannel\n",
+ "from collections import OrderedDict\n",
+ "\n",
+ "class MyDynamicModel(BaseModel):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__(None, None)\n",
+ " self.net = nn.Sequential(\n",
+ " OrderedDict([('conv0', DynamicConv2d(3, 8, 3, 1, 1)),\n",
+ " ('relu', nn.ReLU()),\n",
+ " ('conv1', DynamicConv2d(8, 16, 3, 1, 1))]))\n",
+ " self.pool = nn.AdaptiveAvgPool2d(1)\n",
+ " self.head = DynamicLinear(16, 1000)\n",
+ "\n",
+ " # register MutableChannelContainer\n",
+ " MutableChannelUnit._register_channel_container(\n",
+ " self, MutableChannelContainer)\n",
+ " self._register_mutables()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " feature = self.net(x)\n",
+ " pool = self.pool(feature).flatten(1)\n",
+ " return self.head(pool)\n",
+ "\n",
+ " def _register_mutables(self):\n",
+ " mutable1 = SquentialMutableChannel(8)\n",
+ " mutable2 = SquentialMutableChannel(16)\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.net.conv0, mutable1, is_to_output_channel=True)\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.net.conv1, mutable1, is_to_output_channel=False)\n",
+ "\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.net.conv1, mutable2, is_to_output_channel=True)\n",
+ " MutableChannelContainer.register_mutable_channel_to_module(\n",
+ " self.head, mutable2, is_to_output_channel=False)\n",
+ "\n",
+ "\n",
+ "model = MyDynamicModel()\n",
+ "# initialize a ChannelMutator object\n",
+ "mutator = ChannelMutator(\n",
+ " channel_unit_cfg=dict(\n",
+ " type='SequentialMutableChannelUnit',\n",
+ " default_args=dict(choice_mode='ratio'),\n",
+ " units={},\n",
+ " ),\n",
+ " parse_cfg=dict(type='Predefined'))\n",
+ "# init the ChannelMutator object with a model\n",
+ "mutator.prepare_from_supernet(model)\n",
+ "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## How to Change the Structure of a Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The structure of a model is represented by a dict where the key is the name of a MutableChannelUnit and the value is a structure choice."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{0: 8, 1: 16}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(mutator.current_choices)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can change the dict to prune the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MyDynamicModel(\n",
+ " (data_preprocessor): BaseDataPreprocessor()\n",
+ " (net): Sequential(\n",
+ " (conv0): DynamicConv2d(\n",
+ " 3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n",
+ " (out_channels): MutableChannelContainer(num_channels=8, activated_channels=4)\n",
+ " )\n",
+ " )\n",
+ " (relu): ReLU()\n",
+ " (conv1): DynamicConv2d(\n",
+ " 8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_channels): MutableChannelContainer(num_channels=8, activated_channels=4)\n",
+ " (out_channels): MutableChannelContainer(num_channels=16, activated_channels=8)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (pool): AdaptiveAvgPool2d(output_size=1)\n",
+ " (head): DynamicLinear(\n",
+ " in_features=16, out_features=1000, bias=True\n",
+ " (mutable_attrs): ModuleDict(\n",
+ " (in_features): MutableChannelContainer(num_channels=16, activated_channels=8)\n",
+ " (out_features): MutableChannelContainer(num_channels=1000, activated_channels=1000)\n",
+ " )\n",
+ " )\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "mutator.set_choices(\n",
+ " {0: 4, 1: 8}\n",
+ ")\n",
+ "print(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Please refer to our documents for more choices related methods."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.9.13 ('lab2max')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py
index c4ce92e96..7a19f1c72 100644
--- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py
+++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
-from typing import Dict, Generic, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union
from mmengine import fileio
-from torch.nn import Module
+from torch.nn import Module, ModuleList
from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin
from mmrazor.models.mutables import (ChannelUnitType, MutableChannelUnit,
@@ -13,6 +13,7 @@
from mmrazor.registry import MODELS
from mmrazor.structures.graph import ModuleGraph
from ..base_mutator import BaseMutator
+from ..group_mixin import GroupMixin
def is_dynamic_op_for_fx_tracer(module, name):
@@ -20,7 +21,7 @@ def is_dynamic_op_for_fx_tracer(module, name):
@MODELS.register_module()
-class ChannelMutator(BaseMutator, Generic[ChannelUnitType]):
+class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin):
"""ChannelMutator manages the pruning structure of a model.
Args:
@@ -48,6 +49,10 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType]):
dict( type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss')).
+ custom_groups (list[list[str]], optional): User-defined search groups.
+ All searchable modules that are not in ``custom_group`` will be
+ grouped separately.
+
init_cfg (dict, optional): initialization configuration dict for
BaseModule.
@@ -70,6 +75,7 @@ def __init__(self,
parse_cfg: Dict = dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss')),
+ custom_groups: Optional[List[List[str]]] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg)
@@ -83,7 +89,7 @@ def __init__(self,
# units
self._name2unit: Dict[str, ChannelUnitType] = {}
- self.units: List[ChannelUnitType] = []
+ self.units: ModuleList[ChannelUnitType] = ModuleList()
# unit config
self.channel_unit_cfg = channel_unit_cfg
@@ -91,6 +97,10 @@ def __init__(self,
self._parse_channel_unit_cfg(
channel_unit_cfg)
+ if custom_groups is None:
+ custom_groups = []
+ self._custom_groups = custom_groups
+
def prepare_from_supernet(self, supernet: Module) -> None:
"""Prepare from a model for pruning.
@@ -113,7 +123,11 @@ def prepare_from_supernet(self, supernet: Module) -> None:
for unit in units:
unit.prepare_for_pruning(supernet)
self._name2unit[unit.name] = unit
- self.units = units
+ self.units = ModuleList(units)
+
+ self._search_groups = self.build_search_groups(
+ ModuleList(self.mutable_units), self.mutable_class_type,
+ self._custom_groups)
# ~
@@ -129,13 +143,16 @@ def config_template(self,
"""Config template of the mutator.
Args:
- only_mutable_units (bool, optional): If only return config of
- prunable units. Defaults to False.
- with_unit_init_args (bool, optional): If return init_args of
- units. Defaults to False.
- with_channels (bool, optional): if return channel info.
- Defaults to False.
-
+ only_mutable_units (bool, optional): Whether only return config of
+ prunable units. It can omit unmutable MutableChannelUnits
+ to decrease the length of the config. Defaults to False.
+ with_unit_init_args (bool, optional): Whether return init_args of
+ units. Let it be true, when you want to change the init
+ args of units. Defaults to False.
+ with_channels (bool, optional): Whether return channel info.
+ The channel info can initialization the units without
+ tracer. When you want to prune your model without a
+ tracer next time, let it be true. Defaults to False.
Example:
dict(
channel_unit_cfg = dict(
@@ -190,23 +207,40 @@ def fix_channel_mutables(self):
@property
def current_choices(self) -> Dict:
"""Get current choices."""
- config = self.choice_template
- for unit in self.mutable_units:
- config[unit.name] = unit.current_choice
- return config
-
- def set_choices(self, config: Dict[str, Union[int, float]]):
- """Set choices."""
- for name, choice in config.items():
- unit = self._name2unit[name]
- unit.current_choice = choice
-
- def sample_choices(self) -> Dict[str, Union[int, float]]:
- """Sample choices(pruning structure)."""
- template = self.choice_template
- for key in template:
- template[key] = self._name2unit[key].sample_choice()
- return template
+ current_choices = dict()
+ for group_id, modules in self.search_groups.items():
+ current_choices[group_id] = modules[0].current_choice
+
+ return current_choices
+
+ def sample_choices(self) -> Dict[int, Any]:
+ """Sampling by search groups.
+
+ The sampling result of the first mutable of each group is the sampling
+ result of this group.
+
+ Returns:
+ Dict[int, Any]: Random choices dict.
+ """
+ random_choices = dict()
+ for group_id, modules in self.search_groups.items():
+ random_choices[group_id] = modules[0].sample_choice()
+
+ return random_choices
+
+ def set_choices(self, choices: Dict[int, Any]) -> None:
+ """Set mutables' current choice according to choices sample by
+ :func:`sample_choices`.
+
+ Args:
+ choices (Dict[int, Any]): Choices dict. The key is group_id in
+ search groups, and the value is the sampling results
+ corresponding to this group.
+ """
+ for group_id, modules in self.search_groups.items():
+ choice = choices[group_id]
+ for module in modules:
+ module.current_choice = choice
@property
def choice_template(self) -> Dict:
@@ -223,12 +257,24 @@ def choice_template(self) -> Dict:
template[unit.name] = unit.current_choice
return template
- # implementation of abstract functions
+ @property
+ def search_groups(self) -> Dict[int, List]:
+ """Search group of the supernet.
- def search_groups(self) -> Dict:
- return self._name2unit
+ Note:
+ Search group is different from search space. The key of search
+ group is called ``group_id``, and the value is corresponding
+ searchable modules. The searchable modules will have the same
+ search space if they are in the same group.
+ Returns:
+ dict: Search group.
+ """
+ return self._search_groups
+
+ @property
def mutable_class_type(self) -> Type[ChannelUnitType]:
+ """Mutable class type supported by this mutator."""
return self.unit_class
# private methods
diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py
index a5350ab2b..1f8e3496b 100644
--- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py
+++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py
@@ -28,14 +28,16 @@ def __init__(self,
def min_choices(self) -> Dict:
"""Return the minimal pruning subnet(structure)."""
- template = self.choice_template
- for key in template:
- template[key] = self._name2unit[key].min_choice
- return template
+ min_choices = dict()
+ for group_id, modules in self.search_groups.items():
+ min_choices[group_id] = modules[0].min_choice
+
+ return min_choices
def max_choices(self) -> Dict:
"""Return the maximal pruning subnet(structure)."""
- template = self.choice_template
- for key in template:
- template[key] = self._name2unit[key].max_choice
- return template
+ max_choices = dict()
+ for group_id, modules in self.search_groups.items():
+ max_choices[group_id] = modules[0].max_choice
+
+ return max_choices
diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py
index 7c0d24fa6..9f5eb0075 100644
--- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py
+++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py
@@ -28,10 +28,20 @@ def __init__(self,
loss_calculator=dict(type='ImageClassifierPseudoLoss')),
init_cfg: Optional[Dict] = None) -> None:
- super().__init__(channel_unit_cfg, parse_cfg, init_cfg)
+ super().__init__(channel_unit_cfg, parse_cfg, None, init_cfg)
self.subnets = self._prepare_subnets(self.units_cfg)
+ def set_choices(self, config: Dict[str, float]): # type: ignore[override]
+ """Set choices."""
+ for name, choice in config.items():
+ unit = self._name2unit[name]
+ unit.current_choice = choice
+
+ def sample_choices(self):
+ """Sample choices(pruning structure)."""
+ raise RuntimeError
+
# private methods
def _prepare_subnets(self, unit_cfg: Dict) -> List[Dict[str, int]]:
diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py
new file mode 100644
index 000000000..7e735b263
--- /dev/null
+++ b/mmrazor/models/mutators/group_mixin.py
@@ -0,0 +1,222 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from collections import Counter
+from typing import Dict, List, Type
+
+from torch.nn import Module
+
+from ..mutables import BaseMutable
+
+
+class GroupMixin():
+ """A mixin for :class:`BaseMutator`, which can group mutables by
+ ``custom_group`` and ``alias``(see more information in
+ :class:`BaseMutable`). Grouping by alias and module name are both
+ supported.
+
+ Note:
+ Apart from user-defined search group, all other searchable
+ modules(mutable) will be grouped separately.
+
+ The main difference between using alias and module name for
+ grouping is that the alias is One-to-Many while the module
+ name is One-to-One.
+
+ When using both alias and module name in `custom_group`, the
+ priority of alias is higher than that of module name.
+
+ If alias is set in `custom_group`, then its corresponding module
+ name should not be in the `custom_group`.
+
+ Moreover, there should be no duplicate keys in the `custom_group`.
+
+ Example:
+ >>> import torch
+ >>> from mmrazor.models import DiffModuleMutator
+
+ >>> # Assume that a toy model consists of three mutables
+ >>> # whose name are op1,op2,op3. The corresponding
+ >>> # alias names of the three mutables are a1, a1, a2.
+ >>> model = ToyModel()
+
+ >>> # Using alias for grouping
+ >>> mutator = DiffModuleMutator(custom_group=[['a1'], ['a2']])
+ >>> mutator.prepare_from_supernet(model)
+ >>> mutator.search_groups
+ {0: [op1, op2], 1: [op3]}
+
+ >>> # Using module name for grouping
+ >>> mutator = DiffModuleMutator(custom_group=[['op1', 'op2'], ['op3']])
+
+ >>> # Using module name for grouping
+ >>> mutator.prepare_from_supernet(model)
+ >>> mutator.search_groups
+ {0: [op1, op2], 1: [op3]}
+
+ >>> # Using both alias and module name for grouping
+ >>> mutator = DiffModuleMutator(custom_group=[['a2'], ['op2']])
+ >>> mutator.prepare_from_supernet(model)
+ >>> # The last operation would be grouped
+ >>> mutator.search_groups
+ {0: [op3], 1: [op2], 2: [op1]}
+
+ """
+
+ def _build_name_mutable_mapping(
+ self, supernet: Module,
+ support_mutables: Type) -> Dict[str, BaseMutable]:
+ """Mapping module name to mutable."""
+ name2mutable: Dict[str, BaseMutable] = dict()
+ for name, module in supernet.named_modules():
+ if isinstance(module, support_mutables):
+ name2mutable[name] = module
+ self._name2mutable = name2mutable
+
+ return name2mutable
+
+ def _build_alias_names_mapping(
+ self, supernet: Module,
+ support_mutables: Type) -> Dict[str, List[str]]:
+ """Mapping alias to module names."""
+ alias2mutable_names: Dict[str, List[str]] = dict()
+ for name, module in supernet.named_modules():
+ if isinstance(module, support_mutables):
+
+ if module.alias is not None:
+ if module.alias not in alias2mutable_names:
+ alias2mutable_names[module.alias] = [name]
+ else:
+ alias2mutable_names[module.alias].append(name)
+
+ return alias2mutable_names
+
+ def build_search_groups(self, supernet: Module, support_mutables: Type,
+ custom_groups: List[List[str]]) -> Dict[int, List]:
+ """Build search group with ``custom_group`` and ``alias``(see more
+ information in :class:`BaseMutable`). Grouping by alias and module name
+ are both supported.
+
+ Args:
+ supernet (:obj:`torch.nn.Module`): The supernet to be searched
+ in your algorithm.
+ support_mutables (Type): Mutable type that can be grouped.
+ custom_group (list, optional): User-defined search groups.
+ All searchable modules that are not in ``custom_group`` will be
+ grouped separately.
+ """
+ name2mutable: Dict[str,
+ BaseMutable] = self._build_name_mutable_mapping(
+ supernet, support_mutables)
+ alias2mutable_names = self._build_alias_names_mapping(
+ supernet, support_mutables)
+
+ # Check whether the custom group is valid
+ if len(custom_groups) > 0:
+ self._check_valid_groups(alias2mutable_names, name2mutable,
+ custom_groups)
+
+ # Construct search_groups based on user-defined group
+ search_groups: Dict[int, List[BaseMutable]] = dict()
+
+ current_group_nums = 0
+ grouped_mutable_names: List[str] = list()
+ grouped_alias: List[str] = list()
+ for group in custom_groups:
+ group_mutables = list()
+ for item in group:
+ if item in alias2mutable_names:
+ # if the item is from alias name
+ mutable_names: List[str] = alias2mutable_names[item]
+ grouped_alias.append(item)
+ group_mutables.extend(
+ [name2mutable[n] for n in mutable_names])
+ grouped_mutable_names.extend(mutable_names)
+ else:
+ # if the item is in name2mutable
+ group_mutables.append(name2mutable[item])
+ grouped_mutable_names.append(item)
+
+ search_groups[current_group_nums] = group_mutables
+ current_group_nums += 1
+
+ # Construct search_groups based on alias
+ for alias, mutable_names in alias2mutable_names.items():
+ if alias not in grouped_alias:
+ # Check whether all current names are already grouped
+ flag_all_grouped = True
+ for mutable_name in mutable_names:
+ if mutable_name not in grouped_mutable_names:
+ flag_all_grouped = False
+
+ # If not all mutables are already grouped
+ if not flag_all_grouped:
+ search_groups[current_group_nums] = []
+ for mutable_name in mutable_names:
+ if mutable_name not in grouped_mutable_names:
+ search_groups[current_group_nums].append(
+ name2mutable[mutable_name])
+ grouped_mutable_names.append(mutable_name)
+ current_group_nums += 1
+
+ # check whether all the mutable objects are in the search_groups
+ for name, module in supernet.named_modules():
+ if isinstance(module, support_mutables):
+ if name in grouped_mutable_names:
+ continue
+ else:
+ search_groups[current_group_nums] = [module]
+ current_group_nums += 1
+
+ grouped_counter = Counter(grouped_mutable_names)
+
+ # find duplicate keys
+ duplicate_keys = list()
+ for key, count in grouped_counter.items():
+ if count > 1:
+ duplicate_keys.append(key)
+
+ assert len(grouped_mutable_names) == len(
+ list(set(grouped_mutable_names))), \
+ 'There are duplicate keys in grouped mutable names. ' \
+ f'The duplicate keys are {duplicate_keys}. ' \
+ 'Please check if there are duplicate keys in the `custom_group`.'
+
+ return search_groups
+
+ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
+ name2mutable: Dict[str, BaseMutable],
+ custom_group: List[List[str]]) -> None:
+
+ aliases = [*alias2mutable_names.keys()]
+ module_names = [*name2mutable.keys()]
+
+ # check if all keys are legal
+ expanded_custom_group: List[str] = [
+ _ for group in custom_group for _ in group
+ ]
+ legal_keys: List[str] = [*aliases, *module_names]
+
+ for key in expanded_custom_group:
+ if key not in legal_keys:
+ raise AssertionError(
+ f'The key: {key} in `custom_group` is not legal. '
+ f'Legal keys are: {legal_keys}. '
+ 'Make sure that the keys are either alias or mutable name')
+
+ # when the mutable has alias attribute, the corresponding module
+ # name should not be used in `custom_group`.
+ used_aliases = list()
+ for group in custom_group:
+ for key in group:
+ if key in aliases:
+ used_aliases.append(key)
+
+ for alias_key in used_aliases:
+ mutable_names: List = alias2mutable_names[alias_key]
+ # check whether module name is in custom group
+ for mutable_name in mutable_names:
+ if mutable_name in expanded_custom_group:
+ raise AssertionError(
+ f'When a mutable is set alias attribute :{alias_key},'
+ f'the corresponding module name {mutable_name} should '
+ f'not be used in `custom_group` {custom_group}.')
diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py
index ac6358049..1f639ed28 100644
--- a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py
+++ b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py
@@ -25,9 +25,9 @@ class DiffModuleMutator(ModuleMutator):
"""
def __init__(self,
- custom_group: Optional[List[List[str]]] = None,
+ custom_groups: Optional[List[List[str]]] = None,
init_cfg: Optional[Dict] = None) -> None:
- super().__init__(custom_group=custom_group, init_cfg=init_cfg)
+ super().__init__(custom_groups=custom_groups, init_cfg=init_cfg)
def build_arch_param(self, num_choices) -> nn.Parameter:
"""Build learnable architecture parameters."""
diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py
index dc045932b..f30e933e0 100644
--- a/mmrazor/models/mutators/module_mutator/module_mutator.py
+++ b/mmrazor/models/mutators/module_mutator/module_mutator.py
@@ -1,14 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
-from collections import Counter
from typing import Dict, List, Optional, Type
from torch.nn import Module
from ..base_mutator import MUTABLE_TYPE, BaseMutator
+from ..group_mixin import GroupMixin
-class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
+class ModuleMutator(BaseMutator[MUTABLE_TYPE], GroupMixin):
"""The base class for mutable based mutator.
All subclass should implement the following APIS:
@@ -16,19 +16,19 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
- ``mutable_class_type``
Args:
- custom_group (list[list[str]], optional): User-defined search groups.
+ custom_groups (list[list[str]], optional): User-defined search groups.
All searchable modules that are not in ``custom_group`` will be
grouped separately.
"""
def __init__(self,
- custom_group: Optional[List[List[str]]] = None,
+ custom_groups: Optional[List[List[str]]] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg)
- if custom_group is None:
- custom_group = []
- self._custom_group = custom_group
+ if custom_groups is None:
+ custom_groups = []
+ self._custom_groups = custom_groups
self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None
# TODO
@@ -52,7 +52,9 @@ def prepare_from_supernet(self, supernet: Module) -> None:
supernet (:obj:`torch.nn.Module`): The supernet to be searched
in your algorithm.
"""
- self._build_search_groups(supernet)
+ self._search_groups = self.build_search_groups(supernet,
+ self.mutable_class_type,
+ self._custom_groups)
@property
def name2mutable(self) -> Dict[str, MUTABLE_TYPE]:
@@ -90,196 +92,3 @@ def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]:
raise RuntimeError(
'Call `prepare_from_supernet` before access search group!')
return self._search_groups
-
- def _build_name_mutable_mapping(
- self, supernet: Module) -> Dict[str, MUTABLE_TYPE]:
- """Mapping module name to mutable."""
- name2mutable: Dict[str, MUTABLE_TYPE] = dict()
- for name, module in supernet.named_modules():
- if isinstance(module, self.mutable_class_type):
- name2mutable[name] = module
- self._name2mutable = name2mutable
-
- return name2mutable
-
- def _build_alias_names_mapping(self,
- supernet: Module) -> Dict[str, List[str]]:
- """Mapping alias to module names."""
- alias2mutable_names: Dict[str, List[str]] = dict()
- for name, module in supernet.named_modules():
- if isinstance(module, self.mutable_class_type):
- if module.alias is not None:
- if module.alias not in alias2mutable_names:
- alias2mutable_names[module.alias] = [name]
- else:
- alias2mutable_names[module.alias].append(name)
-
- return alias2mutable_names
-
- def _build_search_groups(self, supernet: Module) -> None:
- """Build search group with ``custom_group`` and ``alias``(see more
- information in :class:`BaseMutable`). Grouping by alias and module name
- are both supported.
-
- Note:
- Apart from user-defined search group, all other searchable
- modules(mutable) will be grouped separately.
-
- The main difference between using alias and module name for
- grouping is that the alias is One-to-Many while the module
- name is One-to-One.
-
- When using both alias and module name in `custom_group`, the
- priority of alias is higher than that of module name.
-
- If alias is set in `custom_group`, then its corresponding module
- name should not be in the `custom_group`.
-
- Moreover, there should be no duplicate keys in the `custom_group`.
-
- Example:
- >>> import torch
- >>> from mmrazor.models.mutables.diff_mutable import DiffMutableOP
-
- >>> # Assume that a toy model consists of three mutables
- >>> # whose name are op1,op2,op3. The corresponding
- >>> # alias names of the three mutables are a1, a1, a2.
- >>> model = ToyModel()
-
- >>> # Using alias for grouping
- >>> mutator = DiffMutableOP(custom_group=[['a1'], ['a2']])
- >>> mutator.prepare_from_supernet(model)
- >>> mutator.search_groups
- {0: [op1, op2], 1: [op3]}
-
- >>> # Using module name for grouping
- >>> mutator = DiffMutableOP(custom_group=[['op1', 'op2'], ['op3']])
- >>> mutator.prepare_from_supernet(model)
- >>> mutator.search_groups
- {0: [op1, op2], 1: [op3]}
-
- >>> # Using both alias and module name for grouping
- >>> mutator = DiffMutableOP(custom_group=[['a2'], ['op2']])
- >>> mutator.prepare_from_supernet(model)
- >>> # The last operation would be grouped
- >>> mutator.search_groups
- {0: [op3], 1: [op2], 2: [op1]}
-
-
- Args:
- supernet (:obj:`torch.nn.Module`): The supernet to be searched
- in your algorithm.
- """
- name2mutable = self._build_name_mutable_mapping(supernet)
- alias2mutable_names = self._build_alias_names_mapping(supernet)
-
- # Check whether the custom group is valid
- if len(self._custom_group) > 0:
- self._check_valid_groups(alias2mutable_names, name2mutable,
- self._custom_group)
-
- # Construct search_groups based on user-defined group
- search_groups: Dict[int, List[MUTABLE_TYPE]] = dict()
-
- current_group_nums = 0
- grouped_mutable_names: List[str] = list()
- grouped_alias: List[str] = list()
- for group in self._custom_group:
- group_mutables = list()
- for item in group:
- if item in alias2mutable_names:
- # if the item is from alias name
- mutable_names: List[str] = alias2mutable_names[item]
- grouped_alias.append(item)
- group_mutables.extend(
- [name2mutable[n] for n in mutable_names])
- grouped_mutable_names.extend(mutable_names)
- else:
- # if the item is in name2mutable
- group_mutables.append(name2mutable[item])
- grouped_mutable_names.append(item)
-
- search_groups[current_group_nums] = group_mutables
- current_group_nums += 1
-
- # Construct search_groups based on alias
- for alias, mutable_names in alias2mutable_names.items():
- if alias not in grouped_alias:
- # Check whether all current names are already grouped
- flag_all_grouped = True
- for mutable_name in mutable_names:
- if mutable_name not in grouped_mutable_names:
- flag_all_grouped = False
-
- # If not all mutables are already grouped
- if not flag_all_grouped:
- search_groups[current_group_nums] = []
- for mutable_name in mutable_names:
- if mutable_name not in grouped_mutable_names:
- search_groups[current_group_nums].append(
- name2mutable[mutable_name])
- grouped_mutable_names.append(mutable_name)
- current_group_nums += 1
-
- # check whether all the mutable objects are in the search_groups
- for name, module in supernet.named_modules():
- if isinstance(module, self.mutable_class_type):
- if name in grouped_mutable_names:
- continue
- else:
- search_groups[current_group_nums] = [module]
- current_group_nums += 1
-
- grouped_counter = Counter(grouped_mutable_names)
-
- # find duplicate keys
- duplicate_keys = list()
- for key, count in grouped_counter.items():
- if count > 1:
- duplicate_keys.append(key)
-
- assert len(grouped_mutable_names) == len(
- list(set(grouped_mutable_names))), \
- 'There are duplicate keys in grouped mutable names. ' \
- f'The duplicate keys are {duplicate_keys}. ' \
- 'Please check if there are duplicate keys in the `custom_group`.'
-
- self._search_groups = search_groups
-
- def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
- name2mutable: Dict[str, MUTABLE_TYPE],
- custom_group: List[List[str]]) -> None:
-
- aliases = [*alias2mutable_names.keys()]
- module_names = [*name2mutable.keys()]
-
- # check if all keys are legal
- expanded_custom_group: List[str] = [
- _ for group in custom_group for _ in group
- ]
- legal_keys: List[str] = [*aliases, *module_names]
-
- for key in expanded_custom_group:
- if key not in legal_keys:
- raise AssertionError(
- f'The key: {key} in `custom_group` is not legal. '
- f'Legal keys are: {legal_keys}. '
- 'Make sure that the keys are either alias or mutable name')
-
- # when the mutable has alias attribute, the corresponding module
- # name should not be used in `custom_group`.
- used_aliases = list()
- for group in custom_group:
- for key in group:
- if key in aliases:
- used_aliases.append(key)
-
- for alias_key in used_aliases:
- mutable_names: List = alias2mutable_names[alias_key]
- # check whether module name is in custom group
- for mutable_name in mutable_names:
- if mutable_name in expanded_custom_group:
- raise AssertionError(
- f'When a mutable is set alias attribute :{alias_key},'
- f'the corresponding module name {mutable_name} should '
- f'not be used in `custom_group` {custom_group}.')
diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py
new file mode 100644
index 000000000..22af9bae9
--- /dev/null
+++ b/mmrazor/models/observers/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .minmax import EMAMinMaxObserver, MinMaxObserver
+from .mse import MSEObserver
+
+__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver']
diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py
new file mode 100644
index 000000000..e10738664
--- /dev/null
+++ b/mmrazor/models/observers/base.py
@@ -0,0 +1,74 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch.ao.quantization.observer import UniformQuantizationObserverBase
+
+from mmrazor.models.utils import pot_quantization, sync_tensor
+
+# from mmengine.model import BaseModule
+
+
+class BaseObserver(UniformQuantizationObserverBase):
+ """Modified torch quantization observer.
+
+ Args:
+ dtype: dtype argument to the `quantize` node needed to implement the
+ reference model spec.
+ qscheme: Quantization scheme to be used.
+ reduce_range: Reduces the range of the quantized data type by 1 bit.
+ This is sometimes required to avoid instruction overflow.
+ quant_min: Minimum quantization value. If unspecified, it will follow
+ the 8-bit setup.
+ quant_max: Maximum quantization value. If unspecified, it will follow
+ the 8-bit setup.
+ ch_axis (int, optional): Channel axis index. Defaults to -1.
+ is_pot_scale (bool, optional): Indicate whether scale is power of two.
+ Defaults to False.
+ eps: Epsilon value for float32.
+ Defaults to `torch.finfo(torch.float32).eps`.
+ """
+
+ min_val: torch.Tensor
+ max_val: torch.Tensor
+
+ def __init__(self,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=None,
+ quant_max=None,
+ ch_axis=-1,
+ is_pot_scale=False,
+ factory_kwargs=None,
+ eps=torch.finfo(torch.float32).eps) -> None:
+ super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
+ factory_kwargs, eps)
+ factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
+ self.register_buffer('min_val',
+ torch.tensor(float('inf'), **factory_kwargs))
+ self.register_buffer('max_val',
+ torch.tensor(float('-inf'), **factory_kwargs))
+ self.ch_axis = ch_axis
+ self.is_pot_scale = is_pot_scale
+
+ @torch.jit.export
+ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Calculates the quantization parameters."""
+ scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
+ scale.data = sync_tensor(scale).data
+ zero_point.data = sync_tensor(zero_point).data
+ if self.is_pot_scale:
+ scale = pot_quantization(scale)
+ return scale, zero_point
+
+ @torch.jit.export
+ def extra_repr(self):
+ return 'min_val={}, max_val={} ch_axis={} is_pot_scale={}'.format(
+ self.min_val, self.max_val, self.ch_axis, self.is_pot_scale)
+
+ @torch.jit.export
+ def reset_min_max_vals(self):
+ """Resets the min/max values."""
+ self.min_val.copy_(torch.tensor(float('inf')))
+ self.max_val.copy_(torch.tensor(float('-inf')))
diff --git a/mmrazor/models/observers/lsq_observer.py b/mmrazor/models/observers/lsq_observer.py
new file mode 100644
index 000000000..d9b96d7a8
--- /dev/null
+++ b/mmrazor/models/observers/lsq_observer.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+
+from mmrazor.registry import MODELS
+from ..utils import _is_symmetric_quant, pot_quantization, sync_tensor
+from .base import BaseObserver
+
+
+@MODELS.register_module()
+class LSQObserver(BaseObserver):
+ """Observer for `LEARNED STEP SIZE QUANTIZATION`"""
+
+ def __init__(self,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=None,
+ quant_max=None,
+ ch_axis=-1,
+ is_pot_scale=False,
+ factory_kwargs=None,
+ eps=torch.finfo(torch.float32).eps) -> None:
+ super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
+ ch_axis, is_pot_scale, factory_kwargs, eps)
+
+ self.tensor_norm = None
+
+ def forward(self, x_orig):
+ if x_orig.numel() == 0:
+ return x_orig
+ x = x_orig.to(self.min_val.dtype)
+ if self.ch_axis == -1:
+ self.tensor_norm = x.abs().mean()
+ self.min_val, self.max_val = torch._aminmax(x)
+ else:
+ # compute channel-wise mean
+ x_dim = x.size()
+ new_axis_list = [i for i in range(len(x_dim))]
+ new_axis_list[self.ch_axis] = 0
+ new_axis_list[0] = self.ch_axis
+ y = x.permute(new_axis_list)
+ y = torch.flatten(y, start_dim=1)
+ self.tensor_norm = y.abs().mean(1)
+ self.min_val, self.max_val = torch._aminmax(y, 1)
+
+ return x
+
+ def calculate_qparams(self):
+ scale = 2 * self.tensor_norm / math.sqrt(self.quant_max)
+ zero_point = torch.zeros_like(self.tensor_norm)
+ sync_tensor(scale)
+ sync_tensor(zero_point)
+ if self.is_pot_scale:
+ scale = pot_quantization(scale)
+ if not _is_symmetric_quant(self.qscheme):
+ zero_point = self.quant_min - torch.round(self.min_val / scale)
+ return scale, zero_point
diff --git a/mmrazor/models/observers/minmax.py b/mmrazor/models/observers/minmax.py
new file mode 100644
index 000000000..099296536
--- /dev/null
+++ b/mmrazor/models/observers/minmax.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmrazor.registry import MODELS
+from .base import BaseObserver
+
+
+@MODELS.register_module()
+class MinMaxObserver(BaseObserver):
+ """Min max observer."""
+
+ def __init__(self,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=None,
+ quant_max=None,
+ ch_axis=-1,
+ is_pot_scale=False,
+ factory_kwargs=None,
+ eps=torch.finfo(torch.float32).eps) -> None:
+ super(MinMaxObserver, self).__init__(dtype, qscheme, reduce_range,
+ quant_min, quant_max, ch_axis,
+ is_pot_scale, factory_kwargs, eps)
+ if (self.qscheme == torch.per_tensor_symmetric and self.reduce_range
+ and self.dtype == torch.quint8):
+ raise NotImplementedError('Cannot reduce range for symmetric \
+ quantization for quint8')
+
+ def forward(self, x_orig):
+ r"""Records the running minimum and maximum of ``x``."""
+ if x_orig.numel() == 0:
+ return x_orig
+ x = x_orig.detach() # avoid keeping autograd tape
+ x = x.to(self.min_val.dtype)
+ if self.ch_axis == -1:
+ min_val_cur, max_val_cur = torch._aminmax(x)
+ else:
+ x_dim = x.size()
+ new_axis_list = [i for i in range(len(x_dim))]
+ new_axis_list[self.ch_axis] = 0
+ new_axis_list[0] = self.ch_axis
+ y = x.permute(new_axis_list)
+ y = torch.flatten(y, start_dim=1)
+ min_val_cur, max_val_cur = torch._aminmax(y, 1)
+ min_val = torch.min(self.min_val, min_val_cur)
+ max_val = torch.max(self.max_val, max_val_cur)
+ self.min_val.copy_(min_val)
+ self.max_val.copy_(max_val)
+
+ return x
+
+
+@MODELS.register_module()
+class EMAMinMaxObserver(BaseObserver):
+ """Moving average min/max among batches."""
+
+ def __init__(self,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=None,
+ quant_max=None,
+ ch_axis=-1,
+ is_pot_scale=False,
+ ema_ratio=0.9,
+ factory_kwargs=None):
+ super(EMAMinMaxObserver,
+ self).__init__(dtype, qscheme, reduce_range, quant_min,
+ quant_max, ch_axis, is_pot_scale, factory_kwargs)
+ self.ema_ratio = ema_ratio
+
+ def forward(self, x_orig):
+ r"""Records the running minimum and maximum of ``x``."""
+ if x_orig.numel() == 0:
+ return x_orig
+ x = x_orig.to(self.min_val.dtype)
+ if self.ch_axis == -1:
+ min_val_cur, max_val_cur = torch._aminmax(x)
+ else:
+ x_dim = x.size()
+ new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
+ new_axis_list[self.ch_axis] = 0
+ new_axis_list[0] = self.ch_axis
+ y = x.permute(new_axis_list)
+ y = torch.flatten(y, start_dim=1)
+ min_val_cur, max_val_cur = torch._aminmax(y, 1)
+
+ if self.max_val.numel() <= 1 and self.max_val.isinf():
+ self.min_val = min_val_cur
+ self.max_val = max_val_cur
+ else:
+ self.min_val = self.min_val * self.ema_ratio + min_val_cur * (
+ 1.0 - self.ema_ratio)
+ self.max_val = self.max_val * self.ema_ratio + max_val_cur * (
+ 1.0 - self.ema_ratio)
+ return x
diff --git a/mmrazor/models/observers/mse.py b/mmrazor/models/observers/mse.py
new file mode 100644
index 000000000..f85abd902
--- /dev/null
+++ b/mmrazor/models/observers/mse.py
@@ -0,0 +1,156 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmrazor.registry import MODELS
+from .base import BaseObserver
+
+_version_under_1100 = int(torch.__version__.split('.')[1]) < 10
+
+
+@MODELS.register_module()
+class MSEObserver(BaseObserver):
+ """MSE observer."""
+
+ def __init__(self,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=None,
+ quant_max=None,
+ ch_axis=-1,
+ is_pot_scale=False,
+ p=2.0,
+ factory_kwargs=None,
+ eps=torch.finfo(torch.float32).eps) -> None:
+ super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
+ ch_axis, is_pot_scale, factory_kwargs, eps)
+ self.p = p
+
+ def lp_loss(self, pred, tgt, dim=None):
+ """loss function measured in L_p Norm."""
+ return (pred - tgt).abs().pow(
+ self.p).mean(dim) if dim else (pred -
+ tgt).abs().pow(self.p).mean()
+
+ def mse(self,
+ x: torch.Tensor,
+ x_min: torch.Tensor,
+ x_max: torch.Tensor,
+ iter=80):
+ best_score = 1e+10
+ best_min, best_max = torch.tensor(
+ [1.0], dtype=torch.float), torch.tensor([1.0], dtype=torch.float)
+ best_min.copy_(x_min)
+ best_max.copy_(x_max)
+ for i in range(iter):
+ new_min = x_min * (1.0 - (i * 0.01))
+ new_max = x_max * (1.0 - (i * 0.01))
+ scale, zero_point = self._calculate_qparams(new_min, new_max)
+ x_q = torch.fake_quantize_per_tensor_affine(
+ x, scale.item(), int(zero_point.item()), self.quant_min,
+ self.quant_max)
+ score = self.lp_loss(x_q, x)
+ if score < best_score:
+ best_score = score
+ best_min, best_max = new_min, new_max
+ return best_min, best_max
+
+ def mse_perchannel(self,
+ x: torch.Tensor,
+ x_min: torch.Tensor,
+ x_max: torch.Tensor,
+ iter=80,
+ ch_axis=0):
+ assert x_min.shape == x_max.shape
+ assert ch_axis >= 0, f'{ch_axis}'
+ best_score = 1e+10 * torch.ones_like(x_min)
+ best_min, best_max = x_min.clone(), x_max.clone()
+ reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis])
+ for i in range(iter):
+ new_min = x_min * (1.0 - (i * 0.01))
+ new_max = x_max * (1.0 - (i * 0.01))
+ scale, zero_point = self._calculate_qparams(new_min, new_max)
+ x_q = torch.fake_quantize_per_channel_affine(
+ x, scale,
+ zero_point.long() if _version_under_1100 else zero_point,
+ ch_axis, self.quant_min, self.quant_max)
+ score = self.lp_loss(x_q, x, reduce_dim)
+ update_idx = (score < best_score)
+ best_score[update_idx] = score[update_idx]
+ best_min[update_idx] = new_min[update_idx]
+ best_max[update_idx] = new_max[update_idx]
+ return best_min, best_max
+
+ def forward(self, x_orig):
+ r"""Records the running minimum and maximum of ``x``."""
+ if x_orig.numel() == 0:
+ return x_orig
+ x = x_orig.clone().detach().to(self.min_val.dtype)
+ if self.ch_axis == -1:
+ min_val_cur, max_val_cur = torch._aminmax(x)
+ min_val_cur, max_val_cur = self.mse(
+ x, min_val_cur, max_val_cur, iter=95)
+ else:
+ x_dim = x.size()
+ new_axis_list = [i for i in range(len(x_dim))]
+ new_axis_list[self.ch_axis] = 0
+ new_axis_list[0] = self.ch_axis
+ x_channel = x.permute(new_axis_list)
+ y = torch.flatten(x_channel, start_dim=1)
+ min_val_cur, max_val_cur = torch._aminmax(y, 1)
+ min_val_cur, max_val_cur = self.mse_perchannel(
+ x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis)
+
+ self.min_val = torch.min(self.min_val, min_val_cur)
+ self.max_val = torch.max(self.max_val, max_val_cur)
+ return x
+
+
+@MODELS.register_module()
+class EMAMSEObserver(MSEObserver):
+
+ def __init__(self,
+ dtype=torch.quint8,
+ qscheme=torch.per_tensor_affine,
+ reduce_range=False,
+ quant_min=None,
+ quant_max=None,
+ ch_axis=-1,
+ is_pot_scale=False,
+ p=2.0,
+ ema_ratio=0.9,
+ factory_kwargs=None,
+ eps=torch.finfo(torch.float32).eps) -> None:
+ super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
+ ch_axis, is_pot_scale, p, factory_kwargs, eps)
+ self.ema_ratio = ema_ratio
+
+ def forward(self, x_orig):
+ r"""Records the running minimum and maximum of ``x``."""
+ if x_orig.numel() == 0:
+ return x_orig
+ x = x_orig.clone().detach().to(self.min_val.dtype)
+ if self.ch_axis == -1:
+ min_val_cur, max_val_cur = torch._aminmax(x)
+ min_val_cur, max_val_cur = self.mse(
+ x, min_val_cur, max_val_cur, iter=95)
+ else:
+ x_dim = x.size()
+ new_axis_list = [i for i in range(len(x_dim))]
+ new_axis_list[self.ch_axis] = 0
+ new_axis_list[0] = self.ch_axis
+ x_channel = x.permute(new_axis_list)
+ y = torch.flatten(x_channel, start_dim=1)
+ min_val_cur, max_val_cur = torch._aminmax(y, 1)
+ min_val_cur, max_val_cur = self.mse_perchannel(
+ x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis)
+
+ if self.max_val.numel() <= 1 and self.max_val.isinf():
+ self.min_val = min_val_cur
+ self.max_val = max_val_cur
+ else:
+ self.min_val = self.min_val * self.ema_ratio + min_val_cur * (
+ 1.0 - self.ema_ratio)
+ self.max_val = self.max_val * self.ema_ratio + max_val_cur * (
+ 1.0 - self.ema_ratio)
+ return x
diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py
new file mode 100644
index 000000000..e56902eba
--- /dev/null
+++ b/mmrazor/models/quantizers/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import CustomQuantizer
+from .trt_quantizer import TensorRTQuantizer
+
+__all__ = ['CustomQuantizer', 'TensorRTQuantizer']
diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py
new file mode 100644
index 000000000..ab4cf190a
--- /dev/null
+++ b/mmrazor/models/quantizers/base.py
@@ -0,0 +1,194 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List
+
+import torch
+from mmengine.model import BaseModule
+from torch.ao.quantization import QConfig
+from torch.ao.quantization.fx import prepare
+from torch.ao.quantization.quantize_fx import _convert_fx, _fuse_fx
+
+from mmrazor.models.task_modules.tracer import CustomTracer
+from mmrazor.models.utils import (check_is_valid_convert_custom_config_dict,
+ check_is_valid_prepare_custom_config_dict,
+ check_is_valid_qconfig_dict,
+ get_custom_module_class_keys)
+from mmrazor.registry import MODELS
+from mmrazor.structures.quantization import (CheckArgs, DefalutQconfigs,
+ QuantizeScheme, SupportQtypes)
+
+
+@MODELS.register_module()
+class CustomQuantizer(BaseModule):
+ """Configurable quantizer, base class of quantizers.
+
+ Args:
+ qconfig (Dict, optional): QConfig. Defaults to DefalutQconfigs['default']. # noqa: E501
+ is_qat (bool, optional): Is QAT ro not. Defaults to True.
+ skipped_methods (List, optional): Skipped methods list for tracer.
+ Defaults to None.
+ prepare_custom_config_dict (Dict, optional): `PrepareCustomConfig`
+ from `torch.quantization.fx`. Defaults to None.
+ convert_custom_config_dict (Dict, optional): `ConvertCustomConfig`
+ from `torch.quantization.fx`. Defaults to None.
+ equalization_qconfig_dict (Dict, optional): Custom `QConfig` effects
+ on all modules. Defaults to None.
+ _remove_qconfig (Dict, optional): Remove qconfig at the end of
+ `_convert_fx`. Defaults to True.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ qconfig: Dict = DefalutQconfigs['default'],
+ is_qat: bool = True,
+ skipped_methods: List = None,
+ prepare_custom_config_dict: Dict = None,
+ convert_custom_config_dict: Dict = None,
+ equalization_qconfig_dict: Dict = None,
+ _remove_qconfig: bool = True,
+ init_cfg: Dict = None):
+ super().__init__(init_cfg)
+ if self.check_qconfig(qconfig):
+ qconfig = self.qconfig_convert(qconfig)
+ self.qconfig_dict = {'': qconfig}
+ else:
+ raise ValueError('qconfig is incorrect!')
+
+ if prepare_custom_config_dict is None:
+ self.prepare_custom_config_dict = {}
+ else:
+ self.prepare_custom_config_dict = prepare_custom_config_dict
+ if convert_custom_config_dict is None:
+ self.convert_custom_config_dict = {}
+ else:
+ self.convert_custom_config_dict = convert_custom_config_dict
+ if equalization_qconfig_dict is None:
+ self.equalization_qconfig_dict = {}
+ else:
+ self.equalization_qconfig_dict = equalization_qconfig_dict
+
+ check_is_valid_qconfig_dict(self.qconfig_dict)
+ check_is_valid_prepare_custom_config_dict(
+ self.prepare_custom_config_dict)
+ check_is_valid_convert_custom_config_dict(
+ self.convert_custom_config_dict)
+ check_is_valid_qconfig_dict(self.equalization_qconfig_dict)
+
+ self.is_qat = is_qat
+ self.skipped_methods = skipped_methods
+ self._remove_qconfig = _remove_qconfig
+ self.tracer = self.build_tracer()
+
+ def prepare(self, model, graph_module):
+
+ preserved_attributes = self.prepare_custom_config_dict.get(
+ 'preserved_attributes', [])
+ for attr_name in preserved_attributes:
+ setattr(graph_module, attr_name, getattr(model, attr_name))
+
+ graph_module = self.fuse_model(graph_module)
+
+ prepared = prepare(
+ graph_module,
+ self.qconfig_dict,
+ self.is_qat,
+ self.tracer.node_name_to_scope,
+ prepare_custom_config_dict=self.prepare_custom_config_dict,
+ equalization_qconfig_dict=self.equalization_qconfig_dict
+ ) # type: ignore[operator]
+
+ for attr_name in preserved_attributes:
+ setattr(prepared, attr_name, getattr(model, attr_name))
+ return prepared
+
+ def convert(self, graph_module):
+ quantized = _convert_fx(
+ graph_module,
+ is_reference=False,
+ convert_custom_config_dict=self.convert_custom_config_dict,
+ _remove_qconfig=self._remove_qconfig,
+ qconfig_dict=self.qconfig_dict)
+ return quantized
+
+ def check_qconfig(self, qconfig):
+ is_pass = True
+ for arg in CheckArgs:
+ if arg == 'qtype':
+ if qconfig[arg] in SupportQtypes and arg in qconfig.keys():
+ continue
+ else:
+ is_pass = False
+ break
+ else:
+ if isinstance(qconfig[arg], dict) and arg in qconfig.keys():
+ continue
+ else:
+ is_pass = False
+ break
+ return is_pass
+
+ def qconfig_convert(self, qconfig):
+ self.w_qscheme = QuantizeScheme(**qconfig['w_qscheme'])
+ self.a_qscheme = QuantizeScheme(**qconfig['a_qscheme'])
+ w_observer = MODELS.get(qconfig['w_observer']['type'])
+ w_observer_kwargs = self.w_qscheme.to_observer_params()
+ a_observer = MODELS.get(qconfig['a_observer']['type'])
+ a_observer_kwargs = self.a_qscheme.to_observer_params()
+ self.w_observer = MODELS.get(qconfig['w_observer']['type']).with_args(
+ **self.w_qscheme.to_observer_params())
+ self.a_observer = MODELS.get(qconfig['a_observer']['type']).with_args(
+ **self.a_qscheme.to_observer_params())
+ self.w_fake_quant = MODELS.get(
+ qconfig['w_fake_quant']['type']).with_args(
+ observer=w_observer, **w_observer_kwargs)
+ self.a_fake_quant = MODELS.get(
+ qconfig['a_fake_quant']['type']).with_args(
+ observer=a_observer, **a_observer_kwargs)
+
+ torch_qconfig = QConfig(
+ weight=self.w_fake_quant, activation=self.a_fake_quant)
+ return torch_qconfig
+
+ def _swap_ff_with_fxff(self, model: torch.nn.Module) -> None:
+ r""" Swap FloatFunctional with FXFloatFunctional
+ """
+ modules_to_swap = []
+ for name, module in model.named_children():
+ if isinstance(module, torch.nn.quantized.FloatFunctional):
+ modules_to_swap.append(name)
+ else:
+ self._swap_ff_with_fxff(module)
+
+ for name in modules_to_swap:
+ del model._modules[name]
+ model._modules[name] = torch.nn.quantized.FXFloatFunctional()
+
+ def build_tracer(self):
+ skipped_module_names = self.prepare_custom_config_dict.get(
+ 'non_traceable_module_name', [])
+ skipped_module_classes = self.prepare_custom_config_dict.get(
+ 'non_traceable_module_class', [])
+ standalone_module_name_configs = self.prepare_custom_config_dict.get(
+ 'standalone_module_name', [])
+ skipped_module_names += [
+ config[0] for config in standalone_module_name_configs
+ ]
+
+ standalone_module_class_configs = self.prepare_custom_config_dict.get(
+ 'standalone_module_class', [])
+ skipped_module_classes += [
+ config[0] for config in standalone_module_class_configs
+ ]
+ float_custom_module_classes = get_custom_module_class_keys(
+ self.prepare_custom_config_dict,
+ 'float_to_observed_custom_module_class')
+ skipped_module_classes += float_custom_module_classes
+ tracer = CustomTracer(self.skipped_methods, skipped_module_names,
+ skipped_module_classes)
+ # tracer = QuantizationTracer(skipped_module_names,
+ # skipped_module_classes)
+ return tracer
+
+ def fuse_model(self, graph_module):
+ graph_module = _fuse_fx(graph_module, self.is_qat,
+ self.prepare_custom_config_dict)
+ return graph_module
diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py
new file mode 100644
index 000000000..cc8532a53
--- /dev/null
+++ b/mmrazor/models/quantizers/trt_quantizer.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmrazor.registry import MODELS
+from mmrazor.structures.quantization import DefalutQconfigs
+from .base import CustomQuantizer
+
+
+@MODELS.register_module()
+class TensorRTQuantizer(CustomQuantizer):
+ """Quantizer for TensorRT backend."""
+
+ def __init__(self,
+ qconfig=DefalutQconfigs['tensorrt'],
+ is_qat=True,
+ skipped_methods=None,
+ prepare_custom_config_dict=None,
+ convert_custom_config_dict=None,
+ equalization_qconfig_dict=None,
+ _remove_qconfig=True,
+ init_cfg=None):
+ super().__init__(qconfig, is_qat, skipped_methods,
+ prepare_custom_config_dict,
+ convert_custom_config_dict, equalization_qconfig_dict,
+ _remove_qconfig, init_cfg)
diff --git a/mmrazor/models/task_modules/delivery/distill_delivery.py b/mmrazor/models/task_modules/delivery/distill_delivery.py
index dcd56f388..d8c335f00 100644
--- a/mmrazor/models/task_modules/delivery/distill_delivery.py
+++ b/mmrazor/models/task_modules/delivery/distill_delivery.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
-from queue import Queue
+from collections import deque
from typing import Callable
@@ -33,7 +33,7 @@ class DistillDelivery(metaclass=ABCMeta):
def __init__(self, max_keep_data: int = 1) -> None:
self._override_data = False
- self.data_queue: Queue = Queue(maxsize=max_keep_data)
+ self.data_queue: deque = deque([], maxlen=max_keep_data)
self.max_keep_data = max_keep_data
@property
diff --git a/mmrazor/models/task_modules/delivery/function_outputs_delivery.py b/mmrazor/models/task_modules/delivery/function_outputs_delivery.py
index 19aadc2e9..15c361e38 100644
--- a/mmrazor/models/task_modules/delivery/function_outputs_delivery.py
+++ b/mmrazor/models/task_modules/delivery/function_outputs_delivery.py
@@ -78,23 +78,7 @@ def __init__(self, func_path: str, max_keep_data: int):
super().__init__(max_keep_data)
self._check_valid_path(func_path)
- module_path = self._get_module_path(func_path)
- try:
- module = import_modules_from_strings(module_path)
- except ImportError:
- raise ImportError(f'{module_path} is not imported correctly.')
- self.module = module
-
- func_name = self._get_func_name(func_path)
- assert hasattr(module, func_name), \
- f'{func_name} is not in {module_path}.'
- self.func_name = func_name
-
- origin_func = getattr(module, func_name)
- if not isinstance(origin_func, FunctionType):
- raise TypeError(f'{func_name} should be a FunctionType '
- f'instance, but got {type(origin_func)}')
- self.origin_func = origin_func
+ self.func_path = func_path
@staticmethod
def _check_valid_path(func_path: str) -> None:
@@ -121,6 +105,24 @@ def __enter__(self) -> None:
Wrap the origin function.
"""
+ module_path = self._get_module_path(self.func_path)
+ try:
+ module = import_modules_from_strings(module_path)
+ except ImportError:
+ raise ImportError(f'{module_path} is not imported correctly.')
+ self.module = module
+
+ func_name = self._get_func_name(self.func_path)
+ assert hasattr(module, func_name), \
+ f'{func_name} is not in {module_path}.'
+ self.func_name = func_name
+
+ origin_func = getattr(module, func_name)
+ if not isinstance(origin_func, FunctionType):
+ raise TypeError(f'{func_name} should be a FunctionType '
+ f'instance, but got {type(origin_func)}')
+ self.origin_func = origin_func
+
wrapped_func = self.deliver_wrapper(self.origin_func)
setattr(self.module, self.func_name, wrapped_func)
@@ -131,6 +133,11 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
setattr(self.module, self.func_name, self.origin_func)
+ # self.module and self.origin_func can not be pickled.
+ # Delete these two attributes to avoid errors when ema model is used.
+ del self.module
+ del self.origin_func
+
def deliver_wrapper(self, origin_func: Callable) -> Callable:
"""Wrap the specific function to make the intermediate results of the
model can be delivered."""
@@ -139,12 +146,13 @@ def deliver_wrapper(self, origin_func: Callable) -> Callable:
def wrap_func(*args, **kwargs):
if self.override_data:
- assert not self.data_queue.empty(), 'pop from an empty queue'
- outputs = self.data_queue.get()
+ assert len(self.data_queue) > 0, 'pop from an empty queue'
+ outputs = self.data_queue.popleft()
else:
- assert not self.data_queue.full(), 'push into an full queue'
+ assert len(self.data_queue) < self.data_queue.maxlen,\
+ 'push into an full queue'
outputs = origin_func(*args, **kwargs)
- self.data_queue.put(outputs)
+ self.data_queue.append(outputs)
return outputs
return wrap_func
diff --git a/mmrazor/models/task_modules/delivery/method_outputs_delivery.py b/mmrazor/models/task_modules/delivery/method_outputs_delivery.py
index dcaae2fd8..fa9f6c4a4 100644
--- a/mmrazor/models/task_modules/delivery/method_outputs_delivery.py
+++ b/mmrazor/models/task_modules/delivery/method_outputs_delivery.py
@@ -143,12 +143,13 @@ def deliver_wrapper(self, origin_method: Callable) -> Callable:
def wrap_method(*args, **kwargs):
if self.override_data:
- assert not self.data_queue.empty(), 'pop from an empty queue'
- outputs = self.data_queue.get()
+ assert len(self.data_queue) > 0, 'pop from an empty queue'
+ outputs = self.data_queue.popleft()
else:
- assert not self.data_queue.full(), 'push into an full queue'
+ assert len(self.data_queue) < self.data_queue.maxlen,\
+ 'push into an full queue'
outputs = origin_method(*args, **kwargs)
- self.data_queue.put(outputs)
+ self.data_queue.append(outputs)
return outputs
return wrap_method
diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py
index f31208248..df0c867c6 100644
--- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py
+++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from functools import partial
-from typing import Dict
+from typing import Dict, List
+import mmcv
import torch
import torch.nn as nn
@@ -497,9 +498,29 @@ def add_flops_params_counter_variable_or_reset(module):
module.__params__ = 0
-def get_counter_type(module):
- """Get counter type of the module based on the module class name."""
- return module.__class__.__name__ + 'Counter'
+def get_counter_type(module) -> str:
+ """Get counter type of the module based on the module class name.
+
+ If the current module counter_type is not in TASK_UTILS._module_dict,
+ it will search the base classes of the module to see if it matches any
+ base class counter_type.
+
+ Returns:
+ str: Counter type (or the base counter type) of the current module.
+ """
+ counter_type = module.__class__.__name__ + 'Counter'
+ if counter_type not in TASK_UTILS._module_dict.keys():
+ old_counter_type = counter_type
+ assert nn.Module in module.__class__.mro()
+ for base_cls in module.__class__.mro():
+ if base_cls in get_modules_list():
+ counter_type = base_cls.__name__ + 'Counter'
+ from mmengine import MMLogger
+ logger = MMLogger.get_current_instance()
+ logger.warning(f'`{old_counter_type}` not in op_counters. '
+ f'Using `{counter_type}` instead.')
+ break
+ return counter_type
def is_supported_instance(module):
@@ -518,3 +539,54 @@ def remove_flops_params_counter_hook_function(module):
del module.__flops__
if hasattr(module, '__params__'):
del module.__params__
+
+
+def get_modules_list() -> List:
+ return [
+ # convolutions
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ mmcv.cnn.bricks.Conv2d,
+ mmcv.cnn.bricks.Conv3d,
+ # activations
+ nn.ReLU,
+ nn.PReLU,
+ nn.ELU,
+ nn.LeakyReLU,
+ nn.ReLU6,
+ # poolings
+ nn.MaxPool1d,
+ nn.AvgPool1d,
+ nn.AvgPool2d,
+ nn.MaxPool2d,
+ nn.MaxPool3d,
+ nn.AvgPool3d,
+ mmcv.cnn.bricks.MaxPool2d,
+ mmcv.cnn.bricks.MaxPool3d,
+ nn.AdaptiveMaxPool1d,
+ nn.AdaptiveAvgPool1d,
+ nn.AdaptiveMaxPool2d,
+ nn.AdaptiveAvgPool2d,
+ nn.AdaptiveMaxPool3d,
+ nn.AdaptiveAvgPool3d,
+ # normalizations
+ nn.BatchNorm1d,
+ nn.BatchNorm2d,
+ nn.BatchNorm3d,
+ nn.GroupNorm,
+ nn.InstanceNorm1d,
+ nn.InstanceNorm2d,
+ nn.InstanceNorm3d,
+ nn.LayerNorm,
+ # FC
+ nn.Linear,
+ mmcv.cnn.bricks.Linear,
+ # Upscale
+ nn.Upsample,
+ nn.UpsamplingNearest2d,
+ nn.UpsamplingBilinear2d,
+ # Deconvolution
+ nn.ConvTranspose2d,
+ mmcv.cnn.bricks.ConvTranspose2d,
+ ]
diff --git a/mmrazor/models/task_modules/recorder/__init__.py b/mmrazor/models/task_modules/recorder/__init__.py
index 6d1858f0b..8af399126 100644
--- a/mmrazor/models/task_modules/recorder/__init__.py
+++ b/mmrazor/models/task_modules/recorder/__init__.py
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .function_inputs_recorder import FunctionInputsRecorder
from .function_outputs_recorder import FunctionOutputsRecorder
+from .method_inputs_recorder import MethodInputsRecorder
from .method_outputs_recorder import MethodOutputsRecorder
from .module_inputs_recorder import ModuleInputsRecorder
from .module_outputs_recorder import ModuleOutputsRecorder
@@ -9,5 +11,5 @@
__all__ = [
'FunctionOutputsRecorder', 'MethodOutputsRecorder',
'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
- 'ModuleInputsRecorder'
+ 'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder'
]
diff --git a/mmrazor/models/task_modules/recorder/function_inputs_recorder.py b/mmrazor/models/task_modules/recorder/function_inputs_recorder.py
new file mode 100644
index 000000000..e7bbdd896
--- /dev/null
+++ b/mmrazor/models/task_modules/recorder/function_inputs_recorder.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+from inspect import signature
+from typing import Callable, List
+
+from mmrazor.registry import TASK_UTILS
+from .function_outputs_recorder import FunctionOutputsRecorder
+
+
+@TASK_UTILS.register_module()
+class FunctionInputsRecorder(FunctionOutputsRecorder):
+ """Recorder for intermediate results which are ``FunctionType``'s inputs.
+
+ Notes:
+ The form of `source` needs special attention. For example,
+ `anchor_inside_flags` is a function in mmdetection to check whether the
+ anchors are inside the border. This function is in
+ `mmdet/core/anchor/utils.py` and used in
+ `mmdet/models/dense_heads/anchor_head.py`. Then the source should be
+ `mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not
+ `mmdet.core.anchor.utils.anchor_inside_flags`.
+
+
+ Examples:
+ >>> # Below code in toy_module.py
+ >>> import random
+ >>> def toy_func(a, b):
+ ... return a, b
+ >>> def execute_toy_func(a, b):
+ ... toy_func(a, b)
+
+ >>> # Below code in main.py
+ >>> # Now, we want to get teacher's inputs by recorder.
+
+ >>> from toy_module import execute_toy_func
+ >>> r1 = FunctionInputsRecorder('toy_module.toy_func')
+ >>> r1.initialize()
+ >>> with r1:
+ ... execute_toy_func(1, 2)
+ ... execute_toy_func(1, b=2)
+ ... execute_toy_func(b=2, a=1)
+
+ >>> r1.data_buffer
+ [[1, 2], [1, 2], [1, 2]]
+ """
+
+ def func_record_wrapper(self, origin_func: Callable,
+ data_buffer: List) -> Callable:
+ """Save the function's inputs.
+
+ Args:
+ origin_func (FunctionType): The method whose inputs need to be
+ recorded.
+ data_buffer (list): A list of data.
+ """
+
+ func_input_params = signature(origin_func).parameters.keys()
+
+ @functools.wraps(origin_func)
+ def wrap_func(*args, **kwargs):
+ outputs = origin_func(*args, **kwargs)
+ inputs = list(args)
+ for keyword in func_input_params:
+ if keyword in kwargs:
+ inputs.append(kwargs[keyword])
+ # assume a func execute N times, there will be N inputs need to
+ # save.
+ data_buffer.append(inputs)
+ return outputs
+
+ return wrap_func
diff --git a/mmrazor/models/task_modules/recorder/function_outputs_recorder.py b/mmrazor/models/task_modules/recorder/function_outputs_recorder.py
index c6ab5228f..706c1a8f7 100644
--- a/mmrazor/models/task_modules/recorder/function_outputs_recorder.py
+++ b/mmrazor/models/task_modules/recorder/function_outputs_recorder.py
@@ -65,28 +65,8 @@ class FunctionOutputsRecorder(BaseRecorder):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
-
self._check_valid_source(self.source)
- # import the function corrosponding module
- try:
- mod = import_modules_from_strings(self.module_string)
- except ImportError:
- raise ImportError(
- f'{self.module_string} is not imported correctly.')
-
- self.imported_module: ModuleType = mod
-
- assert hasattr(mod, self.func_name), \
- f'{self.func_name} is not in {self.module_string}.'
-
- origin_func = getattr(mod, self.func_name)
- if not isinstance(origin_func, FunctionType):
- raise TypeError(f'{self.func_name} should be a FunctionType '
- f'instance, but got {type(origin_func)}')
-
- self.origin_func: Callable = origin_func
-
@staticmethod
def _check_valid_source(source):
"""Check if the source's format is valid."""
@@ -118,8 +98,7 @@ def func_record_wrapper(self, origin_func: Callable,
Args:
origin_func (FunctionType): The method whose outputs need to be
recorded.
- buffer_key (str): The key of the function's outputs saved in
- ``data_buffer``.
+ data_buffer (list): A list of data.
"""
@functools.wraps(origin_func)
@@ -136,8 +115,25 @@ def __enter__(self):
"""Enter the context manager."""
super().__enter__()
- mod = self.imported_module
- origin_func = self.origin_func
+ # import the function corrosponding module
+ try:
+ mod = import_modules_from_strings(self.module_string)
+ except ImportError:
+ raise ImportError(
+ f'{self.module_string} is not imported correctly.')
+
+ self.imported_module: ModuleType = mod
+
+ assert hasattr(mod, self.func_name), \
+ f'{self.func_name} is not in {self.module_string}.'
+
+ origin_func = getattr(mod, self.func_name)
+ if not isinstance(origin_func, FunctionType):
+ raise TypeError(f'{self.func_name} should be a FunctionType '
+ f'instance, but got {type(origin_func)}')
+
+ self.origin_func: Callable = origin_func
+
# add record wrapper to origin function.
record_func = self.func_record_wrapper(origin_func, self.data_buffer)
@@ -159,3 +155,8 @@ def __exit__(self, exc_type, exc_value, traceback):
# restore the origin function
setattr(mod, self.func_name, origin_func)
+
+ # self.imported_module and self.origin_func can not be pickled.
+ # Delete these two attributes to avoid errors when ema model is used.
+ del self.imported_module
+ del self.origin_func
diff --git a/mmrazor/models/task_modules/recorder/method_inputs_recorder.py b/mmrazor/models/task_modules/recorder/method_inputs_recorder.py
new file mode 100644
index 000000000..44cb41843
--- /dev/null
+++ b/mmrazor/models/task_modules/recorder/method_inputs_recorder.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+from inspect import signature
+from typing import Callable, List
+
+from mmrazor.registry import TASK_UTILS
+from .method_outputs_recorder import MethodOutputsRecorder
+
+
+@TASK_UTILS.register_module()
+class MethodInputsRecorder(MethodOutputsRecorder):
+ """Recorder for intermediate results which are ``MethodType``'s inputs.
+
+ Note:
+ Different from ``FunctionType``, ``MethodType`` is the type of methods
+ of class instances.
+
+ Examples:
+ >>> # Below code in toy_module.py
+ >>> import random
+ >>> class Toy():
+ ... def toy_func(self, x, y=0):
+ ... return x + y
+
+ >>> # Below code in main.py
+ >>> # Now, we want to get teacher's inputs by recorder.
+
+ >>> from toy_module import Toy
+ >>> toy = Toy()
+ >>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
+ >>> r1.initialize()
+ >>> with r1:
+ ... _ = toy.toy_func(1, 2)
+
+ >>> r1.data_buffer
+ [[1, 2]]
+ >>> r1.get_record_data(record_idx=0, data_idx=0)
+ 1
+ >>> r1.get_record_data(record_idx=0, data_idx=1)
+ 2
+
+ >>> from toy_module import Toy
+ >>> toy = Toy()
+ >>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
+ >>> r1.initialize()
+ >>> with r1:
+ ... _ = toy.toy_func(1, 2)
+ ... _ = toy.toy_func(y=2, x=1)
+
+ >>> r1.data_buffer
+ [[1, 2], [1, 2]]
+ >>> r1.get_record_data(record_idx=1, data_idx=0)
+ 1
+ >>> r1.get_record_data(record_idx=1, data_idx=1)
+ 2
+ """
+
+ def method_record_wrapper(self, orgin_method: Callable,
+ data_buffer: List) -> Callable:
+ """Save the method's inputs.
+
+ Args:
+ origin_method (MethodType): The method whose inputs need to be
+ recorded.
+ data_buffer (list): A list of data.
+ """
+
+ method_input_params = signature(orgin_method).parameters.keys()
+
+ @functools.wraps(orgin_method)
+ def wrap_method(*args, **kwargs):
+ outputs = orgin_method(*args, **kwargs)
+ # the first element of a class method is the class itself
+ inputs = list(args[1:])
+ for keyword in method_input_params:
+ if keyword in kwargs:
+ inputs.append(kwargs[keyword])
+ # Assume a func execute N times, there will be N inputs need to
+ # save.
+ data_buffer.append(inputs)
+ return outputs
+
+ return wrap_method
diff --git a/mmrazor/models/task_modules/recorder/method_outputs_recorder.py b/mmrazor/models/task_modules/recorder/method_outputs_recorder.py
index 266750726..6d3fb6593 100644
--- a/mmrazor/models/task_modules/recorder/method_outputs_recorder.py
+++ b/mmrazor/models/task_modules/recorder/method_outputs_recorder.py
@@ -130,8 +130,7 @@ def method_record_wrapper(self, orgin_method: Callable,
Args:
origin_method (MethodType): The method whose outputs need to be
recorded.
- buffer_key (str): The key of the method's outputs saved in
- ``data_buffer``.
+ data_buffer (list): A list of data.
"""
@functools.wraps(orgin_method)
diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py
index a9a6fde52..920d43477 100644
--- a/mmrazor/models/task_modules/tracer/__init__.py
+++ b/mmrazor/models/task_modules/tracer/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backward_tracer import BackwardTracer
+from .fx import CustomTracer, UntracedMethodRegistry, custom_symbolic_trace
from .loss_calculator import * # noqa: F401,F403
from .parsers import * # noqa: F401,F403
from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode,
@@ -7,5 +8,6 @@
__all__ = [
'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode',
- 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode'
+ 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode',
+ 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace'
]
diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py
new file mode 100644
index 000000000..29c93f83a
--- /dev/null
+++ b/mmrazor/models/task_modules/tracer/fx/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .custom_tracer import (CustomTracer, UntracedMethodRegistry,
+ custom_symbolic_trace)
+
+__all__ = ['CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace']
diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py
new file mode 100644
index 000000000..f69ec2269
--- /dev/null
+++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py
@@ -0,0 +1,281 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+from types import FunctionType, MethodType
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+import torch
+from mmengine.utils import import_modules_from_strings
+from torch._C import ScriptObject # type: ignore[attr-defined]
+from torch.ao.quantization.quantize_fx import QuantizationTracer
+from torch.fx import GraphModule, Tracer
+from torch.fx._symbolic_trace import (Graph, _autowrap_check,
+ _patch_wrapped_functions, _Patcher)
+from torch.fx.proxy import Proxy
+
+_orig_module_call: Callable = torch.nn.Module.__call__
+_orig_module_getattr: Callable = torch.nn.Module.__getattr__
+# _orig_module_forward_train: Callable = models.BaseDenseHead.forward_train
+
+
+class UntracedMethodRegistry:
+ """A `Descriptor` class which records untraced methods."""
+ method_dict: Dict = dict()
+ tracer = None
+
+ def __init__(self, method):
+ """_summary_
+
+ Args:
+ method (FunctionType): Function to be registered.
+ """
+ self.method = method
+ self.instances: Dict = dict()
+ self.owner = None
+
+ def __set_name__(self, owner, name):
+ self.owner = owner
+ self.name = name
+ wrapped = self.method_wrapper()
+ self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.method
+ return MethodType(self.method, instance)
+
+ def method_wrapper(self):
+
+ @functools.wraps(self.method)
+ def wrapped_method(mod, *args, **kwargs):
+
+ def method(*args, **kwargs):
+ return self.method(mod, *args, **kwargs)
+
+ return self.tracer.call_method(mod, self.name, method, args,
+ kwargs)
+
+ return wrapped_method
+
+
+def custom_symbolic_trace(
+ root: Union[torch.nn.Module, Callable[..., Any]],
+ concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
+ """Modified `symbolic_trace` function.
+
+ Args:
+ root (Union[torch.nn.Module, Callable]): Module or function to be
+ traced and converted into a Graph representation.
+ concrete_args (Optional[Dict[str, any]]): Inputs to be partially
+ specialized.
+
+ Returns:
+ _type_: _description_
+ """
+ tracer = CustomTracer()
+ graph = tracer.trace(root, concrete_args)
+ name = root.__class__.__name__ if isinstance(
+ root, torch.nn.Module) else root.__name__
+ return GraphModule(tracer.root, graph, name)
+
+
+class CustomTracer(QuantizationTracer):
+
+ def __init__(self,
+ skipped_methods: List[str] = [],
+ skipped_module_names: List[str] = [],
+ skipped_module_classes: List[Callable] = [],
+ *args,
+ **kwargs):
+ """_summary_
+
+ Args:
+ skipped_methods (List[str], optional): Methods to be skipped while
+ tracing. Defaults to None.
+ skipped_module_names (List[str], optional): Modules to be skipped
+ while tracing. Defaults to None.
+ skipped_module_classes (List[str], optional): Class to be skipped
+ while tracing. Defaults to None.
+ """
+ super(CustomTracer, self).__init__(skipped_module_names,
+ skipped_module_classes)
+ UntracedMethodRegistry.tracer = self # type: ignore
+ self.skipped_methods = skipped_methods
+ if self.skipped_methods:
+ self.register_skipped_methods()
+
+ @staticmethod
+ def _check_valid_source(source):
+ """Check if the source's format is valid."""
+ if not isinstance(source, str):
+ raise TypeError(f'source should be a str '
+ f'instance, but got {type(source)}')
+
+ assert len(source.split('.')) > 1, \
+ 'source must have at least one `.`'
+
+ def register_skipped_methods(self):
+ if not isinstance(self.skipped_methods, list):
+ self.skipped_methods = [self.skipped_methods]
+ for s_method in self.skipped_methods:
+ self._check_valid_source(s_method)
+ mod_str = '.'.join(s_method.split('.')[:-2])
+ cls_str = s_method.split('.')[-2]
+ method_str = s_method.split('.')[-1]
+
+ try:
+ mod = import_modules_from_strings(mod_str)
+ except ImportError:
+ raise ImportError(f'{mod_str} is not imported correctly.')
+
+ imported_cls: type = getattr(mod, cls_str)
+ if not isinstance(imported_cls, type):
+ raise TypeError(f'{cls_str} should be a type '
+ f'instance, but got {type(imported_cls)}')
+ assert hasattr(imported_cls, method_str), \
+ f'{method_str} is not in {mod_str}.'
+
+ method = getattr(imported_cls, method_str)
+
+ method_registry = UntracedMethodRegistry(method)
+ method_registry.__set_name__(imported_cls, method_str)
+
+ def call_method(self, m: torch.nn.Module, name, method, args, kwargs):
+ """Method that specifies the behavior of this ``Tracer`` when it
+ encounters a call to an ``nn.Module`` instance.
+
+ By default, the behavior is to check if the called module is a leaf
+ module via ``is_leaf_module``. If it is, emit a ``call_module``
+ node referring to ``m`` in the ``Graph``. Otherwise, call the
+ ``Module`` normally, tracing through the operations in its ``forward``
+ function.
+
+ This method can be overridden to--for example--create nested traced
+ GraphModules, or any other behavior you would want while tracing across
+ ``Module`` boundaries.
+
+ Args:
+
+ m (Module): The module for which a call is being emitted
+ forward (Callable): The forward() method of the ``Module`` to be
+ invoked
+ args (Tuple): args of the module callsite
+ kwargs (Dict): kwargs of the module callsite
+
+ Return:
+
+ The return value from the Module call. In the case that a
+ ``call_module`` node was emitted, this is a ``Proxy`` value.
+ Otherwise, it is whatever value was returned from the ``Module``
+ invocation.
+ """
+ # module_qualified_name = self.path_of_module(m)
+ if not self.is_skipped_method(m):
+ return method(*args, **kwargs)
+ args = list(args)
+ args.insert(0, m)
+ args = tuple(args)
+ return self.create_proxy('call_method', name, args, kwargs)
+
+ def trace(self, root, concrete_args=None):
+ if isinstance(root, torch.nn.Module):
+ self.root = root
+ fn = type(root).forward
+ self.submodule_paths = {
+ mod: name
+ for name, mod in root.named_modules()
+ }
+ else:
+ self.root = torch.nn.Module()
+ fn = root
+
+ tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None)
+ self.graph = Graph(tracer_cls=tracer_cls)
+
+ # When we encounter a Tensor value that's not a parameter, we look if
+ # it is some other attribute on the model. Construct a dict mapping
+ # Tensor values to the qualified name here for efficiency. This is
+ # used downstream in create_arg
+ self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}
+
+ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
+ for k, v in m.__dict__.items():
+ if isinstance(v, (torch.Tensor, ScriptObject)):
+ self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
+ for k, v in m.named_children():
+ collect_tensor_attrs(v, prefix_atoms + [k])
+
+ collect_tensor_attrs(self.root, [])
+
+ assert isinstance(fn, FunctionType)
+
+ fn_globals = fn.__globals__ # run before it gets patched
+ fn, args = self.create_args_for_root(fn,
+ isinstance(root, torch.nn.Module),
+ concrete_args)
+
+ # Reduce number of get_attr calls
+ parameter_proxy_cache: Dict[str, Proxy] = {}
+
+ # Method dispatch on parameters is not recorded unless it's directly
+ # used. Thus, we need to insert a proxy when __getattr__ requests a
+ # parameter.
+ @functools.wraps(_orig_module_getattr)
+ def module_getattr_wrapper(mod, attr):
+ attr_val = _orig_module_getattr(mod, attr)
+ return self._module_getattr(attr, attr_val, parameter_proxy_cache)
+
+ @functools.wraps(_orig_module_call)
+ def module_call_wrapper(mod, *args, **kwargs):
+
+ def forward(*args, **kwargs):
+ return _orig_module_call(mod, *args, **kwargs)
+
+ _autowrap_check(
+ patcher,
+ getattr(getattr(mod, 'forward', mod), '__globals__', {}),
+ self._autowrap_function_ids)
+ return self.call_module(mod, forward, args, kwargs)
+
+ with _Patcher() as patcher:
+ # allow duplicate patches to support the case of nested calls
+ patcher.patch_method(
+ torch.nn.Module,
+ '__getattr__',
+ module_getattr_wrapper,
+ deduplicate=False)
+ patcher.patch_method(
+ torch.nn.Module,
+ '__call__',
+ module_call_wrapper,
+ deduplicate=False)
+
+ for name, value in UntracedMethodRegistry.method_dict.items():
+ wrapped = value['wrapped']
+ patcher.patch_method(
+ value['mod'], name, wrapped, deduplicate=False)
+
+ _patch_wrapped_functions(patcher)
+ _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
+ for module in self._autowrap_search:
+ _autowrap_check(patcher, module.__dict__,
+ self._autowrap_function_ids)
+ self.create_node(
+ 'output',
+ 'output', (self.create_arg(fn(*args)), ), {},
+ type_expr=fn.__annotations__.get('return', None))
+
+ self.submodule_paths = None
+
+ return self.graph
+
+ def is_skipped_method(self, m):
+ mods = tuple(value['mod']
+ for value in UntracedMethodRegistry.method_dict.values())
+ custom = isinstance(m, mods)
+ return custom
+
+ def is_leaf_module(self, m: torch.nn.Module,
+ module_qualified_name: str) -> bool:
+ # return super().is_leaf_module(m, module_qualified_name)
+ leaf = super().is_leaf_module(m, module_qualified_name)
+ return leaf
diff --git a/mmrazor/models/utils/__init__.py b/mmrazor/models/utils/__init__.py
index db07845a8..d37ebd8e8 100644
--- a/mmrazor/models/utils/__init__.py
+++ b/mmrazor/models/utils/__init__.py
@@ -2,9 +2,26 @@
from .make_divisible import make_divisible
from .misc import add_prefix
from .optim_wrapper import reinitialize_optim_wrapper_count_status
+# yapf:disable
+from .quantization_util import (PerChannelLoadHook, _is_float_qparams,
+ _is_per_channel, _is_per_tensor,
+ _is_symmetric_quant,
+ check_is_valid_convert_custom_config_dict,
+ check_is_valid_fuse_custom_config_dict,
+ check_is_valid_prepare_custom_config_dict,
+ check_is_valid_qconfig_dict,
+ get_custom_module_class_keys, is_tracing_state,
+ pot_quantization, sync_tensor)
+# yapf:enable
from .utils import get_module_device, set_requires_grad
__all__ = [
- 'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible',
- 'get_module_device', 'set_requires_grad'
+ 'add_prefix', 'check_is_valid_convert_custom_config_dict',
+ 'check_is_valid_fuse_custom_config_dict',
+ 'check_is_valid_prepare_custom_config_dict', 'check_is_valid_qconfig_dict',
+ 'get_module_device', 'get_custom_module_class_keys', 'make_divisible',
+ 'pot_quantization', 'reinitialize_optim_wrapper_count_status',
+ 'set_requires_grad', 'sync_tensor', '_is_per_channel', '_is_per_tensor',
+ '_is_symmetric_quant', '_is_float_qparams', 'is_tracing_state',
+ 'PerChannelLoadHook'
]
diff --git a/mmrazor/models/utils/make_divisible.py b/mmrazor/models/utils/make_divisible.py
index 5056aeb15..5fda15591 100644
--- a/mmrazor/models/utils/make_divisible.py
+++ b/mmrazor/models/utils/make_divisible.py
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
+warn_once = False
+
def make_divisible(value: int,
divisor: int,
@@ -24,6 +26,18 @@ def make_divisible(value: int,
if min_value is None:
min_value = divisor
+ if min_value < divisor:
+ global warn_once
+ if warn_once is False:
+ from mmengine import MMLogger
+ MMLogger.get_current_instance().warning(
+ (f'min_value=={min_value} should greater or equal to '
+ f'divisor=={divisor}, '
+ 'so we make min_value equal divisor.'))
+ warn_once = True
+
+ min_value = divisor
+
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py
new file mode 100644
index 000000000..376096b67
--- /dev/null
+++ b/mmrazor/models/utils/quantization_util.py
@@ -0,0 +1,217 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+from typing import Any, Dict, List, Optional, Set
+
+import torch
+
+
+class PerChannelLoadHook:
+
+ def __init__(self, module, hook_param=['scale', 'zero_point']):
+ self.hook = module._register_load_state_dict_pre_hook(
+ partial(self.hook_fn, module=module))
+ self.hook_param = hook_param
+
+ def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys,
+ unexpected_keys, error_msgs, module):
+ if module.ch_axis == -1:
+ # no per-channel parameters
+ return
+ for module_key, param in module._parameters.items():
+ if module_key not in self.hook_param:
+ continue
+ candidate = prefix + module_key
+ if candidate in state_dict:
+ input_param = state_dict[candidate]
+ if param.shape != input_param.shape:
+ param.data = torch.ones_like(
+ input_param, dtype=param.dtype, device=param.device)
+ for module_key, param in module._buffers.items():
+ if module_key not in self.hook_param:
+ continue
+ candidate = prefix + module_key
+ if candidate in state_dict:
+ input_param = state_dict[candidate]
+ if param.shape != input_param.shape:
+ param.data = torch.ones_like(
+ input_param, dtype=param.dtype, device=param.device)
+
+ def close(self):
+ self.hook.remove()
+
+
+USE_LINK = False
+USE_DDP = False
+
+try:
+ import spring.linklink as link
+ assert link.is_initialized()
+ USE_LINK = True
+except (ModuleNotFoundError, AssertionError):
+ import torch.distributed as dist
+ if torch.distributed.is_initialized():
+ USE_DDP = True
+
+
+def sync_tensor(tensor):
+ if USE_LINK:
+ if tensor.is_cuda is True:
+ tensor.data = tensor.data / link.get_world_size()
+ link.allreduce(tensor.data)
+ elif USE_DDP:
+ tensor.data = tensor.data / dist.get_world_size()
+ dist.all_reduce(tensor.data)
+ return tensor
+
+
+def pot_quantization(tensor: torch.Tensor, mode='round'):
+ log2t = torch.log2(tensor)
+ if mode == 'round':
+ log2t = (torch.round(log2t) - log2t).detach() + log2t
+ else:
+ assert mode == 'floor'
+ log2t = (torch.floor(log2t) - log2t).detach() + log2t
+ return 2**log2t
+
+
+def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [
+ torch.per_channel_symmetric, torch.per_channel_affine,
+ torch.per_channel_affine_float_qparams
+ ]
+
+
+def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
+
+
+def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
+
+
+def is_tracing_state():
+ return torch._C._get_tracing_state()
+
+
+def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
+ return qscheme in [
+ torch.per_channel_affine_float_qparams,
+ ]
+
+
+def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str],
+ dict_name: str) -> None:
+ r""" Checks if the given config_dict has the correct keys
+ Args:
+ `config_dict`: dictionary whose keys we want to check
+ """
+
+ for k in config_dict.keys():
+ if k not in allowed_keys:
+ raise ValueError('Expected ' + dict_name +
+ ' to have the following keys: ' +
+ str(allowed_keys) + '. But found \'' + k +
+ '\' instead.')
+
+
+def check_is_valid_qconfig_dict(qconfig_dict: Any) -> None:
+ r""" Checks if the given qconfig_dict has the correct keys
+ Args:
+ `qconfig_dict`: dictionary whose keys we want to check
+ """
+
+ qconfig_dict_allowed_keys = {
+ '', 'object_type', 'module_name_regex', 'module_name',
+ 'module_name_object_type_order'
+ }
+ check_is_valid_config_dict(qconfig_dict, qconfig_dict_allowed_keys,
+ 'qconfig_dict')
+
+
+def check_is_valid_prepare_custom_config_dict(
+ prepare_custom_config_dict: Optional[Dict[str, Any]] = None) -> None:
+ r""" Checks if the given prepare_custom_config_dict has the correct keys
+ Args:
+ `prepare_custom_config_dict`: customization configuration dictionary for
+ quantization tool
+ """
+ if not prepare_custom_config_dict:
+ return
+
+ prepare_custom_config_dict_allowed_keys = {
+ 'standalone_module_name', 'standalone_module_class',
+ 'float_to_observed_custom_module_class', 'non_traceable_module_name',
+ 'non_traceable_module_class', 'input_quantized_idxs',
+ 'output_quantized_idxs', 'preserved_attributes'
+ }
+ check_is_valid_config_dict(prepare_custom_config_dict,
+ prepare_custom_config_dict_allowed_keys,
+ 'prepare_custom_config_dict')
+
+
+def check_is_valid_convert_custom_config_dict(
+ convert_custom_config_dict: Optional[Dict[str, Any]] = None) -> None:
+ r""" Checks if the given convert_custom_config_dict has the correct keys
+ Args:
+ `convert_custom_config_dict`: dictionary for custom configurations for
+ convert function
+ """
+ if not convert_custom_config_dict:
+ return
+
+ convert_custom_config_dict_allowed_keys = {
+ 'observed_to_quantized_custom_module_class', 'preserved_attributes'
+ }
+ check_is_valid_config_dict(convert_custom_config_dict,
+ convert_custom_config_dict_allowed_keys,
+ 'convert_custom_config_dict')
+
+
+def check_is_valid_fuse_custom_config_dict(
+ fuse_custom_config_dict: Optional[Dict[str, Any]] = None) -> None:
+ r""" Checks if the given fuse_custom_config_dict has the correct keys
+ Args:
+ `fuse_custom_config_dict`: dictionary for custom configurations for
+ fuse_fx
+ """
+ if not fuse_custom_config_dict:
+ return
+
+ fuse_custom_config_dict_allowed_keys = {'preserved_attributes'}
+ check_is_valid_config_dict(fuse_custom_config_dict,
+ fuse_custom_config_dict_allowed_keys,
+ 'fuse_custom_config_dict')
+
+
+def get_custom_module_class_keys(custom_config_dict,
+ custom_config_dict_key) -> List[Any]:
+ r""" Get all the unique custom module keys in the custom config dict
+ e.g.
+ Input:
+ custom_config_dict = {
+ "float_to_observed_custom_module_class": {
+ "static": {
+ CustomModule1: ObservedCustomModule
+ },
+ "dynamic": {
+ CustomModule2: DynamicObservedCustomModule
+ },
+ "weight_only": {
+ CustomModule3: WeightOnlyObservedCustomModule
+ },
+ },
+ }
+ Output:
+ # extract all the keys in "static", "dynamic" and "weight_only" dict
+ [CustomModule1, CustomModule2, CustomModule3]
+ """
+ # using set to dedup
+ float_custom_module_classes: Set[Any] = set()
+ custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {})
+ for quant_mode in ['static', 'dynamic', 'weight_only']:
+ quant_mode_custom_module_config = custom_module_mapping.get(
+ quant_mode, {})
+ quant_mode_custom_module_classes = set(
+ quant_mode_custom_module_config.keys())
+ float_custom_module_classes |= quant_mode_custom_module_classes
+ return list(float_custom_module_classes)
diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py
new file mode 100644
index 000000000..fc2133bf2
--- /dev/null
+++ b/mmrazor/structures/quantization/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .backend_default_qconfigs import CheckArgs, DefalutQconfigs, SupportQtypes
+from .qscheme import QuantizeScheme
+
+__all__ = ['QuantizeScheme', 'DefalutQconfigs', 'SupportQtypes', 'CheckArgs']
diff --git a/mmrazor/structures/quantization/backend_default_qconfigs.py b/mmrazor/structures/quantization/backend_default_qconfigs.py
new file mode 100644
index 000000000..6a1fde183
--- /dev/null
+++ b/mmrazor/structures/quantization/backend_default_qconfigs.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+SupportQtypes = ('affine')
+CheckArgs = [
+ 'qtype', 'w_qscheme', 'a_qscheme', 'w_fake_quant', 'a_fake_quant',
+ 'w_observer', 'a_observer'
+]
+
+Default = dict(
+ qtype='affine', # noqa: E241
+ w_qscheme=dict(
+ is_symmetry=True,
+ is_per_channel=True,
+ is_pot_scale=False,
+ bit=8,
+ symmetric_range=True),
+ a_qscheme=dict(
+ is_symmetry=True,
+ is_per_channel=False,
+ is_pot_scale=False,
+ bit=8,
+ symmetric_range=True),
+ w_fake_quant=dict(type='BaseFakeQuantize'),
+ a_fake_quant=dict(type='BaseFakeQuantize'),
+ w_observer=dict(type='MinMaxObserver'),
+ a_observer=dict(type='MinMaxObserver'))
+
+TensorRT = dict(
+ qtype='affine', # noqa: E241
+ w_qscheme=dict(
+ is_symmetry=True,
+ is_per_channel=True,
+ is_pot_scale=False,
+ bit=8,
+ symmetric_range=True),
+ a_qscheme=dict(
+ is_symmetry=True,
+ is_per_channel=False,
+ is_pot_scale=False,
+ bit=8,
+ symmetric_range=True),
+ w_fake_quant=dict(type='LearnableFakeQuantize'),
+ a_fake_quant=dict(type='LearnableFakeQuantize'),
+ w_observer=dict(type='MinMaxObserver'),
+ a_observer=dict(type='EMAMinMaxObserver'))
+
+DefalutQconfigs = dict(default=Default, tensorrt=TensorRT)
diff --git a/mmrazor/structures/quantization/qscheme.py b/mmrazor/structures/quantization/qscheme.py
new file mode 100644
index 000000000..24c41832e
--- /dev/null
+++ b/mmrazor/structures/quantization/qscheme.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+class QuantizeScheme(object):
+ """Custom QScheme. Refer to:
+ https://github.com/pytorch/pytorch/blob/master/c10/core/QScheme.h.
+
+ Args:
+ bit (int, optional): Bit number. Defaults to 8.
+ is_symmetry (bool, optional): Is symmetry quantization or not. Defaults
+ to True.
+ is_per_channel (bool, optional): Is per-channel quantization or not.
+ Defaults to False.
+ is_pot_scale (bool, optional): Indicate whether scale is power of two.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ bit=8,
+ is_symmetry=True,
+ is_per_channel=False,
+ is_pot_scale=False,
+ **kwargs):
+ self.bit = bit
+ self.is_symmetry = is_symmetry
+ self.is_per_channel = is_per_channel
+ self.is_pot_scale = is_pot_scale
+
+ if self.is_per_channel:
+ self.torch_qscheme = torch.per_channel_symmetric \
+ if self.is_symmetry else torch.per_channel_affine
+ else:
+ self.torch_qscheme = torch.per_tensor_symmetric \
+ if self.is_symmetry else torch.per_tensor_affine
+ if 'is_symmetric_range' in kwargs:
+ self.is_symmetric_range = kwargs['is_symmetric_range']
+ del kwargs['is_symmetric_range']
+ else:
+ self.is_symmetric_range = False
+ self.kwargs = kwargs
+
+ def to_observer_params(self):
+ quant_min = 0
+ quant_max = 2**self.bit - 1
+ if self.is_symmetry:
+ quant_max = 2**(self.bit - 1) - 1
+ if self.is_symmetric_range:
+ quant_min = -2**(self.bit - 1) + 1
+ else:
+ quant_min = -2**(self.bit - 1)
+
+ naive_para = {
+ 'quant_min': quant_min,
+ 'quant_max': quant_max,
+ 'dtype': torch.qint8 if self.is_symmetry else torch.quint8,
+ 'is_pot_scale': self.is_pot_scale,
+ 'qscheme': self.torch_qscheme,
+ 'reduce_range': False,
+ 'ch_axis': 0 if self.is_per_channel else -1
+ }
+ naive_para.update(self.kwargs)
+ return naive_para
+
+ def __str__(self):
+ return f'bit: {self.bit} / is_symmetry: {self.is_symmetry} / \
+ is_per_channel: {self.is_per_channel} / is_pot_scale: \
+ {self.is_pot_scale} / extra_kwargs: {self.kwargs}'
diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py
index 4eb515371..625e65025 100644
--- a/mmrazor/structures/subnet/fix_subnet.py
+++ b/mmrazor/structures/subnet/fix_subnet.py
@@ -6,6 +6,7 @@
from torch import nn
from mmrazor.utils import FixMutable, ValidFixMutable
+from mmrazor.utils.typing import DumpChosen
def _dynamic_to_static(model: nn.Module) -> None:
@@ -56,6 +57,7 @@ def load_fix_subnet(model: nn.Module,
assert alias in fix_mutable, \
f'The alias {alias} is not in fix_modules, ' \
'please check your `fix_mutable`.'
+ # {chosen=xx, meta=xx)
chosen = fix_mutable.get(alias, None)
else:
mutable_name = name.lstrip(prefix)
@@ -64,8 +66,12 @@ def load_fix_subnet(model: nn.Module,
raise RuntimeError(
f'The module name {mutable_name} is not in '
'fix_mutable, please check your `fix_mutable`.')
+ # {chosen=xx, meta=xx)
chosen = fix_mutable.get(mutable_name, None)
- module.fix_chosen(chosen)
+
+ if not isinstance(chosen, DumpChosen):
+ chosen = DumpChosen(**chosen)
+ module.fix_chosen(chosen.chosen)
# convert dynamic op to static op
_dynamic_to_static(model)
diff --git a/mmrazor/testing/__init__.py b/mmrazor/testing/__init__.py
index 009dd844d..54dfd30ed 100644
--- a/mmrazor/testing/__init__.py
+++ b/mmrazor/testing/__init__.py
@@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403
+from ._fx_models import * # noqa: F401, F403
diff --git a/mmrazor/testing/_fx_models.py b/mmrazor/testing/_fx_models.py
new file mode 100644
index 000000000..969c4792d
--- /dev/null
+++ b/mmrazor/testing/_fx_models.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Optional, Tuple, Union
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from mmrazor.registry import MODELS
+
+
+@MODELS.register_module()
+class ConvBNReLU(nn.Module):
+
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ kernel_size: Union[int, Tuple[int, int]] = 1,
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: Union[str, bool] = 'auto',
+ conv_cfg: Optional[Dict] = None,
+ norm_cfg: Optional[Dict] = None,
+ act_cfg: Dict = dict(type='ReLU'),
+ inplace: bool = True,
+ with_spectral_norm: bool = False,
+ padding_mode: str = 'zeros',
+ order: tuple = ('conv', 'norm', 'act'),
+ init_cfg: Optional[Dict] = None,
+ ) -> None:
+ super().__init__()
+ self.conv_module = ConvModule(in_channel, out_channel, kernel_size,
+ stride, padding, dilation, groups, bias,
+ conv_cfg, norm_cfg, act_cfg, inplace,
+ with_spectral_norm, padding_mode, order)
+
+ def forward(self, x):
+ x = self.conv_module.conv(x)
+ x = self.conv_module.norm(x)
+ x = self.conv_module.activate(x)
+ return x
diff --git a/mmrazor/utils/index_dict.py b/mmrazor/utils/index_dict.py
index 8ac3661c2..a053024ac 100644
--- a/mmrazor/utils/index_dict.py
+++ b/mmrazor/utils/index_dict.py
@@ -6,9 +6,9 @@
class IndexDict(OrderedDict):
- """IndexDict inherents from OrderedDict[Tuple[int, int], VT]. Each
- IndexDict object is a OrderDict object which using index(Tuple[int,int]) as
- key and Any as value.
+ """IndexDict inherits from OrderedDict[Tuple[int, int], VT]. Each IndexDict
+ object is a OrderDict object which using index(Tuple[int,int]) as key and
+ Any as value.
The key type is Tuple[a: int,b: int]. It indicates a range in
the [a,b).
diff --git a/mmrazor/utils/typing.py b/mmrazor/utils/typing.py
index 1166d580f..0d1126f2a 100644
--- a/mmrazor/utils/typing.py
+++ b/mmrazor/utils/typing.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, NamedTuple, Optional, Union
FixMutable = Dict[str, Any]
ValidFixMutable = Union[str, Path, FixMutable]
@@ -23,3 +23,15 @@
SupportRandomSubnet = Union[SingleMutatorRandomSubnet,
MultiMutatorsRandomSubnet]
+
+Chosen = Union[str, float, List[str]]
+ChosenMeta = Optional[Dict[str, Any]]
+
+
+class DumpChosen(NamedTuple):
+ chosen: Chosen
+ meta: ChosenMeta = None
+
+
+# DumpChosen = NamedTuple('DumpChosen', [('chosen', Chosen),
+# ('meta', ChosenMeta)])
diff --git a/mmrazor/version.py b/mmrazor/version.py
index 7ac2c40a3..e962eccaa 100644
--- a/mmrazor/version.py
+++ b/mmrazor/version.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved
-__version__ = '1.0.0rc0'
+__version__ = '1.0.0rc1'
def parse_version_info(version_str):
diff --git a/mmrazor/visualization/__init__.py b/mmrazor/visualization/__init__.py
new file mode 100644
index 000000000..993202428
--- /dev/null
+++ b/mmrazor/visualization/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .local_visualizer import modify
+
+__all__ = ['modify']
diff --git a/mmrazor/visualization/local_visualizer.py b/mmrazor/visualization/local_visualizer.py
new file mode 100644
index 000000000..5834e8eca
--- /dev/null
+++ b/mmrazor/visualization/local_visualizer.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Optional, Tuple
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmengine.dist import master_only
+from mmengine.visualization.utils import (convert_overlay_heatmap,
+ img_from_canvas)
+
+
+@master_only
+def modify(featmap: torch.Tensor,
+ overlaid_image: Optional[np.ndarray] = None,
+ channel_reduction: Optional[str] = 'pixel_wise_max',
+ topk: int = 20,
+ arrangement: Tuple[int, int] = (4, 5),
+ resize_shape: Optional[tuple] = None,
+ alpha: float = 0.5):
+ assert isinstance(featmap,
+ torch.Tensor), (f'`featmap` should be torch.Tensor,'
+ f' but got {type(featmap)}')
+ assert featmap.ndim == 3, f'Input dimension must be 3, ' \
+ f'but got {featmap.ndim}'
+ featmap = featmap.detach().cpu()
+
+ if overlaid_image is not None:
+ if overlaid_image.ndim == 2:
+ overlaid_image = cv2.cvtColor(overlaid_image, cv2.COLOR_GRAY2RGB)
+
+ if overlaid_image.shape[:2] != featmap.shape[1:]:
+ warnings.warn(f'Since the spatial dimensions of '
+ f'overlaid_image: {overlaid_image.shape[:2]} and '
+ f'featmap: {featmap.shape[1:]} are not same, '
+ f'the feature map will be interpolated. '
+ f'This may cause mismatch problems !')
+ if resize_shape is None:
+ overlaid_image_h, overlaid_image_w = overlaid_image.shape[:2]
+ feat_h, feat_w = featmap.shape[-2:]
+ if feat_h / feat_w > overlaid_image_h / overlaid_image_w:
+ feat_h = round(feat_w * overlaid_image_h /
+ overlaid_image_w)
+ else:
+ feat_w = round(feat_h * overlaid_image_w /
+ overlaid_image_h)
+ featmap = featmap[..., :feat_h, :feat_w]
+ featmap = F.interpolate(
+ featmap[None], overlaid_image.shape[:2],
+ mode='bilinear')[0]
+
+ if resize_shape is not None:
+ featmap = F.interpolate(
+ featmap[None], resize_shape, mode='bilinear',
+ align_corners=False)[0]
+ if overlaid_image is not None:
+ overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1])
+
+ if channel_reduction is not None:
+ assert channel_reduction in [
+ 'squeeze_mean', 'select_max', 'pixel_wise_max'], \
+ f'Mode only support "squeeze_mean", "select_max", ' \
+ f'"pixel_wise_max", but got {channel_reduction}'
+ if channel_reduction == 'select_max':
+ sum_channel_featmap = torch.sum(featmap, dim=(1, 2))
+ _, indices = torch.topk(sum_channel_featmap, 1)
+ feat_map = featmap[indices]
+ elif channel_reduction == 'squeeze_mean':
+ feat_map = torch.mean(featmap, dim=0)
+ else:
+ feat_map = torch.max(featmap, dim=0)[0]
+ return convert_overlay_heatmap(feat_map, overlaid_image, alpha)
+ elif topk <= 0:
+ featmap_channel = featmap.shape[0]
+ assert featmap_channel in [
+ 1, 3
+ ], ('The input tensor channel dimension must be 1 or 3 '
+ 'when topk is less than 1, but the channel '
+ f'dimension you input is {featmap_channel}, you can use the'
+ ' channel_reduction parameter or set topk greater than '
+ '0 to solve the error')
+ return convert_overlay_heatmap(featmap, overlaid_image, alpha)
+ else:
+ row, col = arrangement
+ channel, height, width = featmap.shape
+ assert row * col >= topk, 'The product of row and col in ' \
+ 'the `arrangement` is less than ' \
+ 'topk, please set the ' \
+ '`arrangement` correctly'
+
+ # Extract the feature map of topk
+ topk = min(channel, topk)
+ sum_channel_featmap = torch.sum(featmap, dim=(1, 2))
+ _, indices = torch.topk(sum_channel_featmap, topk)
+ topk_featmap = featmap[indices]
+
+ fig = plt.figure(frameon=False)
+ # Set the window layout
+ fig.subplots_adjust(
+ left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
+ dpi = fig.get_dpi()
+ fig.set_size_inches((width * col + 1e-2) / dpi,
+ (height * row + 1e-2) / dpi)
+ for i in range(topk):
+ axes = fig.add_subplot(row, col, i + 1)
+ axes.axis('off')
+ axes.text(2, 15, f'channel: {indices[i]}', fontsize=10)
+ axes.imshow(
+ convert_overlay_heatmap(topk_featmap[i], overlaid_image,
+ alpha))
+ image = img_from_canvas(fig.canvas)
+ plt.close(fig)
+ return image
diff --git a/model-index.yml b/model-index.yml
index b1a321c84..90969c7e3 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -18,3 +18,6 @@ Import:
- configs/nas/mmcls/autoslim/metafile.yml
- configs/nas/mmcls/darts/metafile.yml
- configs/nas/mmdet/detnas/metafile.yml
+ - configs/distill/mmdet/pkd/metafile.yml
+ - configs/distill/mmdet3d/pkd/metafile.yml
+ - configs/distill/mmcls/deit/metafile.yml
diff --git a/tests/data/models.py b/tests/data/models.py
index 60c8a7058..867adc0c9 100644
--- a/tests/data/models.py
+++ b/tests/data/models.py
@@ -513,12 +513,25 @@ def _expand_mask():
def dump_chosen(self):
return super().dump_chosen()
+ def export_chosen(self):
+ return super().export_chosen()
+
def fix_chosen(self, chosen):
return super().fix_chosen(chosen)
def num_choices(self) -> int:
return super().num_choices
+ @property
+ def current_choice(self):
+ return super().current_choice
+
+ @current_choice.setter
+ def current_choice(self, choice):
+ super().current_choice(choice)
+
+
+
class DynamicLinearModel(nn.Module):
"""
diff --git a/tests/data/test_models/test_subnet/mockmodel_subnet.yaml b/tests/data/test_models/test_subnet/mockmodel_subnet.yaml
index 36e3a9ce0..8d92a99b5 100644
--- a/tests/data/test_models/test_subnet/mockmodel_subnet.yaml
+++ b/tests/data/test_models/test_subnet/mockmodel_subnet.yaml
@@ -1,2 +1,4 @@
-mutable1: conv1
-mutable2: conv2
+mutable1:
+ chosen: conv1
+mutable2:
+ chosen: conv2
diff --git a/tests/data/test_registry/registry_subnet_config.py b/tests/data/test_registry/registry_subnet_config.py
index 28ba3a0ea..539a1cdb1 100644
--- a/tests/data/test_registry/registry_subnet_config.py
+++ b/tests/data/test_registry/registry_subnet_config.py
@@ -7,7 +7,7 @@
type='MockAlgorithm',
architecture=supernet,
_fix_subnet_ = {
- 'architecture.mutable1': 'conv1',
- 'architecture.mutable2': 'conv2',
+ 'architecture.mutable1': {'chosen':'conv1'},
+ 'architecture.mutable2': {'chosen':'conv2'},
}
)
diff --git a/tests/test_core/test_delivers/test_function_outputs_deliver.py b/tests/test_core/test_delivers/test_function_outputs_deliver.py
index 8115af411..531e59795 100644
--- a/tests/test_core/test_delivers/test_function_outputs_deliver.py
+++ b/tests/test_core/test_delivers/test_function_outputs_deliver.py
@@ -1,11 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import os.path as osp
+import tempfile
from unittest import TestCase
+from unittest.mock import Mock
+
+import torch
+import torch.nn as nn
+from mmengine.evaluator import Evaluator
+from mmengine.hooks import EMAHook
+from mmengine.logging import MMLogger
+from mmengine.model import BaseModel, ExponentialMovingAverage
+from mmengine.optim import OptimWrapper
+from mmengine.runner import Runner
+from torch.utils.data import Dataset
from mmrazor.models.task_modules import FunctionOutputsDelivery
+class ToyModel(BaseModel):
+
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(2, 1)
+ # test FunctionOutputsDelivery when ema_hook is used
+ self.deliver = FunctionOutputsDelivery(
+ max_keep_data=2, func_path='toy_module.toy_func')
+
+ def forward(self, inputs, data_sample, mode='tensor'):
+ labels = torch.stack(data_sample)
+ inputs = torch.stack(inputs)
+ with self.deliver:
+ outputs = self.linear(inputs)
+ if mode == 'tensor':
+ return outputs
+ elif mode == 'loss':
+ loss = (labels - outputs).sum()
+ outputs = dict(loss=loss)
+ return outputs
+ else:
+ return outputs
+
+
+class DummyDataset(Dataset):
+ METAINFO = dict() # type: ignore
+ data = torch.randn(12, 2)
+ label = torch.ones(12)
+
+ @property
+ def metainfo(self):
+ return self.METAINFO
+
+ def __len__(self):
+ return self.data.size(0)
+
+ def __getitem__(self, index):
+ return dict(inputs=self.data[index], data_sample=self.label[index])
+
+
class TestFuncOutputsDeliver(TestCase):
+ def setUp(self):
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ def tearDown(self):
+ # `FileHandler` should be closed in Windows, otherwise we cannot
+ # delete the temporary directory
+ logging.shutdown()
+ MMLogger._instance_dict.clear()
+ self.temp_dir.cleanup()
+
def test_init(self):
with self.assertRaisesRegex(TypeError, 'func_path should be'):
@@ -14,19 +78,25 @@ def test_init(self):
with self.assertRaisesRegex(AssertionError, 'func_path must have at '):
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_func')
+ def test_context_manager(self):
+ import toy_module
+
+ delivery = FunctionOutputsDelivery(max_keep_data=2, func_path='aaa.bb')
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
- _ = FunctionOutputsDelivery(max_keep_data=1, func_path='aaa.bb')
+ with delivery:
+ _ = toy_module.toy_func()
+ delivery = FunctionOutputsDelivery(
+ max_keep_data=1, func_path='toy_module.bb')
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'):
- _ = FunctionOutputsDelivery(
- max_keep_data=1, func_path='toy_module.bb')
+ with delivery:
+ _ = toy_module.toy_func()
+ delivery = FunctionOutputsDelivery(
+ max_keep_data=1, func_path='toy_module.TOY_VAR')
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
- _ = FunctionOutputsDelivery(
- max_keep_data=1, func_path='toy_module.TOY_VAR')
-
- def test_context_manager(self):
- import toy_module
+ with delivery:
+ _ = toy_module.toy_func()
delivery = FunctionOutputsDelivery(
max_keep_data=2, func_path='toy_module.toy_func')
@@ -52,3 +122,42 @@ def test_context_manager(self):
with self.assertRaisesRegex(AssertionError, 'pop from an empty queue'):
with delivery:
_ = toy_module.toy_func()
+
+ def test_ema_hook(self):
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+ model = ToyModel().to(device)
+ evaluator = Evaluator([])
+ evaluator.evaluate = Mock(return_value=dict(acc=0.5))
+ runner = Runner(
+ model=model,
+ train_dataloader=dict(
+ dataset=DummyDataset(),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ batch_size=3,
+ num_workers=0),
+ val_dataloader=dict(
+ dataset=DummyDataset(),
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ batch_size=3,
+ num_workers=0),
+ val_evaluator=evaluator,
+ work_dir=self.temp_dir.name,
+ default_scope='mmrazor',
+ optim_wrapper=OptimWrapper(
+ torch.optim.Adam(ToyModel().parameters())),
+ train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
+ val_cfg=dict(),
+ default_hooks=dict(logger=None),
+ custom_hooks=[dict(type='EMAHook', )],
+ experiment_name='test_func_outputs_deliver')
+ runner.train()
+ for hook in runner.hooks:
+ if isinstance(hook, EMAHook):
+ self.assertTrue(
+ isinstance(hook.ema_model, ExponentialMovingAverage))
+
+ self.assertTrue(
+ osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
+ checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
+ self.assertTrue('ema_state_dict' in checkpoint)
+ self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
diff --git a/tests/test_core/test_recorders/test_func_inputs_recorder.py b/tests/test_core/test_recorders/test_func_inputs_recorder.py
new file mode 100644
index 000000000..6fa9655a1
--- /dev/null
+++ b/tests/test_core/test_recorders/test_func_inputs_recorder.py
@@ -0,0 +1,138 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import os.path as osp
+import tempfile
+from unittest import TestCase
+from unittest.mock import Mock
+
+import torch
+import torch.nn as nn
+from mmengine.evaluator import Evaluator
+from mmengine.hooks import EMAHook
+from mmengine.logging import MMLogger
+from mmengine.model import BaseModel, ExponentialMovingAverage
+from mmengine.optim import OptimWrapper
+from mmengine.runner import Runner
+from torch.utils.data import Dataset
+
+from mmrazor.models.task_modules import FunctionInputsRecorder, RecorderManager
+
+
+class ToyModel(BaseModel):
+
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(2, 1)
+ # test FunctionInputsRecorder when ema_hook is used
+ recorders_cfg = dict(
+ out=dict(type='FunctionInputs', source='toy_mod.toy_func'))
+ self.recorders = RecorderManager(recorders_cfg)
+ self.recorders.initialize(self)
+
+ def forward(self, inputs, data_sample, mode='tensor'):
+ labels = torch.stack(data_sample)
+ inputs = torch.stack(inputs)
+ with self.recorders:
+ outputs = self.linear(inputs)
+ if mode == 'tensor':
+ return outputs
+ elif mode == 'loss':
+ loss = (labels - outputs).sum()
+ outputs = dict(loss=loss)
+ return outputs
+ else:
+ return outputs
+
+
+class DummyDataset(Dataset):
+ METAINFO = dict() # type: ignore
+ data = torch.randn(12, 2)
+ label = torch.ones(12)
+
+ @property
+ def metainfo(self):
+ return self.METAINFO
+
+ def __len__(self):
+ return self.data.size(0)
+
+ def __getitem__(self, index):
+ return dict(inputs=self.data[index], data_sample=self.label[index])
+
+
+class TestFuncInputsRecorder(TestCase):
+
+ def setUp(self):
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ def tearDown(self):
+ # `FileHandler` should be closed in Windows, otherwise we cannot
+ # delete the temporary directory
+ logging.shutdown()
+ MMLogger._instance_dict.clear()
+ self.temp_dir.cleanup()
+
+ def test_context_manager(self):
+ from toy_mod import execute_toy_func2 as execute_toy_func
+
+ recorder = FunctionInputsRecorder('toy_mod.toy_func2')
+ recorder.initialize()
+
+ with recorder:
+ execute_toy_func(1, 2)
+ execute_toy_func(1, b=2)
+ execute_toy_func(b=2, a=1)
+
+ self.assertTrue(
+ recorder.get_record_data(record_idx=0, data_idx=0) == 1)
+ self.assertTrue(
+ recorder.get_record_data(record_idx=0, data_idx=1) == 2)
+
+ self.assertTrue(
+ recorder.get_record_data(record_idx=1, data_idx=0) == 1)
+ self.assertTrue(
+ recorder.get_record_data(record_idx=1, data_idx=1) == 2)
+
+ self.assertTrue(
+ recorder.get_record_data(record_idx=2, data_idx=0) == 1)
+ self.assertTrue(
+ recorder.get_record_data(record_idx=2, data_idx=1) == 2)
+
+ def test_ema_hook(self):
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+ model = ToyModel().to(device)
+ evaluator = Evaluator([])
+ evaluator.evaluate = Mock(return_value=dict(acc=0.5))
+ runner = Runner(
+ model=model,
+ train_dataloader=dict(
+ dataset=DummyDataset(),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ batch_size=3,
+ num_workers=0),
+ val_dataloader=dict(
+ dataset=DummyDataset(),
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ batch_size=3,
+ num_workers=0),
+ val_evaluator=evaluator,
+ work_dir=self.temp_dir.name,
+ default_scope='mmrazor',
+ optim_wrapper=OptimWrapper(
+ torch.optim.Adam(ToyModel().parameters())),
+ train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
+ val_cfg=dict(),
+ default_hooks=dict(logger=None),
+ custom_hooks=[dict(type='EMAHook', )],
+ experiment_name='test_func_inputs_recorder')
+ runner.train()
+ for hook in runner.hooks:
+ if isinstance(hook, EMAHook):
+ self.assertTrue(
+ isinstance(hook.ema_model, ExponentialMovingAverage))
+
+ self.assertTrue(
+ osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
+ checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
+ self.assertTrue('ema_state_dict' in checkpoint)
+ self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
diff --git a/tests/test_core/test_recorders/test_func_outputs_recorder.py b/tests/test_core/test_recorders/test_func_outputs_recorder.py
index a6c3be2ba..1d6561495 100644
--- a/tests/test_core/test_recorders/test_func_outputs_recorder.py
+++ b/tests/test_core/test_recorders/test_func_outputs_recorder.py
@@ -16,17 +16,26 @@ def test_init(self):
with self.assertRaisesRegex(AssertionError, 'source must have at '):
_ = FunctionOutputsRecorder('aaaaa')
+ def test_context_manager(self):
+ from toy_mod import execute_toy_func
+
+ recorder = FunctionOutputsRecorder('aaa.bbb')
+ recorder.initialize()
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
- _ = FunctionOutputsRecorder('aaa.bbb')
+ with recorder:
+ execute_toy_func(1)
+ recorder = FunctionOutputsRecorder('toy_mod.aaa')
+ recorder.initialize()
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
- _ = FunctionOutputsRecorder('toy_mod.aaa')
+ with recorder:
+ execute_toy_func(1)
+ recorder = FunctionOutputsRecorder('toy_mod.TOY_VAR')
+ recorder.initialize()
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
- _ = FunctionOutputsRecorder('toy_mod.TOY_VAR')
-
- def test_context_manager(self):
- from toy_mod import execute_toy_func
+ with recorder:
+ execute_toy_func(1)
recorder = FunctionOutputsRecorder('toy_mod.toy_func')
recorder.initialize()
diff --git a/tests/test_core/test_recorders/test_method_inputs_recorder.py b/tests/test_core/test_recorders/test_method_inputs_recorder.py
new file mode 100644
index 000000000..7450a231c
--- /dev/null
+++ b/tests/test_core/test_recorders/test_method_inputs_recorder.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+from mmrazor.models.task_modules import MethodInputsRecorder
+
+
+class TestFuncOutputsRecorder(TestCase):
+
+ def test_context_manager(self):
+ from toy_mod import ToyClass
+
+ toy = ToyClass()
+
+ recorder = MethodInputsRecorder('toy_mod.ToyClass.func')
+ recorder.initialize()
+
+ with recorder:
+ _ = toy.func(x=1, y=2)
+ _ = toy.func(1, y=2)
+ _ = toy.func(y=2, x=1)
+
+ self.assertTrue(
+ recorder.get_record_data(record_idx=0, data_idx=0) == 1)
+ self.assertTrue(
+ recorder.get_record_data(record_idx=0, data_idx=1) == 2)
+
+ self.assertTrue(
+ recorder.get_record_data(record_idx=1, data_idx=0) == 1)
+ self.assertTrue(
+ recorder.get_record_data(record_idx=1, data_idx=1) == 2)
+
+ self.assertTrue(
+ recorder.get_record_data(record_idx=2, data_idx=0) == 1)
+ self.assertTrue(
+ recorder.get_record_data(record_idx=2, data_idx=1) == 2)
diff --git a/tests/test_core/test_recorders/toy_mod.py b/tests/test_core/test_recorders/toy_mod.py
index 3cc331476..0df3e2d70 100644
--- a/tests/test_core/test_recorders/toy_mod.py
+++ b/tests/test_core/test_recorders/toy_mod.py
@@ -8,6 +8,10 @@ def toy_func(a):
return a
+def toy_func2(a, b):
+ return a, b
+
+
def toy_list_func(a):
return [a, a, a]
@@ -16,6 +20,10 @@ def execute_toy_func(a):
toy_func(a)
+def execute_toy_func2(a, b):
+ toy_func2(a, b)
+
+
def execute_toy_list_func(a):
toy_list_func(a)
@@ -31,6 +39,9 @@ def toy(self):
self._count += 1
return self._count
+ def func(self, x, y=0):
+ return x + y
+
def __call__(self):
self._count += 1
return self._count
diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py
new file mode 100644
index 000000000..64004843a
--- /dev/null
+++ b/tests/test_engine/test_hooks/test_visualization_hook.py
@@ -0,0 +1,129 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import shutil
+import time
+from os.path import dirname
+from typing import Optional
+from unittest import TestCase
+from unittest.mock import Mock
+
+import torch
+import torch.nn as nn
+# TODO: The argument `out_file` has not been supported in MMEngine yet.
+# Temporarily, we use `ClsVisualizer` here
+from mmcls.visualization import ClsVisualizer
+from mmengine import ConfigDict
+from mmengine.model import BaseModel
+
+from mmrazor.engine.hooks import RazorVisualizationHook
+
+
+def get_data_info(idx):
+ root_path = dirname(dirname(dirname(dirname(__file__))))
+ return {
+ 'img_path': os.path.join(root_path, 'tools/visualizations/demo.jpg')
+ }
+
+
+class ToyModel(BaseModel):
+
+ def __init__(self):
+ data_preprocessor = dict(
+ type='mmcls.ClsDataPreprocessor',
+ # RGB format normalization parameters
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+ )
+ super().__init__(data_preprocessor=data_preprocessor)
+ self.op = nn.Conv2d(3, 3, 1)
+
+ def forward(self,
+ inputs: torch.Tensor,
+ data_samples: Optional[list] = None,
+ mode: str = 'tensor'):
+ out = self.op(inputs)
+ return out
+
+
+class TestVisualizationHook(TestCase):
+
+ def setUp(self) -> None:
+ # TODO: The argument `out_file` has not been supported in MMEngine yet.
+ # Temporarily, we use `ClsVisualizer` here
+ ClsVisualizer.get_instance('visualizer')
+
+ test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(1333, 800), keep_ratio=True),
+ dict(type='mmcls.PackClsInputs')
+ ]
+
+ self.runner = Mock()
+ self.runner.val_loop.dataloader.dataset.get_data_info = get_data_info
+ self.runner.cfg = ConfigDict(
+ test_dataloader=dict(dataset=dict(pipeline=test_pipeline)))
+ self.runner.model = ToyModel()
+
+ self.recorders = ConfigDict(
+ out=dict(_scope_='mmrazor', type='ModuleOutputs', source='op'))
+ self.mappings = ConfigDict(out=dict(recorder='out'))
+
+ def test_before_run(self):
+ hook = RazorVisualizationHook(self.recorders, self.mappings)
+ hook.before_run(self.runner)
+
+ def test_before_train(self):
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ out_dir = timestamp + '1'
+ self.runner.work_dir = timestamp
+ self.runner.timestamp = '1'
+ self.runner.epoch = 0
+
+ hook = RazorVisualizationHook(
+ self.recorders, self.mappings, out_dir=out_dir, enabled=False)
+ # initialize recorders
+ hook.before_run(self.runner)
+ hook.before_train(self.runner)
+ self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}'))
+
+ hook = RazorVisualizationHook(
+ self.recorders, self.mappings, out_dir=out_dir, enabled=True)
+ # initialize recorders
+ hook.before_run(self.runner)
+ hook.before_train(self.runner)
+ self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}'))
+ shutil.rmtree(f'{timestamp}')
+
+ def test_after_train_epoch(self):
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ out_dir = timestamp + '1'
+ self.runner.work_dir = timestamp
+ self.runner.timestamp = '1'
+
+ hook = RazorVisualizationHook(
+ self.recorders, self.mappings, out_dir=out_dir, enabled=False)
+ # initialize recorders
+ hook.before_run(self.runner)
+ self.runner.epoch = 0
+ hook.after_train_epoch(self.runner)
+ self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}'))
+
+ self.runner.epoch = 1
+ hook = RazorVisualizationHook(
+ self.recorders,
+ self.mappings,
+ out_dir=out_dir,
+ enabled=True,
+ interval=2)
+ # initialize recorders
+ hook.before_run(self.runner)
+ hook.after_train_epoch(self.runner)
+ self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}'))
+
+ self.runner.epoch = 2
+ hook.after_train_epoch(self.runner)
+ self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}'))
+ shutil.rmtree(f'{timestamp}')
diff --git a/tests/test_models/test_algorithms/test_darts.py b/tests/test_models/test_algorithms/test_darts.py
index 7d33fa047..52f5d10e6 100644
--- a/tests/test_models/test_algorithms/test_darts.py
+++ b/tests/test_models/test_algorithms/test_darts.py
@@ -104,7 +104,11 @@ def test_init(self) -> None:
self.assertIsInstance(algo.mutator, DiffModuleMutator)
# initiate darts when `fix_subnet` is not None
- fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']}
+ fix_subnet = {
+ 'normal': {
+ 'chosen': ['torch_conv2d_3x3', 'torch_conv2d_7x7']
+ }
+ }
algo = Darts(model, mutator, fix_subnet=fix_subnet)
self.assertEqual(algo.architecture.mutable.num_choices, 2)
@@ -124,7 +128,11 @@ def test_forward_loss(self) -> None:
self.assertIsInstance(loss, dict)
# subnet
- fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']}
+ fix_subnet = {
+ 'normal': {
+ 'chosen': ['torch_conv2d_3x3', 'torch_conv2d_7x7']
+ }
+ }
algo = Darts(model, fix_subnet=fix_subnet)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
diff --git a/tests/test_models/test_algorithms/test_dsnas.py b/tests/test_models/test_algorithms/test_dsnas.py
index 9f6dfc902..2b5bbfa49 100644
--- a/tests/test_models/test_algorithms/test_dsnas.py
+++ b/tests/test_models/test_algorithms/test_dsnas.py
@@ -14,8 +14,8 @@
from torch import Tensor
from torch.optim import SGD
-from mmrazor.models import DiffModuleMutator, Dsnas, OneHotMutableOP
-from mmrazor.models.algorithms.nas.dsnas import DsnasDDP
+from mmrazor.models import DSNAS, DiffModuleMutator, OneHotMutableOP
+from mmrazor.models.algorithms.nas.dsnas import DSNASDDP
from mmrazor.registry import MODELS
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
@@ -81,29 +81,29 @@ def test_init(self) -> None:
# initiate dsnas when `norm_training` is True.
model = ToyDiffModule()
mutator = DiffModuleMutator()
- algo = Dsnas(architecture=model, mutator=mutator, norm_training=True)
+ algo = DSNAS(architecture=model, mutator=mutator, norm_training=True)
algo.eval()
self.assertTrue(model.bn.training)
# initiate Dsnas with built mutator
model = ToyDiffModule()
mutator = DiffModuleMutator()
- algo = Dsnas(model, mutator)
+ algo = DSNAS(model, mutator)
self.assertIs(algo.mutator, mutator)
# initiate Dsnas with unbuilt mutator
mutator = dict(type='DiffModuleMutator')
- algo = Dsnas(model, mutator)
+ algo = DSNAS(model, mutator)
self.assertIsInstance(algo.mutator, DiffModuleMutator)
# initiate Dsnas when `fix_subnet` is not None
- fix_subnet = {'mutable': 'torch_conv2d_5x5'}
- algo = Dsnas(model, mutator, fix_subnet=fix_subnet)
+ fix_subnet = {'mutable': {'chosen': 'torch_conv2d_5x5'}}
+ algo = DSNAS(model, mutator, fix_subnet=fix_subnet)
self.assertEqual(algo.architecture.mutable.num_choices, 1)
# initiate Dsnas with error type `mutator`
with self.assertRaisesRegex(TypeError, 'mutator should be'):
- Dsnas(model, model)
+ DSNAS(model, model)
def test_forward_loss(self) -> None:
inputs = torch.randn(1, 3, 8, 8)
@@ -112,13 +112,13 @@ def test_forward_loss(self) -> None:
# supernet
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
- algo = Dsnas(model, mutator)
+ algo = DSNAS(model, mutator)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
# subnet
- fix_subnet = {'mutable': 'torch_conv2d_5x5'}
- algo = Dsnas(model, fix_subnet=fix_subnet)
+ fix_subnet = {'mutable': {'chosen': 'torch_conv2d_5x5'}}
+ algo = DSNAS(model, fix_subnet=fix_subnet)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
@@ -135,7 +135,7 @@ def test_search_subnet(self) -> None:
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
- algo = Dsnas(model, mutator)
+ algo = DSNAS(model, mutator)
subnet = algo.search_subnet()
self.assertIsInstance(subnet, dict)
@@ -146,14 +146,14 @@ def test_dsnas_train_step(self, mock_get_info) -> None:
mutator.prepare_from_supernet(model)
mock_get_info.return_value = 2
- algo = Dsnas(model, mutator)
+ algo = DSNAS(model, mutator)
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG)
loss = algo.train_step(data, optim_wrapper)
self.assertTrue(isinstance(loss['loss'], Tensor))
- algo = Dsnas(model, mutator)
+ algo = DSNAS(model, mutator)
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
@@ -173,16 +173,16 @@ def setUpClass(cls) -> None:
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend, rank=0, world_size=1)
- def prepare_model(self, device_ids=None) -> Dsnas:
+ def prepare_model(self, device_ids=None) -> DSNAS:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyDiffModule()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
- algo = Dsnas(model, mutator).to(self.device)
+ algo = DSNAS(model, mutator).to(self.device)
- return DsnasDDP(
+ return DSNASDDP(
module=algo, find_unused_parameters=True, device_ids=device_ids)
@classmethod
@@ -193,7 +193,7 @@ def tearDownClass(cls) -> None:
not torch.cuda.is_available(), reason='cuda device is not avaliable')
def test_init(self) -> None:
ddp_model = self.prepare_model()
- self.assertIsInstance(ddp_model, DsnasDDP)
+ self.assertIsInstance(ddp_model, DSNASDDP)
@patch('mmengine.logging.message_hub.MessageHub.get_info')
def test_dsnasddp_train_step(self, mock_get_info) -> None:
diff --git a/tests/test_models/test_algorithms/test_general_quant.py b/tests/test_models/test_algorithms/test_general_quant.py
new file mode 100644
index 000000000..94a2485bc
--- /dev/null
+++ b/tests/test_models/test_algorithms/test_general_quant.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch.nn as nn
+
+
+class ToyModel(nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+ # TODO
+
+
+class TestGeneralQuant(TestCase):
+ """TODO.
+
+ Args:
+ TestCase (_type_): _description_
+ """
+
+ def test_init(self):
+ pass
+
+ def test_prepare(self):
+ pass
+
+ def test_convert(self):
+ pass
+
+ def test_states(self):
+ pass
+
+ def test_forward(self):
+ pass
diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py
index 519407772..3a00e93b9 100644
--- a/tests/test_models/test_algorithms/test_prune_algorithm.py
+++ b/tests/test_models/test_algorithms/test_prune_algorithm.py
@@ -114,7 +114,8 @@ def test_iterative_prune_int(self):
model = MODELS.build(MODEL_CFG)
mutator = MODELS.build(MUTATOR_CONFIG_FLOAT)
mutator.prepare_from_supernet(model)
- prune_target = mutator.sample_choices()
+ mutator.set_choices(mutator.sample_choices())
+ prune_target = mutator.choice_template
epoch = 10
epoch_step = 2
@@ -135,9 +136,11 @@ def test_iterative_prune_int(self):
data['inputs'], data['data_samples'], mode='loss')
current_choices = algorithm.mutator.current_choices
+ group_prune_target = algorithm.group_target_pruning_ratio(
+ prune_target, mutator.search_groups)
for key in current_choices:
self.assertAlmostEqual(
- current_choices[key], prune_target[key], delta=0.1)
+ current_choices[key], group_prune_target[key], delta=0.1)
def test_load_pretrained(self):
epoch_step = 2
@@ -158,7 +161,7 @@ def test_load_pretrained(self):
algorithm = ItePruneAlgorithm(
model_cfg,
mutator_cfg=MUTATOR_CONFIG_NUM,
- target_pruning_ratio={},
+ target_pruning_ratio=None,
step_epoch=epoch_step,
prune_times=times,
).to(DEVICE)
@@ -167,3 +170,43 @@ def test_load_pretrained(self):
# delete checkpoint
os.remove(checkpoint_path)
+
+ def test_group_target_ratio(self):
+
+ model = MODELS.build(MODEL_CFG)
+ mutator = MODELS.build(MUTATOR_CONFIG_FLOAT)
+ mutator.prepare_from_supernet(model)
+ mutator.set_choices(mutator.sample_choices())
+ prune_target = mutator.choice_template
+
+ custom_groups = [[
+ 'backbone.layer1.0.conv1_(0, 64)_64',
+ 'backbone.layer1.1.conv1_(0, 64)_64'
+ ]]
+ mutator_cfg = copy.deepcopy(MUTATOR_CONFIG_FLOAT)
+ mutator_cfg['custom_groups'] = custom_groups
+
+ epoch_step = 2
+ times = 3
+
+ prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1
+ prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.1
+
+ _ = ItePruneAlgorithm(
+ MODEL_CFG,
+ target_pruning_ratio=prune_target,
+ mutator_cfg=mutator_cfg,
+ step_epoch=epoch_step,
+ prune_times=times).to(DEVICE)
+
+ prune_target['backbone.layer1.0.conv1_(0, 64)_64'] = 0.1
+ prune_target['backbone.layer1.1.conv1_(0, 64)_64'] = 0.2
+
+ with self.assertRaises(ValueError):
+
+ _ = ItePruneAlgorithm(
+ MODEL_CFG,
+ target_pruning_ratio=prune_target,
+ mutator_cfg=mutator_cfg,
+ step_epoch=epoch_step,
+ prune_times=times).to(DEVICE)
diff --git a/tests/test_models/test_algorithms/test_spos.py b/tests/test_models/test_algorithms/test_spos.py
index 3392f469c..f73521111 100644
--- a/tests/test_models/test_algorithms/test_spos.py
+++ b/tests/test_models/test_algorithms/test_spos.py
@@ -56,7 +56,7 @@ def test_init(self):
self.assertIsInstance(alg.mutator, OneShotModuleMutator)
# initiate spos when `fix_subnet` is not None.
- fix_subnet = {'mutable': 'conv1'}
+ fix_subnet = {'mutable': {'chosen': 'conv1'}}
alg = SPOS(model, mutator, fix_subnet=fix_subnet)
self.assertEqual(alg.architecture.mutable.num_choices, 1)
@@ -75,7 +75,7 @@ def test_forward_loss(self):
self.assertIsInstance(loss, dict)
# subnet
- fix_subnet = {'mutable': 'conv1'}
+ fix_subnet = {'mutable': {'chosen': 'conv1'}}
alg = SPOS(model, fix_subnet=fix_subnet)
loss = alg(inputs, mode='loss')
self.assertIsInstance(loss, dict)
diff --git a/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py b/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py
index acaa84b2b..ba3f5955d 100644
--- a/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py
+++ b/tests/test_models/test_architectures/test_backbones/test_dartsbackbone.py
@@ -55,7 +55,7 @@ def setUp(self) -> None:
self.mutator_cfg = dict(
type='DiffModuleMutator',
- custom_group=None,
+ custom_groups=None,
)
def test_darts_backbone(self):
@@ -81,7 +81,7 @@ def test_darts_backbone_with_auxliary(self):
custom_group = self.generate_key(model)
assert model is not None
- self.mutable_cfg.update(custom_group=custom_group)
+ self.mutable_cfg.update(custom_groups=custom_group)
mutator = MODELS.build(self.mutator_cfg)
assert mutator is not None
mutator.prepare_from_supernet(model)
diff --git a/tests/test_models/test_architectures/test_dynamic_op/utils.py b/tests/test_models/test_architectures/test_dynamic_op/utils.py
index ceb2a5d4f..e448f300e 100644
--- a/tests/test_models/test_architectures/test_dynamic_op/utils.py
+++ b/tests/test_models/test_architectures/test_dynamic_op/utils.py
@@ -2,6 +2,7 @@
from typing import Dict, Optional
from mmrazor.models.architectures.dynamic_ops import DynamicMixin
+from mmrazor.utils.typing import DumpChosen
def fix_dynamic_op(op: DynamicMixin,
@@ -13,4 +14,7 @@ def fix_dynamic_op(op: DynamicMixin,
else:
chosen = mutable.dump_chosen()
- mutable.fix_chosen(chosen)
+ if not isinstance(chosen, DumpChosen):
+ chosen = DumpChosen(**chosen)
+
+ mutable.fix_chosen(chosen.chosen)
diff --git a/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py b/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py
new file mode 100644
index 000000000..d6b670bb5
--- /dev/null
+++ b/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+
+class TestLearnableFakeQuantize(TestCase):
+
+ def test_init(self):
+ pass
+
+ def test_repr(self):
+ pass
+
+ def test_calculate_qparams(self):
+ pass
+
+ def test_forward(self):
+ pass
+
+ def test_load_state_dict(self):
+ pass
+
+ def test_save_state_dict(self):
+ pass
diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py
index 37fea2baf..15243650c 100644
--- a/tests/test_models/test_losses/test_distillation_losses.py
+++ b/tests/test_models/test_losses/test_distillation_losses.py
@@ -7,7 +7,7 @@
from mmrazor import digit_version
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss,
FBKDLoss, FTLoss, InformationEntropyLoss,
- KDSoftCELoss, OFDLoss, OnehotLikeLoss)
+ KDSoftCELoss, OFDLoss, OnehotLikeLoss, PKDLoss)
class TestLosses(TestCase):
@@ -179,3 +179,28 @@ def test_fbkdloss(self):
fbkd_loss = fbkdloss(s_input, t_input)
self.assertTrue(fbkd_loss.numel() == 1)
+
+ def test_pkdloss(self):
+ pkd_loss = PKDLoss(loss_weight=1.0)
+ feats_S, feats_T = torch.rand(2, 256, 4, 4), torch.rand(2, 256, 4, 4)
+ loss = pkd_loss(feats_S, feats_T)
+ self.assertTrue(loss.numel() == 1)
+ self.assertTrue(0. <= loss <= 1.)
+
+ num_stages = 4
+ feats_S = (torch.rand(2, 256, 4, 4) for _ in range(num_stages))
+ feats_T = (torch.rand(2, 256, 4, 4) for _ in range(num_stages))
+ loss = pkd_loss(feats_S, feats_T)
+ self.assertTrue(loss.numel() == 1)
+ self.assertTrue(0. <= loss <= num_stages * 1.)
+
+ feats_S, feats_T = torch.rand(2, 256, 2, 2), torch.rand(2, 256, 4, 4)
+ loss = pkd_loss(feats_S, feats_T)
+ self.assertTrue(loss.numel() == 1)
+ self.assertTrue(0. <= loss <= 1.)
+
+ pkd_loss = PKDLoss(loss_weight=1.0, resize_stu=False)
+ feats_S, feats_T = torch.rand(2, 256, 2, 2), torch.rand(2, 256, 4, 4)
+ loss = pkd_loss(feats_S, feats_T)
+ self.assertTrue(loss.numel() == 1)
+ self.assertTrue(0. <= loss <= 1.)
diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py
index 3e87b0654..0b5f55e88 100644
--- a/tests/test_models/test_mutables/test_derived_mutable.py
+++ b/tests/test_models/test_mutables/test_derived_mutable.py
@@ -24,9 +24,9 @@ def test_is_fixed(self) -> None:
with pytest.raises(RuntimeError):
derived_mutable.is_fixed = True
- mc.fix_chosen(mc.dump_chosen())
+ mc.fix_chosen(mc.dump_chosen().chosen)
assert not derived_mutable.is_fixed
- mv.fix_chosen(mv.dump_chosen())
+ mv.fix_chosen(mv.dump_chosen().chosen)
assert derived_mutable.is_fixed
def test_fix_dump_chosen(self) -> None:
@@ -34,13 +34,13 @@ def test_fix_dump_chosen(self) -> None:
mv.current_choice = 3
derived_mutable = mv * 2
- assert derived_mutable.dump_chosen() == 6
+ assert derived_mutable.dump_chosen().chosen == 6
mv.current_choice = 4
- assert derived_mutable.dump_chosen() == 8
+ assert derived_mutable.dump_chosen().chosen == 8
# nothing will happen
- derived_mutable.fix_chosen(derived_mutable.dump_chosen())
+ derived_mutable.fix_chosen(derived_mutable.dump_chosen().chosen)
def test_derived_same_mutable(self) -> None:
mc = SquentialMutableChannel(num_channels=3)
diff --git a/tests/test_models/test_mutables/test_diffop.py b/tests/test_models/test_mutables/test_diffop.py
index 702adf8e2..eab9fff2b 100644
--- a/tests/test_models/test_mutables/test_diffop.py
+++ b/tests/test_models/test_mutables/test_diffop.py
@@ -44,9 +44,6 @@ def test_forward_arch_param(self):
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
- output = op.forward_arch_param(input, arch_param=None)
- assert output is not None
-
# test when some element of arch_param is 0
arch_param = nn.Parameter(torch.ones(op.num_choices))
output = op.forward_arch_param(input, arch_param=arch_param)
diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py
index c93a43842..79c552250 100644
--- a/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py
+++ b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
-import pytest
import torch
from mmrazor.models.mutables import (SimpleMutableChannel,
@@ -31,5 +30,5 @@ def test_SimpleMutableChannel(self):
channel.current_choice = torch.tensor([1, 0, 0, 0]).bool()
self.assertEqual(channel.activated_channels, 1)
channel.fix_chosen()
- with pytest.raises(NotImplementedError):
- channel.dump_chosen()
+ # with pytest.raises(NotImplementedError):
+ # channel.dump_chosen()
diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py
index d7d05b1d5..b33dfcc98 100644
--- a/tests/test_models/test_mutables/test_mutable_value.py
+++ b/tests/test_models/test_mutables/test_mutable_value.py
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import copy
from unittest import TestCase
import pytest
@@ -42,22 +41,13 @@ def test_init_one_shot_mutable_value(self) -> None:
def test_fix_chosen(self) -> None:
mv = MutableValue([2, 3, 4])
chosen = mv.dump_chosen()
- assert chosen == {
- 'current_choice': mv.current_choice,
- 'all_choices': mv.choices
- }
+ assert chosen.chosen == mv.current_choice
+ assert chosen.meta['all_choices'] == mv.choices
- chosen['current_choice'] = 5
with pytest.raises(AssertionError):
- mv.fix_chosen(chosen)
-
- chosen_copied = copy.deepcopy(chosen)
- chosen_copied['all_choices'] = [1, 2, 3]
- with pytest.raises(AssertionError):
- mv.fix_chosen(chosen_copied)
+ mv.fix_chosen(5)
- chosen['current_choice'] = 3
- mv.fix_chosen(chosen)
+ mv.fix_chosen(3)
assert mv.current_choice == 3
with pytest.raises(RuntimeError):
diff --git a/tests/test_models/test_mutables/test_onehotop.py b/tests/test_models/test_mutables/test_onehotop.py
index 4ace5870d..a3b86d745 100644
--- a/tests/test_models/test_mutables/test_onehotop.py
+++ b/tests/test_models/test_mutables/test_onehotop.py
@@ -44,9 +44,6 @@ def test_forward_arch_param(self):
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
- output = op.forward_arch_param(input, arch_param=None)
- assert output is not None
-
# test when some element of arch_param is 0
arch_param = nn.Parameter(torch.ones(op.num_choices))
output = op.forward_arch_param(input, arch_param=arch_param)
diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py
index 96908d807..3d6ed7773 100644
--- a/tests/test_models/test_mutators/test_channel_mutator.py
+++ b/tests/test_models/test_mutators/test_channel_mutator.py
@@ -134,3 +134,34 @@ def test_models_with_predefined_dynamic_op(self):
parse_cfg={'type': 'Predefined'})
mutator.prepare_from_supernet(model)
self._test_a_mutator(mutator, model)
+
+ def test_custom_group(self):
+ ARCHITECTURE_CFG = dict(
+ type='mmcls.ImageClassifier',
+ backbone=dict(type='mmcls.MobileNetV2', widen_factor=1.5),
+ neck=dict(type='mmcls.GlobalAveragePooling'),
+ head=dict(
+ type='mmcls.LinearClsHead',
+ num_classes=1000,
+ in_channels=1920,
+ loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
+ topk=(1, 5)))
+ model = MODELS.build(ARCHITECTURE_CFG)
+
+ # generate config
+ model1 = copy.deepcopy(model)
+ mutator1 = ChannelMutator()
+ mutator1.prepare_from_supernet(model1)
+
+ self.assertEqual(len(mutator1.search_groups), 25)
+
+ custom_groups = [[
+ 'backbone.layer2.1.conv.0.conv_(0, 240)_240',
+ 'backbone.layer3.0.conv.0.conv_(0, 240)_240'
+ ]]
+
+ model2 = copy.deepcopy(model)
+ mutator2 = ChannelMutator(custom_groups=custom_groups)
+ mutator2.prepare_from_supernet(model2)
+
+ self.assertEqual(len(mutator2.search_groups), 24)
diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py
index 5e230a202..663637fc9 100644
--- a/tests/test_models/test_mutators/test_diff_mutator.py
+++ b/tests/test_models/test_mutators/test_diff_mutator.py
@@ -98,7 +98,8 @@ def setUp(self):
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
self.MUTATOR_CFG = dict(
- type='DiffModuleMutator', custom_group=[['op1'], ['op2'], ['op3']])
+ type='DiffModuleMutator',
+ custom_groups=[['op1'], ['op2'], ['op3']])
def test_diff_mutator_diffop_layer(self) -> None:
model = SearchableLayer(self.MUTABLE_CFG)
@@ -111,7 +112,7 @@ def test_diff_mutator_diffop_model(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [
+ mutator_cfg['custom_groups'] = [
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
@@ -128,7 +129,7 @@ def test_diff_mutator_diffop_model_error(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [
+ mutator_cfg['custom_groups'] = [
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3_error_key'],
@@ -142,7 +143,7 @@ def test_diff_mutator_diffop_alias(self) -> None:
model = SearchableModelAlias(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [['op1'], ['op2'], ['op3']]
+ mutator_cfg['custom_groups'] = [['op1'], ['op2'], ['op3']]
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
@@ -157,11 +158,11 @@ def test_diff_mutator_alias_module_name(self) -> None:
model = SearchableModelAlias(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [['op1'],
- [
- 'slayer1.op2', 'slayer2.op2',
- 'slayer3.op2'
- ], ['slayer1.op3', 'slayer2.op3']]
+ mutator_cfg['custom_groups'] = [['op1'],
+ [
+ 'slayer1.op2', 'slayer2.op2',
+ 'slayer3.op2'
+ ], ['slayer1.op3', 'slayer2.op3']]
mutator: DiffModuleMutator = MODELS.build(mutator_cfg)
mutator.prepare_from_supernet(model)
@@ -175,7 +176,7 @@ def test_diff_mutator_duplicate_keys(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [
+ mutator_cfg['custom_groups'] = [
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer2.op3'],
@@ -189,7 +190,7 @@ def test_diff_mutator_duplicate_key_alias(self) -> None:
model = SearchableModelAlias(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [
+ mutator_cfg['custom_groups'] = [
['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
@@ -203,7 +204,7 @@ def test_diff_mutator_illegal_key(self) -> None:
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [
+ mutator_cfg['custom_groups'] = [
['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
@@ -217,7 +218,7 @@ def test_sample_and_set_choices(self):
model = SearchableModel(self.MUTABLE_CFG)
mutator_cfg = self.MUTATOR_CFG.copy()
- mutator_cfg['custom_group'] = [
+ mutator_cfg['custom_groups'] = [
['slayer1.op1', 'slayer2.op1', 'slayer3.op1'],
['slayer1.op2', 'slayer2.op2', 'slayer3.op2'],
['slayer1.op3', 'slayer2.op3', 'slayer3.op3'],
diff --git a/tests/test_models/test_observers/test_observer.py b/tests/test_models/test_observers/test_observer.py
new file mode 100644
index 000000000..ca39ecfbd
--- /dev/null
+++ b/tests/test_models/test_observers/test_observer.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch.nn as nn
+
+
+class ToyModel(nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+ # TODO
+
+
+class TestMinMaxObserver(TestCase):
+ """TODO.
+
+ Args:
+ TestCase (_type_): _description_
+ """
+
+ def test_init(self):
+ pass
+
+ def test_prepare(self):
+ pass
+
+ def test_convert(self):
+ pass
+
+ def test_states(self):
+ pass
+
+ def test_forward(self):
+ pass
+
+
+class TestLSQObserver(TestMinMaxObserver):
+ pass
diff --git a/tests/test_models/test_quantizers/test_trt_quantizer.py b/tests/test_models/test_quantizers/test_trt_quantizer.py
new file mode 100644
index 000000000..9f85d1ecd
--- /dev/null
+++ b/tests/test_models/test_quantizers/test_trt_quantizer.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch.nn as nn
+
+
+class ToyModel(nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+ # TODO
+
+
+class TestTRTQuantizer(TestCase):
+ """TODO.
+
+ Args:
+ TestCase (_type_): _description_
+ """
+
+ def test_init(self):
+ pass
+
+ def test_prepare(self):
+ pass
+
+ def test_convert(self):
+ pass
+
+ def test_states(self):
+ pass
+
+ def test_forward(self):
+ pass
diff --git a/tests/test_models/test_subnet/test_fix_subnet.py b/tests/test_models/test_subnet/test_fix_subnet.py
index 010372212..0137a8274 100644
--- a/tests/test_models/test_subnet/test_fix_subnet.py
+++ b/tests/test_models/test_subnet/test_fix_subnet.py
@@ -57,8 +57,12 @@ def test_load_fix_subnet(self):
# fix subnet is dict
fix_subnet = {
- 'mutable1': 'conv1',
- 'mutable2': 'conv2',
+ 'mutable1': {
+ 'chosen': 'conv1'
+ },
+ 'mutable2': {
+ 'chosen': 'conv2'
+ },
}
model = MockModel()
@@ -80,8 +84,12 @@ def test_load_fix_subnet(self):
def test_export_fix_subnet(self):
# get FixSubnet
fix_subnet = {
- 'mutable1': 'conv1',
- 'mutable2': 'conv2',
+ 'mutable1': {
+ 'chosen': 'conv1'
+ },
+ 'mutable2': {
+ 'chosen': 'conv2'
+ },
}
model = MockModel()
@@ -95,6 +103,14 @@ def test_export_fix_subnet(self):
model.mutable2.current_choice = 'conv2'
exported_fix_subnet = export_fix_subnet(model)
+ mutable1_dump_chosen = exported_fix_subnet['mutable1']
+ mutable2_dump_chosen = exported_fix_subnet['mutable2']
+
+ mutable1_chosen_dict = dict(chosen=mutable1_dump_chosen.chosen)
+ mutable2_chosen_dict = dict(chosen=mutable2_dump_chosen.chosen)
+
+ exported_fix_subnet['mutable1'] = mutable1_chosen_dict
+ exported_fix_subnet['mutable2'] = mutable2_chosen_dict
self.assertDictEqual(fix_subnet, exported_fix_subnet)
def test_export_fix_subnet_with_derived_mutable(self) -> None:
@@ -102,7 +118,10 @@ def test_export_fix_subnet_with_derived_mutable(self) -> None:
fix_subnet = export_fix_subnet(model)
self.assertDictEqual(
fix_subnet, {'source_mutable': model.source_mutable.dump_chosen()})
- fix_subnet['source_mutable']['current_choice'] = 4
+
+ fix_subnet['source_mutable'] = dict(
+ fix_subnet['source_mutable']._asdict())
+ fix_subnet['source_mutable']['chosen'] = 4
load_fix_subnet(model, fix_subnet)
assert model.source_mutable.current_choice == 4
assert model.derived_mutable.current_choice == 8
@@ -114,7 +133,10 @@ def test_export_fix_subnet_with_derived_mutable(self) -> None:
'source_mutable': model.source_mutable.dump_chosen(),
'derived_mutable': model.derived_mutable.dump_chosen()
})
- fix_subnet['source_mutable']['current_choice'] = 2
+
+ fix_subnet['source_mutable'] = dict(
+ fix_subnet['source_mutable']._asdict())
+ fix_subnet['source_mutable']['chosen'] = 2
load_fix_subnet(model, fix_subnet)
assert model.source_mutable.current_choice == 2
assert model.derived_mutable.current_choice == 4
diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py
new file mode 100644
index 000000000..671922f69
--- /dev/null
+++ b/tests/test_models/test_task_modules/test_custom_tracer.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry
+from mmrazor.testing import ConvBNReLU
+
+
+class testCustomTracer(TestCase):
+
+ def test_init(self):
+ tracer = CustomTracer()
+ assert tracer.skipped_methods.__len__() == 0
+
+ def test_trace(self):
+ tracer = CustomTracer()
+ model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN'))
+ graph = tracer.trace(model) # noqa: F841
+
+ def test_auto_skip_call_module(self):
+ pass
+
+ def test_auto_skip_call_method(self):
+ pass
+
+ def test_configurable_skipped_methods(self):
+ pass
+
+
+class testUntracedMethodRgistry(TestCase):
+
+ def test_init(self):
+ self.assertEqual(len(UntracedMethodRegistry.method_dict), 0)
+
+ def test_add_method(self):
+ pass
diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py
index 60bcef4ba..2acb58e95 100644
--- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py
+++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py
@@ -4,6 +4,7 @@
import pytest
import torch
+from mmcv.cnn.bricks import Conv2dAdaptivePadding
from torch import Tensor
from torch.nn import Conv2d, Module, Parameter
@@ -124,8 +125,17 @@ def test_estimate(self) -> None:
flops_count = results['flops']
params_count = results['params']
- self.assertGreater(flops_count, 0)
- self.assertGreater(params_count, 0)
+ self.assertEqual(flops_count, 44.158)
+ self.assertEqual(params_count, 0.001)
+
+ fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
+ results = estimator.estimate(
+ model=fool_conv2d, flops_params_cfg=flops_params_cfg)
+ flops_count = results['flops']
+ params_count = results['params']
+
+ self.assertEqual(flops_count, 44.958)
+ self.assertEqual(params_count, 0.001)
def test_register_module(self) -> None:
fool_add_constant = FoolConvModule()
@@ -151,6 +161,17 @@ def test_disable_sepc_counter(self) -> None:
self.assertLess(rest_flops_count, 45.158)
self.assertLess(rest_params_count, 0.701)
+ fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
+ flops_params_cfg = dict(
+ input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter'])
+ rest_results = estimator.estimate(
+ model=fool_conv2d, flops_params_cfg=flops_params_cfg)
+ rest_flops_count = rest_results['flops']
+ rest_params_count = rest_results['params']
+
+ self.assertEqual(rest_flops_count, 0)
+ self.assertEqual(rest_params_count, 0)
+
def test_estimate_spec_module(self) -> None:
fool_add_constant = FoolConvModule()
flops_params_cfg = dict(
diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py
new file mode 100644
index 000000000..b1beaedce
--- /dev/null
+++ b/tests/test_visualizer/test_visualizer.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import numpy as np
+import pytest
+import torch
+from mmengine.visualization import Visualizer
+
+from mmrazor.visualization.local_visualizer import modify
+
+
+class TestVisualizer(TestCase):
+
+ def setUp(self):
+ """Setup the demo image in every test method.
+
+ TestCase calls functions in this order: setUp() -> testMethod() ->
+ tearDown() -> cleanUp()
+ """
+ self.image = np.random.randint(
+ 0, 256, size=(10, 10, 3)).astype('uint8')
+
+ def test_draw_featmap(self):
+ visualizer = Visualizer()
+ visualizer.draw_featmap = modify
+ image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8')
+
+ # must be Tensor
+ with pytest.raises(
+ AssertionError,
+ match='`featmap` should be torch.Tensor, but got '
+ ""):
+ visualizer.draw_featmap(np.ones((3, 3, 3)))
+
+ # test tensor format
+ with pytest.raises(
+ AssertionError, match='Input dimension must be 3, but got 4'):
+ visualizer.draw_featmap(torch.randn(1, 1, 3, 3))
+
+ # test overlaid_image shape
+ with pytest.warns(Warning):
+ visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image)
+
+ # test resize_shape
+ featmap = visualizer.draw_featmap(
+ torch.randn(1, 4, 3), resize_shape=(6, 7))
+ assert featmap.shape[:2] == (6, 7)
+ featmap = visualizer.draw_featmap(
+ torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7))
+ assert featmap.shape[:2] == (6, 7)
+
+ # test channel_reduction parameter
+ # mode only supports 'squeeze_mean' and 'select_max'
+ with pytest.raises(AssertionError):
+ visualizer.draw_featmap(
+ torch.randn(2, 3, 3), channel_reduction='xx')
+
+ featmap = visualizer.draw_featmap(
+ torch.randn(2, 3, 3), channel_reduction='squeeze_mean')
+ assert featmap.shape[:2] == (3, 3)
+ featmap = visualizer.draw_featmap(
+ torch.randn(2, 3, 3), channel_reduction='select_max')
+ assert featmap.shape[:2] == (3, 3)
+ featmap = visualizer.draw_featmap(
+ torch.randn(2, 3, 3), channel_reduction='pixel_wise_max')
+ assert featmap.shape[:2] == (3, 3)
+ featmap = visualizer.draw_featmap(
+ torch.randn(2, 4, 3),
+ overlaid_image=image,
+ channel_reduction='pixel_wise_max')
+ assert featmap.shape[:2] == (3, 3)
+
+ # test topk parameter
+ with pytest.raises(
+ AssertionError,
+ match='The input tensor channel dimension must be 1 or 3 '
+ 'when topk is less than 1, but the channel '
+ 'dimension you input is 6, you can use the '
+ 'channel_reduction parameter or set topk '
+ 'greater than 0 to solve the error'):
+ visualizer.draw_featmap(
+ torch.randn(6, 3, 3), channel_reduction=None, topk=0)
+
+ featmap = visualizer.draw_featmap(
+ torch.randn(6, 3, 3), channel_reduction='select_max', topk=10)
+ assert featmap.shape[:2] == (3, 3)
+ featmap = visualizer.draw_featmap(
+ torch.randn(1, 4, 3), channel_reduction=None, topk=-1)
+ assert featmap.shape[:2] == (4, 3)
+
+ featmap = visualizer.draw_featmap(
+ torch.randn(3, 4, 3),
+ overlaid_image=image,
+ channel_reduction=None,
+ topk=-1)
+ assert featmap.shape[:2] == (3, 3)
+ featmap = visualizer.draw_featmap(
+ torch.randn(6, 3, 3),
+ channel_reduction=None,
+ topk=4,
+ arrangement=(2, 2))
+ assert featmap.shape[:2] == (6, 6)
+ featmap = visualizer.draw_featmap(
+ torch.randn(6, 3, 3),
+ channel_reduction=None,
+ topk=4,
+ arrangement=(1, 4))
+ assert featmap.shape[:2] == (3, 12)
+ with pytest.raises(
+ AssertionError,
+ match='The product of row and col in the `arrangement` '
+ 'is less than topk, please set '
+ 'the `arrangement` correctly'):
+ visualizer.draw_featmap(
+ torch.randn(6, 3, 3),
+ channel_reduction=None,
+ topk=4,
+ arrangement=(1, 2))
+
+ # test gray
+ featmap = visualizer.draw_featmap(
+ torch.randn(6, 3, 3),
+ overlaid_image=np.random.randint(
+ 0, 256, size=(3, 3), dtype='uint8'),
+ channel_reduction=None,
+ topk=4,
+ arrangement=(2, 2))
+ assert featmap.shape[:2] == (6, 6)
diff --git a/tools/ckpt_demo.py b/tools/ckpt_demo.py
new file mode 100644
index 000000000..ee257390c
--- /dev/null
+++ b/tools/ckpt_demo.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+ckpt_path = '/mnt/lustre/humu/experiments/adaround/quantizied.pth'
+# ckpt_path =
+# '/mnt/petrelfs/humu/share/resnet18_8xb32_in1k_20210831-fbbb1da6.pth'
+# ckpt_path = '/tmp/humu/resnet18_uniform8/checkpoint.pth.tar'
+# ckpt_path = '/tmp/humu/resnet18_uniform8/quantized_checkpoint.pth.tar'
+
+state_dict = torch.load(ckpt_path, map_location='cpu')
+
+for k, v in state_dict['state_dict'].items():
+ print(k)
diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh
index 6dd67e574..3c74ec6ec 100644
--- a/tools/slurm_test.sh
+++ b/tools/slurm_test.sh
@@ -1,24 +1,10 @@
#!/usr/bin/env bash
-set -x
-
-PARTITION=$1
-JOB_NAME=$2
-CONFIG=$3
-CHECKPOINT=$4
-GPUS=${GPUS:-8}
-GPUS_PER_NODE=${GPUS_PER_NODE:-8}
-CPUS_PER_TASK=${CPUS_PER_TASK:-5}
-PY_ARGS=${@:5}
-SRUN_ARGS=${SRUN_ARGS:-""}
+CONFIG=$1
+CHECKPOINT=$2
+GPUS=$3
+PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
-srun -p ${PARTITION} \
- --job-name=${JOB_NAME} \
- --gres=gpu:${GPUS_PER_NODE} \
- --ntasks=${GPUS} \
- --ntasks-per-node=${GPUS_PER_NODE} \
- --cpus-per-task=${CPUS_PER_TASK} \
- --kill-on-bad-exit=1 \
- ${SRUN_ARGS} \
- python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
+python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+ $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
diff --git a/tools/tracer_demo.py b/tools/tracer_demo.py
new file mode 100644
index 000000000..88334d6aa
--- /dev/null
+++ b/tools/tracer_demo.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.fx as fx
+from mmengine.config import Config
+from mmengine.registry import MODELS
+
+from mmrazor.models.task_modules.tracer import custom_symbolic_trace
+
+cfg_path = 'configs/quantization/ptq/demo.py'
+_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear)
+
+
+def extract_subgraph(graphmodule, block_slice):
+ subgraph = copy.deepcopy(graphmodule.graph)
+ block_start, block_end = block_slice[:2]
+ for node in subgraph.nodes:
+ if node.name == 'inputs':
+ input_node = node
+ if node.name == block_start.name:
+ node.replace_input_with(node.prev, input_node)
+ if node.name == block_end.name:
+ output_node = node
+ if node.op == 'output':
+ node.replace_input_with(node.prev, output_node)
+ subgraph.lint()
+ subgraph_module = fx.GraphModule(graphmodule, subgraph)
+ subgraph_module.graph.eliminate_dead_code()
+ subgraph_module.recompile()
+ return subgraph_module
+
+
+def extract_blocks(graphmodule, key_word='layer'):
+ block_slices = []
+ block_slice = []
+ pre_stage_index, pre_block_index = 0, 0
+ cur_stage_index, cur_block_index = 0, 0
+ for node in graphmodule.graph.nodes:
+ if key_word not in node.name:
+ continue
+ else:
+ items = node.name.split('_')
+ for i, item in enumerate(items):
+ if key_word in item:
+ cur_stage_index = int(item[5:])
+ cur_block_index = int(items[i + 1])
+ break
+ if (cur_block_index != pre_block_index) or (cur_stage_index !=
+ pre_stage_index):
+ block_slice.append(node.prev)
+ if len(block_slice) == 2:
+ block_slices.append(block_slice)
+ block_slice = []
+ block_slice.append(node)
+
+ pre_stage_index, pre_block_index = cur_stage_index, cur_block_index
+
+ return block_slices
+
+
+def extract_layers(graphmodule, layer_types):
+ layer_slices = []
+ for node in graphmodule.graph.nodes:
+ if node.op == 'call_module':
+ m = node.graph.owning_module.get_submodule(node.target)
+ if isinstance(m, _ADAROUND_SUPPORT_TYPE):
+ layer_slices.append((node, node))
+ return layer_slices
+
+
+def main():
+ # load config
+ cfg = Config.fromfile(cfg_path)
+ model = MODELS.build(cfg.model)
+ symbolic_traced = custom_symbolic_trace(
+ model, concrete_args={'mode': 'tensor'})
+ # block_slices = extract_blocks(symbolic_traced)
+ block_slices = extract_layers(
+ symbolic_traced, layer_types=_ADAROUND_SUPPORT_TYPE)
+
+ for b in block_slices:
+ print(b[0].name, b[1].name)
+
+ print('#' * 100)
+ subgraph = extract_subgraph(symbolic_traced, block_slices[0])
+ print(subgraph.code)
+ for name, layer in subgraph.named_modules():
+ print(name, layer)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/visualizations/demo.jpg b/tools/visualizations/demo.jpg
new file mode 100644
index 000000000..dd613cee3
Binary files /dev/null and b/tools/visualizations/demo.jpg differ
diff --git a/tools/visualizations/feature_diff_visualization.py b/tools/visualizations/feature_diff_visualization.py
new file mode 100644
index 000000000..6ecb19c9a
--- /dev/null
+++ b/tools/visualizations/feature_diff_visualization.py
@@ -0,0 +1,169 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os
+
+import mmcv
+import torch
+from mmengine.config import Config
+from mmengine.registry import VISUALIZERS
+from mmengine.utils import import_modules_from_strings
+
+from mmrazor.models.task_modules import RecorderManager
+from mmrazor.utils import register_all_modules
+from mmrazor.visualization.local_visualizer import modify
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Feature map visualization')
+ parser.add_argument('img', help='Image file')
+ parser.add_argument(
+ 'config1', help='train config file path for the first model')
+ parser.add_argument(
+ 'config2', help='train config file path for the second model')
+ parser.add_argument('vis-config', help='visualization config file path')
+ parser.add_argument(
+ 'checkpoint1', help='Checkpoint file for the first model')
+ parser.add_argument(
+ 'checkpoint2', help='Checkpoint file for the second model')
+ parser.add_argument('--out-file', default=None, help='Path to output file')
+ parser.add_argument(
+ '--device', default='cpu', help='Device used for inference')
+ parser.add_argument('--repo', help='the corresponding repo name')
+ parser.add_argument(
+ '--use-norm',
+ action='store_true',
+ help='normalize the featmap before visualization')
+ parser.add_argument(
+ '--overlaid', action='store_true', help='overlaid image')
+ parser.add_argument(
+ '--channel-reduction',
+ help='Reduce multiple channels to a single channel. The optional value'
+ ' is \'squeeze_mean\', \'select_max\' or \'pixel_wise_max\'.',
+ default=None)
+ parser.add_argument(
+ '--topk',
+ help='If channel_reduction is not None and topk > 0, it will select '
+ 'topk channel to show by the sum of each channel. If topk <= 0, '
+ 'tensor_chw is assert to be one or three.',
+ type=int,
+ default=20)
+ parser.add_argument(
+ '--arrangement',
+ nargs='+',
+ type=int,
+ help='the arrangement of featmap when channel_reduction is not None '
+ 'and topk > 0.',
+ default=[4, 5])
+ parser.add_argument(
+ '--resize-shape',
+ nargs='+',
+ type=int,
+ help='the shape to scale the feature map',
+ default=None)
+ parser.add_argument(
+ '--alpha', help='the transparency of featmap', default=0.5)
+
+ parser.add_argument('--local_rank', type=int, default=0)
+
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def norm(feat):
+ N, C, H, W = feat.shape
+ feat = feat.permute(1, 0, 2, 3).reshape(C, -1)
+ mean = feat.mean(dim=-1, keepdim=True)
+ std = feat.std(dim=-1, keepdim=True)
+ centered = (feat - mean) / (std + 1e-6)
+ centered = centered.reshape(C, N, H, W).permute(1, 0, 2, 3)
+ return centered
+
+
+def main(args):
+ register_all_modules(False)
+ mod = import_modules_from_strings(f'{args.repo}.utils')
+ mod.register_all_modules()
+
+ apis = import_modules_from_strings(f'{args.repo}.apis')
+ inference_model, init_model = None, None
+ for attr_name in dir(apis):
+ if 'inference_' in attr_name:
+ inference_model = getattr(apis, attr_name)
+ if 'init_' in attr_name:
+ init_model = getattr(apis, attr_name)
+ assert inference_model and init_model
+
+ model1 = init_model(args.config1, args.checkpoint1, device=args.device)
+ # init visualizer
+ visualizer = VISUALIZERS.build(model1.cfg.visualizer)
+ visualizer.draw_featmap = modify
+
+ model2 = init_model(args.config2, args.checkpoint2, device=args.device)
+
+ visualization_cfg = Config.fromfile(args.vis_config)
+ recorder_cfg1 = visualization_cfg.recorders1
+ mappings1 = visualization_cfg.mappings1
+ recorder_cfg2 = visualization_cfg.recorders2
+ mappings2 = visualization_cfg.mappings2
+
+ recorder_manager1 = RecorderManager(recorder_cfg1)
+ recorder_manager1.initialize(model1)
+
+ recorder_manager2 = RecorderManager(recorder_cfg2)
+ recorder_manager2.initialize(model2)
+
+ with recorder_manager1:
+ # test a single image
+ _ = inference_model(model1, args.img)
+
+ with recorder_manager2:
+ # test a single image
+ _ = inference_model(model2, args.img)
+
+ overlaid_image = mmcv.imread(
+ args.img, channel_order='rgb') if args.overlaid else None
+
+ for name1, name2 in zip(mappings1.keys(), mappings2.keys()):
+ record1 = mappings1[name1]
+ recorder1 = recorder_manager1.get_recorder(record1.recorder)
+ record_idx = getattr(record1, 'record_idx', 0)
+ data_idx = getattr(record1, 'data_idx')
+ feats1 = recorder1.get_record_data(record_idx, data_idx)
+ if isinstance(feats1, torch.Tensor):
+ feats1 = (feats1, )
+
+ record2 = mappings2[name2]
+ recorder2 = recorder_manager2.get_recorder(record2.recorder)
+ record_idx = getattr(record2, 'record_idx', 0)
+ data_idx = getattr(record2, 'data_idx')
+ feats2 = recorder2.get_record_data(record_idx, data_idx)
+ if isinstance(feats2, torch.Tensor):
+ feats2 = (feats2, )
+
+ for i, (feat1, feat2) in enumerate(zip(feats1, feats2)):
+ diff = torch.abs(feat1 - feat2)
+ if args.use_norm:
+ diff = norm(diff)
+ drawn_img = visualizer.draw_featmap(
+ diff[0],
+ overlaid_image,
+ args.channel_reduction,
+ topk=args.topk,
+ arrangement=tuple(args.arrangement),
+ resize_shape=tuple(args.resize_shape)
+ if args.resize_shape else None,
+ alpha=args.alpha)
+ visualizer.add_datasample(
+ f'model1_{name1}_model2_{name2}_{i}',
+ drawn_img,
+ show=args.out_file is None,
+ wait_time=0.1,
+ out_file=args.out_file)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
diff --git a/tools/visualizations/feature_visualization.py b/tools/visualizations/feature_visualization.py
new file mode 100644
index 000000000..ee2b373d5
--- /dev/null
+++ b/tools/visualizations/feature_visualization.py
@@ -0,0 +1,155 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os
+
+import mmcv
+import torch
+from mmengine.config import Config, DictAction
+from mmengine.registry import VISUALIZERS
+from mmengine.utils import import_modules_from_strings
+
+from mmrazor.models.task_modules import RecorderManager
+from mmrazor.utils import register_all_modules
+from mmrazor.visualization.local_visualizer import modify
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Feature map visualization')
+ parser.add_argument('img', help='Image file')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('vis-config', help='visualization config file path')
+ parser.add_argument('checkpoint', help='Checkpoint file')
+ parser.add_argument('--out-file', default=None, help='Path to output file')
+ parser.add_argument(
+ '--device', default='cpu', help='Device used for inference')
+ parser.add_argument('--repo', help='the corresponding repo name')
+ parser.add_argument(
+ '--use-norm',
+ action='store_true',
+ help='normalize the featmap before visualization')
+ parser.add_argument(
+ '--overlaid', action='store_true', help='overlaid image')
+ parser.add_argument(
+ '--channel-reduction',
+ help='Reduce multiple channels to a single channel. The optional value'
+ ' is \'squeeze_mean\', \'select_max\' or \'pixel_wise_max\'.',
+ default=None)
+ parser.add_argument(
+ '--topk',
+ type=int,
+ help='If channel_reduction is not None and topk > 0, it will select '
+ 'topk channel to show by the sum of each channel. If topk <= 0, '
+ 'tensor_chw is assert to be one or three.',
+ default=20)
+ parser.add_argument(
+ '--arrangement',
+ nargs='+',
+ type=int,
+ help='the arrangement of featmap when channel_reduction is not None '
+ 'and topk > 0.',
+ default=[4, 5])
+ parser.add_argument(
+ '--resize-shape',
+ nargs='+',
+ type=int,
+ help='the shape to scale the feature map',
+ default=None)
+ parser.add_argument(
+ '--alpha', help='the transparency of featmap', default=0.5)
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.',
+ default={})
+
+ parser.add_argument('--local_rank', type=int, default=0)
+
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def norm(feat):
+ N, C, H, W = feat.shape
+ feat = feat.permute(1, 0, 2, 3).reshape(C, -1)
+ mean = feat.mean(dim=-1, keepdim=True)
+ std = feat.std(dim=-1, keepdim=True)
+ centered = (feat - mean) / (std + 1e-6)
+ centered = centered.reshape(C, N, H, W).permute(1, 0, 2, 3)
+ return centered
+
+
+def main(args):
+ register_all_modules(False)
+ mod = import_modules_from_strings(f'{args.repo}.utils')
+ mod.register_all_modules()
+
+ apis = import_modules_from_strings(f'{args.repo}.apis')
+ inference_model, init_model = None, None
+ for attr_name in dir(apis):
+ if 'inference_' in attr_name:
+ inference_model = getattr(apis, attr_name)
+ if 'init_' in attr_name:
+ init_model = getattr(apis, attr_name)
+ assert inference_model and init_model
+
+ model = init_model(args.config, args.checkpoint, device=args.device)
+ # init visualizer
+ visualizer = VISUALIZERS.build(model.cfg.visualizer)
+ visualizer.draw_featmap = modify
+
+ visualization_cfg = Config.fromfile(args.vis_config)
+ recorder_cfg = visualization_cfg.recorders
+ mappings = visualization_cfg.mappings
+ recorder_manager = RecorderManager(recorder_cfg)
+ recorder_manager.initialize(model)
+
+ with recorder_manager:
+ # test a single image
+ result = inference_model(model, args.img)
+
+ overlaid_image = mmcv.imread(
+ args.img, channel_order='rgb') if args.overlaid else None
+
+ for name, record in mappings.items():
+ recorder = recorder_manager.get_recorder(record.recorder)
+ record_idx = getattr(record, 'record_idx', 0)
+ data_idx = getattr(record, 'data_idx')
+ feats = recorder.get_record_data(record_idx, data_idx)
+ if isinstance(feats, torch.Tensor):
+ feats = (feats, )
+
+ for i, feat in enumerate(feats):
+ if args.use_norm:
+ feat = norm(feat)
+ drawn_img = visualizer.draw_featmap(
+ feat[0],
+ overlaid_image,
+ args.channel_reduction,
+ topk=args.topk,
+ arrangement=tuple(args.arrangement),
+ resize_shape=tuple(args.resize_shape)
+ if args.resize_shape else None,
+ alpha=args.alpha)
+ visualizer.add_datasample(
+ f'{name}_{i}',
+ drawn_img,
+ data_sample=result,
+ draw_gt=False,
+ show=args.out_file is None,
+ wait_time=0.1,
+ out_file=args.out_file,
+ **args.cfg_options)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
diff --git a/tools/visualizations/vis_configs/backbone_feature_diff_visualization.py b/tools/visualizations/vis_configs/backbone_feature_diff_visualization.py
new file mode 100644
index 000000000..7bc34d90a
--- /dev/null
+++ b/tools/visualizations/vis_configs/backbone_feature_diff_visualization.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# configs for the 1st model
+recorders1 = dict(
+ backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone'))
+mappings1 = dict(
+ p3=dict(recorder='backbone', data_idx=0),
+ p4=dict(recorder='backbone', data_idx=1),
+ p5=dict(recorder='backbone', data_idx=2),
+ p6=dict(recorder='backbone', data_idx=3))
+
+# configs for the 2nd model
+recorders2 = dict(
+ backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone'))
+mappings2 = dict(
+ p3=dict(recorder='backbone', data_idx=0),
+ p4=dict(recorder='backbone', data_idx=1),
+ p5=dict(recorder='backbone', data_idx=2),
+ p6=dict(recorder='backbone', data_idx=3))
diff --git a/tools/visualizations/vis_configs/backbone_feature_visualization.py b/tools/visualizations/vis_configs/backbone_feature_visualization.py
new file mode 100644
index 000000000..1c8038dff
--- /dev/null
+++ b/tools/visualizations/vis_configs/backbone_feature_visualization.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+recorders = dict(
+ backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone'))
+mappings = dict(
+ p3=dict(recorder='backbone', data_idx=0),
+ p4=dict(recorder='backbone', data_idx=1),
+ p5=dict(recorder='backbone', data_idx=2),
+ p6=dict(recorder='backbone', data_idx=3))
diff --git a/tools/visualizations/vis_configs/fpn_feature_diff_visualization.py b/tools/visualizations/vis_configs/fpn_feature_diff_visualization.py
new file mode 100644
index 000000000..c6c172fb6
--- /dev/null
+++ b/tools/visualizations/vis_configs/fpn_feature_diff_visualization.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# configs for the 1st model
+recorders1 = dict(
+ neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck'))
+mappings1 = dict(
+ p3=dict(recorder='neck', data_idx=0),
+ p4=dict(recorder='neck', data_idx=1),
+ p5=dict(recorder='neck', data_idx=2),
+ p6=dict(recorder='neck', data_idx=3))
+
+# configs for the 2nd model
+recorders2 = dict(
+ neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck'))
+mappings2 = dict(
+ p3=dict(recorder='neck', data_idx=0),
+ p4=dict(recorder='neck', data_idx=1),
+ p5=dict(recorder='neck', data_idx=2),
+ p6=dict(recorder='neck', data_idx=3))
diff --git a/tools/visualizations/vis_configs/fpn_feature_visualization.py b/tools/visualizations/vis_configs/fpn_feature_visualization.py
new file mode 100644
index 000000000..40b6b3f1b
--- /dev/null
+++ b/tools/visualizations/vis_configs/fpn_feature_visualization.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+recorders = dict(
+ neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck'))
+mappings = dict(
+ p3=dict(recorder='neck', data_idx=0),
+ p4=dict(recorder='neck', data_idx=1),
+ p5=dict(recorder='neck', data_idx=2),
+ p6=dict(recorder='neck', data_idx=3))