From eb25bb75774b52d8823ed98517456f59583bf534 Mon Sep 17 00:00:00 2001 From: "P.Huang" <37200926+FreakieHuang@users.noreply.github.com> Date: Tue, 13 Sep 2022 20:53:43 +0800 Subject: [PATCH 1/7] [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng --- configs/distill/mmcls/crd/README.md | 30 ++ .../crd/crd_neck_r50_r18_8xb16_cifar10.py | 108 +++++++ .../mmcls/crd/datasets/crd_cifar10_bs16.py | 49 ++++ mmrazor/datasets/__init__.py | 5 + mmrazor/datasets/crd_dataset_wrapper.py | 254 ++++++++++++++++ mmrazor/datasets/transforms/__init__.py | 4 + mmrazor/datasets/transforms/formatting.py | 73 +++++ .../architectures/connectors/__init__.py | 3 +- .../architectures/connectors/crd_connector.py | 47 +++ mmrazor/models/losses/__init__.py | 3 +- mmrazor/models/losses/crd_loss.py | 271 ++++++++++++++++++ mmrazor/utils/placeholder.py | 2 + mmrazor/utils/setup_env.py | 1 + tests/data/dataset/a/1.JPG | 0 tests/data/dataset/ann.json | 28 ++ tests/data/dataset/ann.txt | 3 + tests/data/dataset/b/2.jpeg | 0 tests/data/dataset/b/subb/3.jpg | 0 tests/data/dataset/classes.txt | 2 + tests/data/dataset/multi_label_ann.json | 28 ++ tests/test_datasets/test_datasets.py | 94 ++++++ .../test_transforms/test_formatting.py | 56 ++++ .../test_connectors/test_connectors.py | 19 +- .../test_losses/test_distillation_losses.py | 35 ++- 24 files changed, 1109 insertions(+), 6 deletions(-) create mode 100644 configs/distill/mmcls/crd/README.md create mode 100644 configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py create mode 100644 configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py create mode 100644 mmrazor/datasets/__init__.py create mode 100644 mmrazor/datasets/crd_dataset_wrapper.py create mode 100644 mmrazor/datasets/transforms/__init__.py create mode 100644 mmrazor/datasets/transforms/formatting.py create mode 100644 mmrazor/models/architectures/connectors/crd_connector.py create mode 100644 mmrazor/models/losses/crd_loss.py create mode 100644 tests/data/dataset/a/1.JPG create mode 100644 tests/data/dataset/ann.json create mode 100644 tests/data/dataset/ann.txt create mode 100644 tests/data/dataset/b/2.jpeg create mode 100644 tests/data/dataset/b/subb/3.jpg create mode 100644 tests/data/dataset/classes.txt create mode 100644 tests/data/dataset/multi_label_ann.json create mode 100644 tests/test_datasets/test_datasets.py create mode 100644 tests/test_datasets/test_transforms/test_formatting.py diff --git a/configs/distill/mmcls/crd/README.md b/configs/distill/mmcls/crd/README.md new file mode 100644 index 000000000..0f02f365e --- /dev/null +++ b/configs/distill/mmcls/crd/README.md @@ -0,0 +1,30 @@ +# CONTRASTIVE REPRESENTATION DISTILLATION + +> [CONTRASTIVE REPRESENTATION DISTILLATION](https://arxiv.org/abs/1910.10699) + +## Abstract + +Often we wish to transfer representational knowledge from one neural network to another. Examples include distilling a large network into a smaller one, transferring knowledge from one sensory modality to a second, or ensembling a collection of models into a single estimator. Knowledge distillation, the standard approach to these problems, minimizes the KL divergence between the probabilistic outputs of a teacher and student network. We demonstrate that this objective ignores important structural knowledge of the teacher network. This motivates an alternative objective by which we train a student to capture significantly more information in the teacher’s representation of the data. We formulate this objective as contrastive learning. Experiments demonstrate that our resulting new objective outperforms knowledge distillation and other cutting-edge distillers on a variety of knowledge transfer tasks, including single model compression, ensemble distillation, and cross-modal transfer. Our method sets a new state-of-the-art in many transfer tasks, and sometimes even outperforms the teacher network when combined with knowledge distillation.[Original code](http://github.com/HobbitLong/RepDistiller) + +![pipeline](../../../../docs/en/imgs/model_zoo/crd/pipeline.jpg) + +## Citation + +```latex +@article{tian2019contrastive, + title={Contrastive representation distillation}, + author={Tian, Yonglong and Krishnan, Dilip and Isola, Phillip}, + journal={arXiv preprint arXiv:1910.10699}, + year={2019} +} +``` + +## Results and models + +| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | +| ------- | --------- | --------- | --------- | --------- | ------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | +| CIFAR10 | ResNet-18 | ResNet-50 | 94.79 | 99.86 | [config](crd_neck_r50_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) \|[model](<>) \| [log](<>) | + +## Acknowledgement + +Shout out to @chengshuang18 for his special contribution. diff --git a/configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py b/configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py new file mode 100644 index 000000000..4e36e9a2a --- /dev/null +++ b/configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py @@ -0,0 +1,108 @@ +_base_ = [ + 'mmcls::_base_/datasets/cifar10_bs16.py', + 'mmcls::_base_/schedules/cifar10_bs128.py', + 'mmcls::_base_/default_runtime.py' +] + +model = dict( + _scope_='mmrazor', + type='SingleTeacherDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet50_8xb16_cifar10.py', pretrained=True), + teacher_ckpt='resnet50_b16x8_cifar10_20210528-f54bfad9.pth', + distiller=dict( + type='ConfigurableDistiller', + student_recorders=dict( + neck=dict(type='ModuleOutputs', source='neck.gap'), + data_samples=dict(type='ModuleInputs', source='')), + teacher_recorders=dict( + neck=dict(type='ModuleOutputs', source='neck.gap')), + distill_losses=dict(loss_crd=dict(type='CRDLoss', loss_weight=0.8)), + connectors=dict( + loss_crd_stu=dict(type='CRDConnector', dim_in=512, dim_out=128), + loss_crd_tea=dict(type='CRDConnector', dim_in=2048, dim_out=128)), + loss_forward_mappings=dict( + loss_crd=dict( + s_feats=dict( + from_student=True, + recorder='neck', + connector='loss_crd_stu'), + t_feats=dict( + from_student=False, + recorder='neck', + connector='loss_crd_tea'), + data_samples=dict( + from_student=True, recorder='data_samples', data_idx=1))))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') + +# change `CIFAR10` dataset to `CRDDataset` dataset. +dataset_type = 'CIFAR10' +train_pipeline = [ + dict(_scope_='mmcls', type='RandomCrop', crop_size=32, padding=4), + dict(_scope_='mmcls', type='RandomFlip', prob=0.5, direction='horizontal'), + dict(_scope_='mmrazor', type='PackCRDClsInputs'), +] + +test_pipeline = [ + dict(_scope_='mmrazor', type='PackCRDClsInputs'), +] + +ori_train_dataset = dict( + _scope_='mmcls', + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline) + +crd_train_dataset = dict( + _scope_='mmrazor', + type='CRDDataset', + dataset=ori_train_dataset, + neg_num=16384, + sample_mode='exact', + percent=1.0) + +ori_test_dataset = dict( + _scope_='mmcls', + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline) + +crd_test_dataset = dict( + _scope_='mmrazor', + type='CRDDataset', + dataset=ori_test_dataset, + neg_num=16384, + sample_mode='exact', + percent=1.0) + +train_dataloader = dict( + _delete_=True, + batch_size=16, + num_workers=2, + dataset=crd_train_dataset, + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + _delete_=True, + batch_size=16, + num_workers=2, + dataset=crd_test_dataset, + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) diff --git a/configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py b/configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py new file mode 100644 index 000000000..c7cb74c39 --- /dev/null +++ b/configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py @@ -0,0 +1,49 @@ +# dataset settings +dataset_type = 'CIFAR10' +preprocess_cfg = dict( + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type='RandomCrop', crop_size=32, padding=4), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='PackClsInputs'), +] + +neg_num = 16384 +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline, + neg_num=neg_num), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline, + neg_num=neg_num), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='Accuracy', topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmrazor/datasets/__init__.py b/mmrazor/datasets/__init__.py new file mode 100644 index 000000000..5cfa79460 --- /dev/null +++ b/mmrazor/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .crd_dataset_wrapper import CRDDataset +from .transforms import PackCRDClsInputs + +__all__ = ['PackCRDClsInputs', 'CRDDataset'] diff --git a/mmrazor/datasets/crd_dataset_wrapper.py b/mmrazor/datasets/crd_dataset_wrapper.py new file mode 100644 index 000000000..308bc1e4c --- /dev/null +++ b/mmrazor/datasets/crd_dataset_wrapper.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import Any, Dict, List, Union + +import numpy as np +from mmengine.dataset.base_dataset import BaseDataset, force_full_init + +from mmrazor.registry import DATASETS + + +@DATASETS.register_module() +class CRDDataset: + """A wrapper of `CRD` dataset. + + Suitable for image classification datasets like CIFAR. Following + the sampling strategy in the `paper `_, + in each epoch, each data sample has contrast information. + Contrast information for an image is indices of negetive data samples. + Note: + ``CRDDataset`` should not inherit from ``BaseDataset`` + since ``get_subset`` and ``get_subset_`` could produce ambiguous + meaning sub-dataset which conflicts with original dataset. If you + want to use a sub-dataset of ``CRDDataset``, you should set + ``indices`` arguments for wrapped dataset which inherit from + ``BaseDataset``. + Args: + dataset (BaseDataset or dict): The dataset to be repeated. + neg_num (int): number of negetive data samples. + percent (float): sampling percentage. + lazy_init (bool, optional): whether to load annotation during + instantiation. Defaults to False + num_classes (int, optional): Number of classes. Defaults to None. + sample_mode (str, optional): Data sampling mode. Defaults to 'exact'. + """ + + def __init__(self, + dataset: Union[BaseDataset, dict], + neg_num: int, + percent: float, + lazy_init: bool = False, + num_classes: int = None, + sample_mode: str = 'exact') -> None: + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') + self._metainfo = self.dataset.metainfo + + self._fully_initialized = False + + # CRD unique attributes. + self.num_classes = num_classes + self.neg_num = neg_num + self.sample_mode = sample_mode + self.percent = percent + + if not lazy_init: + self.full_init() + + def _parse_fullset_contrast_info(self) -> None: + """parse contrast information of the whole dataset.""" + assert self.sample_mode in [ + 'exact', 'random' + ], ('`sample_mode` must in [`exact`, `random`], ' + f'but get `{self.sample_mode}`') + + # Handle special occasion: + # if dataset's ``CLASSES`` is not list of consecutive integers, + # e.g. [2, 3, 5]. + num_classes: int = self.num_classes # type: ignore + if num_classes is None: + num_classes = len(self.dataset.CLASSES) + + if not self.dataset.test_mode: # type: ignore + # Parse info. + self.gt_labels = self.dataset.get_gt_labels() + self.num_samples: int = self.dataset.__len__() + + self.cls_positive: List[List[int]] = [[] + for _ in range(num_classes) + ] # type: ignore + for i in range(self.num_samples): + self.cls_positive[self.gt_labels[i]].append(i) + + self.cls_negative: List[List[int]] = [[] + for i in range(num_classes) + ] # type: ignore + for i in range(num_classes): # type: ignore + for j in range(num_classes): # type: ignore + if j == i: + continue + self.cls_negative[i].extend(self.cls_positive[j]) + + self.cls_positive = [ + np.asarray(self.cls_positive[i]) + for i in range(num_classes) # type: ignore + ] + self.cls_negative = [ + np.asarray(self.cls_negative[i]) + for i in range(num_classes) # type: ignore + ] + + if 0 < self.percent < 1: + n = int(len(self.cls_negative[0]) * self.percent) + self.cls_negative = [ + np.random.permutation(self.cls_negative[i])[0:n] + for i in range(num_classes) # type: ignore + ] + + self.cls_positive = np.asarray(self.cls_positive) + self.cls_negative = np.asarray(self.cls_negative) + + @property + def metainfo(self) -> dict: + """Get the meta information of the repeated dataset. + + Returns: + dict: The meta information of repeated dataset. + """ + return copy.deepcopy(self._metainfo) + + def _get_contrast_info(self, data: Dict, idx: int) -> Dict: + """Get contrast information for each data sample.""" + if self.sample_mode == 'exact': + pos_idx = idx + elif self.sample_mode == 'random': + pos_idx = np.random.choice(self.cls_positive[self.gt_labels[idx]], + 1) + pos_idx = pos_idx[0] # type: ignore + else: + raise NotImplementedError(self.sample_mode) + replace = True if self.neg_num > \ + len(self.cls_negative[self.gt_labels[idx]]) else False + neg_idx = np.random.choice( + self.cls_negative[self.gt_labels[idx]], + self.neg_num, + replace=replace) + contrast_sample_idxs = np.hstack((np.asarray([pos_idx]), neg_idx)) + data['contrast_sample_idxs'] = contrast_sample_idxs + return data + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._parse_fullset_contrast_info() + + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> Dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + Returns: + dict: The idx-th annotation of the dataset. + """ + data_info = self.dataset.get_data_info(idx) # type: ignore + if not self.dataset.test_mode: # type: ignore + data_info = self._get_contrast_info(data_info, idx) + return data_info + + def prepare_data(self, idx) -> Any: + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + return self.dataset.pipeline(data_info) + + def __getitem__(self, idx: int) -> dict: + """Get the idx-th image and data information of dataset after + ``self.pipeline``, and ``full_init`` will be called if the dataset has + not been fully initialized. + + During training phase, if ``self.pipeline`` get ``None``, + ``self._rand_another`` will be called until a valid image is fetched or + the maximum limit of refetech is reached. + + Args: + idx (int): The index of self.data_list. + + Returns: + dict: The idx-th image and data information of dataset after + ``self.pipeline``. + """ + # Performing full initialization by calling `__getitem__` will consume + # extra memory. If a dataset is not fully initialized by setting + # `lazy_init=True` and then fed into the dataloader. Different workers + # will simultaneously read and parse the annotation. It will cost more + # time and memory, although this may work. Therefore, it is recommended + # to manually call `full_init` before dataset fed into dataloader to + # ensure all workers use shared RAM from master process. + if not self._fully_initialized: + warnings.warn( + 'Please call `full_init()` method manually to accelerate ' + 'the speed.') + self.full_init() + + if self.dataset.test_mode: + data = self.prepare_data(idx) + if data is None: + raise Exception('Test time pipline should not get `None` ' + 'data_sample') + return data + + for _ in range(self.dataset.max_refetch + 1): + data = self.prepare_data(idx) + # Broken images or random augmentations may cause the returned data + # to be None + if data is None: + idx = self.dataset._rand_another() + continue + return data + + raise Exception( + f'Cannot find valid image after {self.dataset.max_refetch}! ' + 'Please check your image path and pipeline') + + @force_full_init + def __len__(self): + return len(self.dataset) + + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') diff --git a/mmrazor/datasets/transforms/__init__.py b/mmrazor/datasets/transforms/__init__.py new file mode 100644 index 000000000..cb1bebc46 --- /dev/null +++ b/mmrazor/datasets/transforms/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .formatting import PackCRDClsInputs + +__all__ = ['PackCRDClsInputs'] diff --git a/mmrazor/datasets/transforms/formatting.py b/mmrazor/datasets/transforms/formatting.py new file mode 100644 index 000000000..d2ba63ddc --- /dev/null +++ b/mmrazor/datasets/transforms/formatting.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from mmcls.datasets.transforms.formatting import PackClsInputs, to_tensor + from mmcls.structures import ClsDataSample +except ImportError: + from mmrazor.utils import get_placeholder + PackClsInputs = get_placeholder('mmcls') + to_tensor = get_placeholder('mmcls') + ClsDataSample = get_placeholder('mmcls') + +import warnings +from typing import Any, Dict, Generator + +import numpy as np +import torch + +from mmrazor.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class PackCRDClsInputs(PackClsInputs): + + def transform(self, results: Dict) -> Dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`ClsDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + packed_results['inputs'] = to_tensor(img) + else: + warnings.warn( + 'Cannot get "img" in the input dict of `PackClsInputs`,' + 'please make sure `LoadImageFromFile` has been added ' + 'in the data pipeline or images have been loaded in ' + 'the dataset.') + + data_sample = ClsDataSample() + if 'gt_label' in results: + gt_label = results['gt_label'] + data_sample.set_gt_label(gt_label) + + if 'sample_idx' in results: + # transfer `sample_idx` to Tensor + self.meta_keys: Generator[Any, None, None] = ( + key for key in self.meta_keys if key != 'sample_idx') + value = results['sample_idx'] + if isinstance(value, int): + value = torch.tensor(value).to(torch.long) + data_sample.set_data(dict(sample_idx=value)) + + if 'contrast_sample_idxs' in results: + value = results['contrast_sample_idxs'] + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + data_sample.set_data(dict(contrast_sample_idxs=value)) + + img_meta = {k: results[k] for k in self.meta_keys if k in results} + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results diff --git a/mmrazor/models/architectures/connectors/__init__.py b/mmrazor/models/architectures/connectors/__init__.py index 962282d64..c12aa60d7 100644 --- a/mmrazor/models/architectures/connectors/__init__.py +++ b/mmrazor/models/architectures/connectors/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .byot_connector import BYOTConnector from .convmodule_connector import ConvModuleConncetor +from .crd_connector import CRDConnector from .factor_transfer_connectors import Paraphraser, Translator from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector from .ofd_connector import OFDTeacherConnector @@ -9,5 +10,5 @@ __all__ = [ 'ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector', 'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector', - 'TorchNNConnector', 'OFDTeacherConnector' + 'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector' ] diff --git a/mmrazor/models/architectures/connectors/crd_connector.py b/mmrazor/models/architectures/connectors/crd_connector.py new file mode 100644 index 000000000..48648c75d --- /dev/null +++ b/mmrazor/models/architectures/connectors/crd_connector.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS +from .base_connector import BaseConnector + + +@MODELS.register_module() +class CRDConnector(BaseConnector): + """Connector with linear layer. + + Args: + dim_in (int, optional): input channels. Defaults to 1024. + dim_out (int, optional): output channels. Defaults to 128. + """ + + def __init__(self, + dim_in: int = 1024, + dim_out: int = 128, + **kwargs) -> None: + super(CRDConnector, self).__init__(**kwargs) + self.linear = nn.Linear(dim_in, dim_out) + self.l2norm = Normalize(2) + + def forward_train(self, x: torch.Tensor) -> torch.Tensor: + x = x.view(x.size(0), -1) + x = self.linear(x) + x = self.l2norm(x) + return x + + +class Normalize(nn.Module): + """normalization layer. + + Args: + power (int, optional): power. Defaults to 2. + """ + + def __init__(self, power: int = 2) -> None: + super(Normalize, self).__init__() + self.power = power + + def forward(self, x: torch.Tensor) -> torch.Tensor: + norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) + out = x.div(norm) + return out diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index c0a751ec8..a145ba914 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss from .at_loss import ATLoss +from .crd_loss import CRDLoss from .cwd import ChannelWiseDivergence from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss from .decoupled_kd import DKDLoss @@ -18,5 +19,5 @@ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', - 'L1Loss', 'FBKDLoss' + 'L1Loss', 'FBKDLoss', 'CRDLoss' ] diff --git a/mmrazor/models/losses/crd_loss.py b/mmrazor/models/losses/crd_loss.py new file mode 100644 index 000000000..4ca81aaf5 --- /dev/null +++ b/mmrazor/models/losses/crd_loss.py @@ -0,0 +1,271 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Union + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class CRDLoss(nn.Module): + """Variate CRD Loss, ICLR 2020. + + https://arxiv.org/abs/1910.10699 + Args: + loss_weight (float, optional): loss weight. Defaults to 1.0. + temperature (float, optional): temperature. Defaults to 0.07. + neg_num (int, optional): number of negative samples. Defaults to 16384. + sample_n (int, optional): number of total samples. Defaults to 50000. + dim_out (int, optional): output channels. Defaults to 128. + momentum (float, optional): momentum. Defaults to 0.5. + eps (double, optional): eps. Defaults to 1e-7. + """ + + def __init__(self, + loss_weight: float = 1.0, + temperature=0.07, + neg_num=16384, + sample_n=50000, + dim_out=128, + momentum=0.5, + eps=1e-7): + super().__init__() + self.loss_weight = loss_weight + self.eps = eps + + self.contrast = ContrastMemory(dim_out, sample_n, neg_num, temperature, + momentum) + self.criterion_s_t = ContrastLoss(sample_n, eps=self.eps) + + def forward(self, s_feats, t_feats, data_samples): + input_data = data_samples[0] + assert 'sample_idx' in input_data, \ + 'you should pass a dict with key `sample_idx` in mimic function.' + assert isinstance( + input_data.sample_idx, torch.Tensor + ), f'`sample_idx` must be a tensor, but get {type(input_data.sample_idx)}' # noqa: E501 + + sample_idxs = torch.stack( + [sample.sample_idx for sample in data_samples]) + if 'contrast_sample_idxs' in input_data: + assert isinstance( + input_data.contrast_sample_idxs, torch.Tensor + ), f'`contrast_sample_idxs` must be a tensor, but get {type(input_data.contrast_sample_idxs)}' # noqa: E501 + contrast_sample_idxs = torch.stack( + [sample.contrast_sample_idxs for sample in data_samples]) + else: + contrast_sample_idxs = None + out_s, out_t = self.contrast(s_feats, t_feats, sample_idxs, + contrast_sample_idxs) + s_loss = self.criterion_s_t(out_s) + t_loss = self.criterion_s_t(out_t) + loss = s_loss + t_loss + return loss + + +class ContrastLoss(nn.Module): + """contrastive loss, corresponding to Eq (18) + + Args: + n_data (int): number of data + eps (float, optional): eps. Defaults to 1e-7. + """ + + def __init__(self, n_data: int, eps: float = 1e-7): + super(ContrastLoss, self).__init__() + self.n_data = n_data + self.eps = eps + + def forward(self, x): + bsz = x.shape[0] + m = x.size(1) - 1 + + # noise distribution + Pn = 1 / float(self.n_data) + + # loss for positive pair + P_pos = x.select(1, 0) + log_D1 = torch.div(P_pos, P_pos.add(m * Pn + self.eps)).log_() + + # loss for neg_sample negative pair + P_neg = x.narrow(1, 1, m) + log_D0 = torch.div(P_neg.clone().fill_(m * Pn), + P_neg.add(m * Pn + self.eps)).log_() + + loss = -(log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz + + return loss + + +class ContrastMemory(nn.Module): + """memory buffer that supplies large amount of negative samples. + + https://github.com/HobbitLong/RepDistiller/blob/master/crd/memory.py + + Args: + dim_out (int, optional): output channels. Defaults to 128. + n_sample (int, optional): number of total samples. + Defaults to 50000. + neg_sample (int, optional): number of negative samples. + Defaults to 16384. + T (float, optional): temperature. Defaults to 0.07. + momentum (float, optional): momentum. Defaults to 0.5. + """ + + def __init__(self, + dim_out: int, + n_sample: int, + neg_sample: int, + T: float = 0.07, + momentum: float = 0.5): + super(ContrastMemory, self).__init__() + self.n_sample = n_sample + self.unigrams = torch.ones(self.n_sample) + self.multinomial = AliasMethod(self.unigrams) + # self.multinomial.cuda() + self.neg_sample = neg_sample + + self.register_buffer('params', + torch.tensor([neg_sample, T, -1, -1, momentum])) + stdv = 1. / math.sqrt(dim_out / 3) + self.register_buffer( + 'memory_v1', + torch.rand(n_sample, dim_out).mul_(2 * stdv).add_(-stdv)) + self.register_buffer( + 'memory_v2', + torch.rand(n_sample, dim_out).mul_(2 * stdv).add_(-stdv)) + + def forward(self, + feat_s: torch.Tensor, + feat_t: torch.Tensor, + idx: torch.Tensor, + sample_idx: Union[None, torch.Tensor] = None) -> torch.Tensor: + neg_sample = int(self.params[0].item()) + T = self.params[1].item() + Z_s = self.params[2].item() + Z_t = self.params[3].item() + + momentum = self.params[4].item() + bsz = feat_s.size(0) + n_sample = self.memory_v1.size(0) + dim_out = self.memory_v1.size(1) + + # original score computation + if sample_idx is None: + sample_idx = self.multinomial.draw(bsz * (self.neg_sample + 1))\ + .view(bsz, -1) + sample_idx.select(1, 0).copy_(idx.data) + # sample + weight_s = torch.index_select(self.memory_v1, 0, + sample_idx.view(-1)).detach() + weight_s = weight_s.view(bsz, neg_sample + 1, dim_out) + out_t = torch.bmm(weight_s, feat_t.view(bsz, dim_out, 1)) + out_t = torch.exp(torch.div(out_t, T)) + # sample + weight_t = torch.index_select(self.memory_v2, 0, + sample_idx.view(-1)).detach() + weight_t = weight_t.view(bsz, neg_sample + 1, dim_out) + out_s = torch.bmm(weight_t, feat_s.view(bsz, dim_out, 1)) + out_s = torch.exp(torch.div(out_s, T)) + + # set Z if haven't been set yet + if Z_s < 0: + self.params[2] = out_s.mean() * n_sample + Z_s = self.params[2].clone().detach().item() + print('normalization constant Z_s is set to {:.1f}'.format(Z_s)) + if Z_t < 0: + self.params[3] = out_t.mean() * n_sample + Z_t = self.params[3].clone().detach().item() + print('normalization constant Z_t is set to {:.1f}'.format(Z_t)) + + # compute out_s, out_t + out_s = torch.div(out_s, Z_s).contiguous() + out_t = torch.div(out_t, Z_t).contiguous() + + # update memory + with torch.no_grad(): + l_pos = torch.index_select(self.memory_v1, 0, idx.view(-1)) + l_pos.mul_(momentum) + l_pos.add_(torch.mul(feat_s, 1 - momentum)) + l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) + updated_v1 = l_pos.div(l_norm) + self.memory_v1.index_copy_(0, idx, updated_v1) + + ab_pos = torch.index_select(self.memory_v2, 0, idx.view(-1)) + ab_pos.mul_(momentum) + ab_pos.add_(torch.mul(feat_t, 1 - momentum)) + ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) + updated_v2 = ab_pos.div(ab_norm) + self.memory_v2.index_copy_(0, idx, updated_v2) + + return out_s, out_t + + +class AliasMethod(object): + """ + From: https://hips.seas.harvard.edu/blog/2013/03/03/ + the-alias-method-efficient-sampling-with-many-discrete-outcomes/ + + Args: + probs (torch.Tensor): probility vector. + """ + + def __init__(self, probs: torch.Tensor) -> None: + + if probs.sum() > 1: + probs.div_(probs.sum()) + neg_sample = len(probs) + self.prob = torch.zeros(neg_sample) + self.alias = torch.LongTensor([0] * neg_sample) + + # Sort the data into the outcomes with probabilities + # that are larger and smaller than 1/neg_sample. + smaller = [] + larger = [] + for kk, prob in enumerate(probs): + self.prob[kk] = neg_sample * prob + if self.prob[kk] < 1.0: + smaller.append(kk) + else: + larger.append(kk) + + # Loop though and create little binary mixtures that + # appropriately allocate the larger outcomes over the + # overall uniform mixture. + while len(smaller) > 0 and len(larger) > 0: + small = smaller.pop() + large = larger.pop() + + self.alias[small] = large + self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] + + if self.prob[large] < 1.0: + smaller.append(large) + else: + larger.append(large) + + for last_one in smaller + larger: + self.prob[last_one] = 1 + + def cuda(self): + """To cuda device.""" + self.prob = self.prob.cuda() + self.alias = self.alias.cuda() + + def draw(self, N: int) -> torch.Tensor: + """Draw N samples from multinomial.""" + neg_sample = self.alias.size(0) + + kk = torch.zeros( + N, dtype=torch.long, + device=self.prob.device).random_(0, neg_sample) + prob = self.prob.index_select(0, kk) + alias = self.alias.index_select(0, kk) + # b is whether a random number is greater than q + b = torch.bernoulli(prob) + oq = kk.mul(b.long()) + oj = alias.mul((1 - b).long()) + + return oq + oj diff --git a/mmrazor/utils/placeholder.py b/mmrazor/utils/placeholder.py index 622e687d1..c81979000 100644 --- a/mmrazor/utils/placeholder.py +++ b/mmrazor/utils/placeholder.py @@ -5,8 +5,10 @@ def get_placeholder(string: str) -> object: Args: string (str): the dependency's name, i.e. `mmcls` + Raises: ImportError: raise it when the dependency is not installed properly. + Returns: object: PlaceHolder instance. """ diff --git a/mmrazor/utils/setup_env.py b/mmrazor/utils/setup_env.py index 392658f84..385be8624 100644 --- a/mmrazor/utils/setup_env.py +++ b/mmrazor/utils/setup_env.py @@ -61,6 +61,7 @@ def register_all_modules(init_default_scope: bool = True) -> None: Defaults to True. """ # noqa + import mmrazor.datasets # noqa: F401,F403 import mmrazor.engine # noqa: F401,F403 import mmrazor.models # noqa: F401,F403 import mmrazor.structures # noqa: F401,F403 diff --git a/tests/data/dataset/a/1.JPG b/tests/data/dataset/a/1.JPG new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/dataset/ann.json b/tests/data/dataset/ann.json new file mode 100644 index 000000000..a55539329 --- /dev/null +++ b/tests/data/dataset/ann.json @@ -0,0 +1,28 @@ +{ + "metainfo": { + "categories": [ + { + "category_name": "first", + "id": 0 + }, + { + "category_name": "second", + "id": 1 + } + ] + }, + "data_list": [ + { + "img_path": "a/1.JPG", + "gt_label": 0 + }, + { + "img_path": "b/2.jpeg", + "gt_label": 1 + }, + { + "img_path": "b/subb/2.jpeg", + "gt_label": 1 + } + ] +} diff --git a/tests/data/dataset/ann.txt b/tests/data/dataset/ann.txt new file mode 100644 index 000000000..f929e873b --- /dev/null +++ b/tests/data/dataset/ann.txt @@ -0,0 +1,3 @@ +a/1.JPG 0 +b/2.jpeg 1 +b/subb/3.jpg 1 diff --git a/tests/data/dataset/b/2.jpeg b/tests/data/dataset/b/2.jpeg new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/dataset/b/subb/3.jpg b/tests/data/dataset/b/subb/3.jpg new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/dataset/classes.txt b/tests/data/dataset/classes.txt new file mode 100644 index 000000000..c012a51e6 --- /dev/null +++ b/tests/data/dataset/classes.txt @@ -0,0 +1,2 @@ +bus +car diff --git a/tests/data/dataset/multi_label_ann.json b/tests/data/dataset/multi_label_ann.json new file mode 100644 index 000000000..5cd8a84d0 --- /dev/null +++ b/tests/data/dataset/multi_label_ann.json @@ -0,0 +1,28 @@ +{ + "metainfo": { + "categories": [ + { + "category_name": "first", + "id": 0 + }, + { + "category_name": "second", + "id": 1 + } + ] + }, + "data_list": [ + { + "img_path": "a/1.JPG", + "gt_label": [0] + }, + { + "img_path": "b/2.jpeg", + "gt_label": [1] + }, + { + "img_path": "b/subb/2.jpeg", + "gt_label": [0, 1] + } + ] +} diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py new file mode 100644 index 000000000..1e6031a97 --- /dev/null +++ b/tests/test_datasets/test_datasets.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import pickle +import tempfile +from unittest import TestCase + +import numpy as np + +from mmrazor.registry import DATASETS +from mmrazor.utils import register_all_modules + +register_all_modules() +ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset')) + + +class Test_CRD_CIFAR10(TestCase): + DATASET_TYPE = 'CRD_CIFAR10' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + data_prefix = tmpdir.name + cls.DEFAULT_ARGS = dict( + data_prefix=data_prefix, pipeline=[], test_mode=False) + + dataset_class = DATASETS.get(cls.DATASET_TYPE) + base_folder = osp.join(data_prefix, dataset_class.base_folder) + os.mkdir(base_folder) + + cls.fake_imgs = np.random.randint( + 0, 255, size=(6, 3 * 32 * 32), dtype=np.uint8) + cls.fake_labels = np.random.randint(0, 10, size=(6, )) + cls.fake_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + batch1 = dict( + data=cls.fake_imgs[:2], labels=cls.fake_labels[:2].tolist()) + with open(osp.join(base_folder, 'data_batch_1'), 'wb') as f: + f.write(pickle.dumps(batch1)) + + batch2 = dict( + data=cls.fake_imgs[2:4], labels=cls.fake_labels[2:4].tolist()) + with open(osp.join(base_folder, 'data_batch_2'), 'wb') as f: + f.write(pickle.dumps(batch2)) + + test_batch = dict( + data=cls.fake_imgs[4:], fine_labels=cls.fake_labels[4:].tolist()) + with open(osp.join(base_folder, 'test_batch'), 'wb') as f: + f.write(pickle.dumps(test_batch)) + + meta = {dataset_class.meta['key']: cls.fake_classes} + meta_filename = dataset_class.meta['filename'] + with open(osp.join(base_folder, meta_filename), 'wb') as f: + f.write(pickle.dumps(meta)) + + dataset_class.train_list = [['data_batch_1', None], + ['data_batch_2', None]] + dataset_class.test_list = [['test_batch', None]] + dataset_class.meta['md5'] = None + + def test_initialize(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test overriding metainfo by `metainfo` argument + cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + + # Test overriding metainfo by `classes` argument + cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + + classes_file = osp.join(ASSETS_ROOT, 'classes.txt') + cfg = {**self.DEFAULT_ARGS, 'classes': classes_file} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, ('bus', 'car')) + self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1}) + + # Test invalid classes + cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)} + with self.assertRaisesRegex(ValueError, "type "): + dataset_class(**cfg) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup() + + +class Test_CRD_CIFAR100(Test_CRD_CIFAR10): + DATASET_TYPE = 'CRD_CIFAR100' diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py new file mode 100644 index 000000000..46aa671df --- /dev/null +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +import unittest + +import numpy as np +import torch +from mmcls.structures import ClsDataSample +from mmengine.data import LabelData + +from mmrazor.datasets.transforms import PackCRDClsInputs + + +class TestPackClsInputs(unittest.TestCase): + + def setUp(self): + """Setup the model and optimizer which are used in every test method. + + TestCase calls functions in this order: setUp() -> testMethod() -> + tearDown() -> cleanUp() + """ + data_prefix = osp.join(osp.dirname(__file__), '../../data') + img_path = osp.join(data_prefix, 'color.jpg') + rng = np.random.RandomState(0) + self.results1 = { + 'sample_idx': 1, + 'img_path': img_path, + 'ori_height': 300, + 'ori_width': 400, + 'height': 600, + 'width': 800, + 'scale_factor': 2.0, + 'flip': False, + 'img': rng.rand(300, 400), + 'gt_label': rng.randint(3, ), + # TODO. + 'contrast_sample_idxs': rng.randint() + } + self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip') + + def test_transform(self): + transform = PackCRDClsInputs(meta_keys=self.meta_keys) + results = transform(copy.deepcopy(self.results1)) + self.assertIn('inputs', results) + self.assertIsInstance(results['inputs'], torch.Tensor) + self.assertIn('data_sample', results) + self.assertIsInstance(results['data_sample'], ClsDataSample) + + data_sample = results['data_sample'] + self.assertIsInstance(data_sample.gt_label, LabelData) + + def test_repr(self): + transform = PackCRDClsInputs(meta_keys=self.meta_keys) + self.assertEqual( + repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})') diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index 56abd0c42..5d44efd75 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -3,7 +3,7 @@ import torch -from mmrazor.models import (BYOTConnector, ConvModuleConncetor, +from mmrazor.models import (BYOTConnector, ConvModuleConncetor, CRDConnector, FBKDStudentConnector, FBKDTeacherConnector, Paraphraser, TorchFunctionalConnector, TorchNNConnector, Translator) @@ -40,6 +40,23 @@ def test_convmodule_connector(self): with self.assertRaises(AssertionError): _ = ConvModuleConncetor(**convmodule_connector_cfg) + def test_crd_connector(self): + dim_out = 128 + crd_stu_connector = CRDConnector( + **dict(dim_in=1 * 5 * 5, dim_out=dim_out)) + + crd_tea_connector = CRDConnector( + **dict(dim_in=3 * 5 * 5, dim_out=dim_out)) + + assert crd_stu_connector.linear.in_features == 1 * 5 * 5 + assert crd_stu_connector.linear.out_features == dim_out + assert crd_tea_connector.linear.in_features == 3 * 5 * 5 + assert crd_tea_connector.linear.out_features == dim_out + + s_output = crd_stu_connector.forward_train(self.s_feat) + t_output = crd_tea_connector.forward_train(self.t_feat) + assert s_output.size() == t_output.size() + def test_ft_connector(self): stu_connector = Translator(**dict(in_channel=1, out_channel=2)) diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index c3ab16949..4328f7865 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -2,11 +2,12 @@ from unittest import TestCase import torch +from mmengine.data import BaseDataElement from mmrazor import digit_version -from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, DKDLoss, FBKDLoss, - FTLoss, InformationEntropyLoss, KDSoftCELoss, - OFDLoss, OnehotLikeLoss) +from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss, + FBKDLoss, FTLoss, InformationEntropyLoss, + KDSoftCELoss, OFDLoss, OnehotLikeLoss) class TestLosses(TestCase): @@ -69,6 +70,34 @@ def test_ab_loss(self): self.normal_test_2d(ab_loss) self.normal_test_3d(ab_loss) + def _mock_crd_data_sample(self, sample_idx_list): + data_samples = [] + for _idx in sample_idx_list: + data_sample = BaseDataElement() + data_sample.set_data(dict(sample_idx=_idx)) + data_samples.append(data_sample) + return data_samples + + def test_crd_loss(self): + crd_loss = CRDLoss(**dict(neg_num=5, sample_n=10, dim_out=6)) + sample_idx_list = torch.tensor(list(range(5))) + data_samples = self._mock_crd_data_sample(sample_idx_list) + loss = crd_loss.forward(self.feats_1d, self.feats_1d, data_samples) + self.assertTrue(loss.numel() == 1) + + # test the calculation + s_feat_0 = torch.randn((5, 6)) + t_feat_0 = torch.randn((5, 6)) + crd_loss_num_0 = crd_loss.forward(s_feat_0, t_feat_0, data_samples) + assert crd_loss_num_0 != torch.tensor(0.0) + + s_feat_1 = torch.randn((5, 6)) + t_feat_1 = torch.rand((5, 6)) + sample_idx_list_1 = torch.tensor(list(range(5))) + data_samples_1 = self._mock_crd_data_sample(sample_idx_list_1) + crd_loss_num_1 = crd_loss.forward(s_feat_1, t_feat_1, data_samples_1) + assert crd_loss_num_1 != torch.tensor(0.0) + def test_dkd_loss(self): dkd_loss_cfg = dict(loss_weight=1.0) dkd_loss = DKDLoss(**dkd_loss_cfg) From 4e800373935a6b78d39c9616e2f341516088e5ec Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Wed, 14 Sep 2022 20:39:49 +0800 Subject: [PATCH 2/7] [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 --- .../spos/spos_mobilenet_search_8xb128_in1k.py | 2 +- .../spos_shufflenet_search_8xb128_in1k.py | 2 +- .../detnas_frcnn_shufflenet_search_coco_1x.py | 4 +- .../engine/runner/evolution_search_loop.py | 45 ++-- mmrazor/engine/runner/subnet_sampler_loop.py | 41 ++-- mmrazor/engine/runner/utils/__init__.py | 3 +- mmrazor/engine/runner/utils/check.py | 48 ++++ .../task_modules/estimators/base_estimator.py | 44 ++-- .../estimators/counters/__init__.py | 10 +- .../counters/flops_params_counter.py | 103 +++++--- .../estimators/counters/latency_counter.py | 71 ++++-- .../estimators/resource_estimator.py | 232 +++++++++++------- .../single_stage_detector_loss_calculator.py | 7 +- .../test_estimators/test_flops_params.py | 63 +++-- .../test_evolution_search_loop.py | 7 +- .../test_runners/test_subnet_sampler_loop.py | 23 +- tests/test_runners/test_utils/test_check.py | 40 +++ 17 files changed, 466 insertions(+), 279 deletions(-) create mode 100644 mmrazor/engine/runner/utils/check.py create mode 100644 tests/test_runners/test_utils/test_check.py diff --git a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py index f670e5c0c..4f5edb316 100644 --- a/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 465 * 1e6), + flops_range=(0., 465.), score_key='accuracy/top1') diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py index 6f8dc9366..f3f963e40 100644 --- a/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py @@ -13,5 +13,5 @@ num_mutation=25, num_crossover=25, mutate_prob=0.1, - flops_range=(0., 330 * 1e6), + flops_range=(0., 330.), score_key='accuracy/top1') diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py index 0bd3b71fe..d1dd1637a 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py @@ -13,5 +13,5 @@ num_mutation=20, num_crossover=20, mutate_prob=0.1, - flops_range=None, - score_key='bbox_mAP') + flops_range=(0., 300.), + score_key='coco/bbox_mAP') diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a72704e40..a9a76b383 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import os import os.path as osp import random @@ -14,11 +13,11 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules.estimators import get_model_complexity_info +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.structures import Candidates, export_fix_subnet from mmrazor.utils import SupportRandomSubnet -from .utils import crossover +from .utils import check_subnet_flops, crossover @LOOPS.register_module() @@ -42,10 +41,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): num_crossover (int): The number of candidates got by crossover. Defaults to 25. mutate_prob (float): The probability of mutation. Defaults to 0.1. - flops_range (tuple, optional): flops_range to be used for screening - candidates. - spec_modules (list): Used for specify modules need to counter. - Defaults to list(). + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -65,8 +64,8 @@ def __init__(self, num_mutation: int = 25, num_crossover: int = 25, mutate_prob: float = 0.1, - flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), - spec_modules: List = [], + flops_range: Optional[Tuple[float, float]] = (0., 330.), + resource_estimator_cfg: Optional[dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -85,7 +84,6 @@ def __init__(self, self.num_candidates = num_candidates self.top_k = top_k self.flops_range = flops_range - self.spec_modules = spec_modules self.score_key = score_key self.num_mutation = num_mutation self.num_crossover = num_crossover @@ -101,6 +99,10 @@ def __init__(self, correct init candidates file' self.top_k_candidates = Candidates() + if resource_estimator_cfg is None: + self.estimator = ResourceEstimator() + else: + self.estimator = ResourceEstimator(**resource_estimator_cfg) if self.runner.distributed: self.model = runner.model.module @@ -299,17 +301,10 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: Returns: bool: The result of checking. """ - if self.flops_range is None: - return True - - self.model.set_subnet(random_subnet) - fix_mutable = export_fix_subnet(self.model) - copied_model = copy.deepcopy(self.model) - load_fix_subnet(copied_model, fix_mutable) - flops, _ = get_model_complexity_info( - copied_model, spec_modules=self.spec_modules) - - if self.flops_range[0] <= flops <= self.flops_range[1]: - return True - else: - return False + is_pass = check_subnet_flops( + model=self.model, + subnet=random_subnet, + estimator=self.estimator, + flops_range=self.flops_range) + + return is_pass diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index c2b4d2176..1127aab21 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import math import os import random @@ -13,10 +12,11 @@ from mmengine.utils import is_list_of from torch.utils.data import DataLoader -from mmrazor.models.task_modules.estimators import get_model_complexity_info +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet +from mmrazor.structures import Candidates from mmrazor.utils import SupportRandomSubnet +from .utils import check_subnet_flops class BaseSamplerTrainLoop(IterBasedTrainLoop): @@ -103,8 +103,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. flops_range (dict): Constraints to be used for screening candidates. - spec_modules (list): Used for specify modules need to counter. - Defaults to list(). + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -138,8 +139,8 @@ def __init__(self, val_begin: int = 1, val_interval: int = 1000, score_key: str = 'accuracy/top1', - flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), - spec_modules: List = [], + flops_range: Optional[Tuple[float, float]] = (0., 330), + resource_estimator_cfg: Optional[dict] = None, num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -163,7 +164,6 @@ def __init__(self, self.score_key = score_key self.flops_range = flops_range - self.spec_modules = spec_modules self.num_candidates = num_candidates self.num_samples = num_samples self.top_k = top_k @@ -177,6 +177,10 @@ def __init__(self, self.candidates = Candidates() self.top_k_candidates = Candidates() + if resource_estimator_cfg is None: + self.estimator = ResourceEstimator() + else: + self.estimator = ResourceEstimator(**resource_estimator_cfg) def run(self) -> None: """Launch training.""" @@ -317,20 +321,13 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: Returns: bool: The result of checking. """ - if self.flops_range is None: - return True - - self.model.set_subnet(random_subnet) - fix_mutable = export_fix_subnet(self.model) - copied_model = copy.deepcopy(self.model) - load_fix_subnet(copied_model, fix_mutable) - flops, _ = get_model_complexity_info( - copied_model, spec_modules=self.spec_modules) - - if self.flops_range[0] <= flops <= self.flops_range[1]: - return True - else: - return False + is_pass = check_subnet_flops( + model=self.model, + subnet=random_subnet, + estimator=self.estimator, + flops_range=self.flops_range) + + return is_pass def _save_candidates(self) -> None: """Save the candidates to init the next searching.""" diff --git a/mmrazor/engine/runner/utils/__init__.py b/mmrazor/engine/runner/utils/__init__.py index 7aaf29539..ec2f2cb29 100644 --- a/mmrazor/engine/runner/utils/__init__.py +++ b/mmrazor/engine/runner/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .check import check_subnet_flops from .genetic import crossover -__all__ = ['crossover'] +__all__ = ['crossover', 'check_subnet_flops'] diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py new file mode 100644 index 000000000..e2fdcfcc6 --- /dev/null +++ b/mmrazor/engine/runner/utils/check.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Tuple + +import torch.nn as nn + +from mmrazor.models import ResourceEstimator +from mmrazor.structures import export_fix_subnet, load_fix_subnet +from mmrazor.utils import SupportRandomSubnet + +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + + +def check_subnet_flops( + model: nn.Module, + subnet: SupportRandomSubnet, + estimator: ResourceEstimator, + flops_range: Optional[Tuple[float, float]] = None) -> bool: + """Check whether is beyond flops constraints. + + Returns: + bool: The result of checking. + """ + if flops_range is None: + return True + + assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') + model.set_subnet(subnet) + fix_mutable = export_fix_subnet(model) + copied_model = copy.deepcopy(model) + load_fix_subnet(copied_model, fix_mutable) + + model_to_check = model.architecture + if isinstance(model_to_check, BaseDetector): + results = estimator.estimate(model=model_to_check.backbone) + else: + results = estimator.estimate(model=model_to_check) + + flops = results['flops'] + flops_mix, flops_max = flops_range + if flops_mix <= flops <= flops_max: # type: ignore + return True + else: + return False diff --git a/mmrazor/models/task_modules/estimators/base_estimator.py b/mmrazor/models/task_modules/estimators/base_estimator.py index 22a82d105..1a6f69264 100644 --- a/mmrazor/models/task_modules/estimators/base_estimator.py +++ b/mmrazor/models/task_modules/estimators/base_estimator.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Dict, Tuple, Union import torch.nn @@ -12,44 +12,40 @@ class BaseEstimator(metaclass=ABCMeta): """The base class of Estimator, used for estimating model infos. Args: - default_shape (tuple): Input data's default shape, for calculating + input_shape (tuple): Input data's default shape, for calculating resources consume. Defaults to (1, 3, 224, 224). - units (str): Resource units. Defaults to 'M'. - disabled_counters (list): List of disabled spec op counters. - Defaults to None. + units (dict): A dict including required units. Default to dict(). as_strings (bool): Output FLOPs and params counts in a string form. Default to False. - measure_inference (bool): whether to measure infer speed or not. - Default to False. """ def __init__(self, - default_shape: Tuple = (1, 3, 224, 224), - units: str = 'M', - disabled_counters: List[str] = None, - as_strings: bool = False, - measure_inference: bool = False): - assert len(default_shape) in [3, 4, 5], \ - f'Unsupported shape: {default_shape}' - self.default_shape = default_shape + input_shape: Tuple = (1, 3, 224, 224), + units: Dict = dict(), + as_strings: bool = False): + assert len(input_shape) in [ + 3, 4, 5 + ], ('The length of input_shape must be in [3, 4, 5]. ' + f'Got `{len(input_shape)}`.') + self.input_shape = input_shape self.units = units - self.disabled_counters = disabled_counters self.as_strings = as_strings - self.measure_inference = measure_inference @abstractmethod - def estimate( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, float]: + def estimate(self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. Args: model: The measured model. - resource_args (Dict[str, float]): resources information. - NOTE: resource_args have the same items() as the init cfgs. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. Returns: - Dict[str, float]): A dict that containing resource results(flops, - params and latency). + Dict[str, Union[float, str]]): A dict that contains the resource + results(FLOPs, params and latency). """ pass diff --git a/mmrazor/models/task_modules/estimators/counters/__init__.py b/mmrazor/models/task_modules/estimators/counters/__init__.py index 0a6adee48..721987ec1 100644 --- a/mmrazor/models/task_modules/estimators/counters/__init__.py +++ b/mmrazor/models/task_modules/estimators/counters/__init__.py @@ -1,10 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .flops_params_counter import (get_model_complexity_info, - params_units_convert) -from .latency_counter import repeat_measure_inference_speed +from .flops_params_counter import get_model_flops_params +from .latency_counter import get_model_latency from .op_counters import * # noqa: F401,F403 -__all__ = [ - 'get_model_complexity_info', 'params_units_convert', - 'repeat_measure_inference_speed' -] +__all__ = ['get_model_flops_params', 'get_model_latency'] diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index 31e998a2a..f31208248 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys from functools import partial +from typing import Dict import torch import torch.nn as nn @@ -8,19 +9,21 @@ from mmrazor.registry import TASK_UTILS -def get_model_complexity_info(model, - input_shape=(1, 3, 224, 224), - spec_modules=[], - disabled_counters=[], - print_per_layer_stat=False, - as_strings=False, - input_constructor=None, - flush=False, - ost=sys.stdout): - """Get complexity information of a model. This method can calculate FLOPs - and parameter counts of a model with corresponding input shape. It can also - print complexity information for each layer in a model. Supported layers - are listed as below: +def get_model_flops_params(model, + input_shape=(1, 3, 224, 224), + spec_modules=[], + disabled_counters=[], + print_per_layer_stat=False, + units=dict(flops='M', params='M'), + as_strings=False, + seperate_return: bool = False, + input_constructor=None, + flush=False, + ost=sys.stdout): + """Get FLOPs and parameters of a model. This method can calculate FLOPs and + parameter counts of a model with corresponding input shape. It can also + print FLOPs and params for each layer in a model. Supported layers are + listed as below: - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, @@ -39,16 +42,20 @@ def get_model_complexity_info(model, Args: model (nn.Module): The model for complexity calculation. input_shape (tuple): Input shape (including batchsize) used for - calculation. Default to (1, 3, 224, 224) + calculation. Default to (1, 3, 224, 224). spec_modules (list): A list that contains the names of several spec modules, which users want to get resources infos of them. e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. disabled_counters (list): One can limit which ops' spec would be calculated. Default to []. - print_per_layer_stat (bool): Whether to print complexity information + print_per_layer_stat (bool): Whether to print FLOPs and params for each layer in a model. Default to True. + units (dict): A dict including converted FLOPs and params units. + Default to dict(flops='M', params='M'). as_strings (bool): Output FLOPs and params counts in a string form. Default to True. + seperate_return (bool): Whether to return the resource information + separately. Default to False. input_constructor (None | callable): If specified, it takes a callable method that generates input. otherwise, it will generate a random tensor with input shape to calculate FLOPs. Default to None. @@ -60,12 +67,16 @@ def get_model_complexity_info(model, tuple[float | str] | dict[str, float]: If `as_strings` is set to True, it will return FLOPs and parameter counts in a string format. Otherwise, it will return those in a float number format. - If len(spec_modules) > 0, it will return a resource info dict with - FLOPs and parameter counts of each spec module in float format. + NOTE: If seperate_return, it will return a resource info dict with + FLOPs & params counts of each spec module in float|string format. """ assert type(input_shape) is tuple assert len(input_shape) >= 1 assert isinstance(model, nn.Module) + if seperate_return and not len(spec_modules): + raise AssertionError('`seperate_return` can only be set to True when ' + '`spec_modules` are not empty.') + flops_params_model = add_flops_params_counting_methods(model) flops_params_model.eval() flops_params_model.start_flops_params_count(disabled_counters) @@ -96,34 +107,44 @@ def get_model_complexity_info(model, ost=ost, flush=flush) + if units is not None: + flops_count = params_units_convert(flops_count, units['flops']) + params_count = params_units_convert(params_count, units['params']) + + if as_strings: + flops_suffix = ' ' + units['flops'] + 'FLOPs' if units else ' FLOPs' + params_suffix = ' ' + units['params'] if units else '' + if len(spec_modules): + flops_count, params_count = 0.0, 0.0 module_names = [name for name, _ in flops_params_model.named_modules()] for module in spec_modules: assert module in module_names, \ f'All modules in spec_modules should be in the measured ' \ - f'flops_params_model. Got module {module} in spec_modules.' - spec_modules_resources = dict() - accumulate_sub_module_flops_params(flops_params_model) + f'flops_params_model. Got module `{module}` in spec_modules.' + spec_modules_resources: Dict[str, dict] = dict() + accumulate_sub_module_flops_params(flops_params_model, units=units) for name, module in flops_params_model.named_modules(): if name in spec_modules: spec_modules_resources[name] = dict() spec_modules_resources[name]['flops'] = module.__flops__ spec_modules_resources[name]['params'] = module.__params__ + flops_count += module.__flops__ + params_count += module.__params__ if as_strings: - spec_modules_resources[name]['flops'] = str( - params_units_convert(module.__flops__, - 'G')) + ' GFLOPs' - spec_modules_resources[name]['params'] = str( - params_units_convert(module.__params__, 'M')) + ' M' + spec_modules_resources[name]['flops'] = \ + str(module.__flops__) + flops_suffix + spec_modules_resources[name]['params'] = \ + str(module.__params__) + params_suffix flops_params_model.stop_flops_params_count() - if len(spec_modules): + if seperate_return: return spec_modules_resources if as_strings: - flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs' - params_string = str(params_units_convert(params_count, 'M')) + ' M' + flops_string = str(flops_count) + flops_suffix + params_string = str(params_count) + params_suffix return flops_string, params_string return flops_count, params_count @@ -164,7 +185,7 @@ def params_units_convert(num_params, units='M', precision=3): def print_model_with_flops_params(model, total_flops, total_params, - units='G', + units=dict(flops='M', params='M'), precision=3, ost=sys.stdout, flush=False): @@ -174,7 +195,9 @@ def print_model_with_flops_params(model, model (nn.Module): The model to be printed. total_flops (float): Total FLOPs of the model. total_params (float): Total parameter counts of the model. - units (str | None): Converted FLOPs units. Default to 'G'. + units (tuple | none): A tuple pair including converted FLOPs & params + units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'. + Default to ('M', 'M'). precision (int): Digit number after the decimal point. Default to 3. ost (stream): same as `file` param in :func:`print`. Default to sys.stdout. @@ -200,8 +223,8 @@ def print_model_with_flops_params(model, >>> return x >>> model = ExampleModel() >>> x = (3, 16, 16) - to print the complexity information state for each layer, you can use - >>> get_model_complexity_info(model, x) + to print the FLOPs and params state for each layer, you can use + >>> get_model_flops_params(model, x) or directly use >>> print_model_with_flops_params(model, 4579784.0, 37361) ExampleModel( @@ -241,11 +264,11 @@ def flops_repr(self): accumulated_flops_cost = self.accumulate_flops() flops_string = str( params_units_convert( - accumulated_flops_cost, units=units, - precision=precision)) + ' ' + units + 'FLOPs' + accumulated_flops_cost, units['flops'], + precision=precision)) + ' ' + units['flops'] + 'FLOPs' params_string = str( - params_units_convert( - accumulated_num_params, units='M', precision=precision)) + ' M' + params_units_convert(accumulated_num_params, units['params'], + precision)) + ' M' return ', '.join([ params_string, '{:.3%} Params'.format(accumulated_num_params / total_params), @@ -277,12 +300,15 @@ def del_extra_repr(m): model.apply(del_extra_repr) -def accumulate_sub_module_flops_params(model): +def accumulate_sub_module_flops_params(model, units=None): """Accumulate FLOPs and params for each module in the model. Each module in the model will have the `__flops__` and `__params__` parameters. Args: model (nn.Module): The model to be accumulated. + units (tuple | none): A tuple pair including converted FLOPs & params + units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'. + Default to None. """ def accumulate_params(module): @@ -310,6 +336,9 @@ def accumulate_flops(module): _params = accumulate_params(module) module.__flops__ = _flops module.__params__ = _params + if units is not None: + module.__flops__ = params_units_convert(_flops, units['flops']) + module.__params__ = params_units_convert(_params, units['params']) def get_model_parameters_number(model): diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py index e3e91c54e..a4241e313 100644 --- a/mmrazor/models/task_modules/estimators/counters/latency_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -1,71 +1,89 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging import time -from typing import Any, Dict +from typing import Tuple, Union import torch from mmengine.logging import print_log -def repeat_measure_inference_speed(model: torch.nn.Module, - resource_args: Dict[str, Any], - max_iter: int = 100, - num_warmup: int = 5, - log_interval: int = 100, - repeat_num: int = 1) -> float: +def get_model_latency(model: torch.nn.Module, + input_shape: Tuple = (1, 3, 224, 224), + unit: str = 'ms', + as_strings: bool = False, + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100, + repeat_num: int = 1) -> Union[float, str]: """Repeat speed measure for multi-times to get more precise results. Args: model (torch.nn.Module): The measured model. - resource_args (Dict[str, float]): resources information. - max_iter (Optional[int]): Max iteration num for inference speed test. + input_shape (tuple): Input shape (including batchsize) used for + calculation. Default to (1, 3, 224, 224). + unit (str): Unit of latency in string format. Default to 'ms'. + as_strings (bool): Output latency counts in a string form. + Default to False. + max_iter (Optional[int]): Max iteration num for the measurement. + Default to 100. num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. log_interval (Optional[int]): Interval num for logging the results. + Default to 100. repeat_num (Optional[int]): Num of times to repeat the measurement. + Default to 1. Returns: - fps (float): The measured inference speed of the model. + latency (Union[float, str]): The measured inference speed of the model. + if ``as_strings=True``, it will return latency in string format. """ assert repeat_num >= 1 fps_list = [] for _ in range(repeat_num): - fps_list.append( - measure_inference_speed(model, resource_args, max_iter, num_warmup, - log_interval)) + _get_model_latency(model, input_shape, max_iter, num_warmup, + log_interval)) + + latency = round(1000 / fps_list[0], 1) if repeat_num > 1: - fps_list_ = [round(fps, 1) for fps in fps_list] + _fps_list = [round(fps, 1) for fps in fps_list] times_per_img_list = [round(1000 / fps, 1) for fps in fps_list] - mean_fps_ = sum(fps_list_) / len(fps_list_) + _mean_fps = sum(_fps_list) / len(_fps_list) mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list) print_log( - f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, ' + f'Overall fps: {_fps_list}[{_mean_fps:.1f}] img / s, ' f'times per image: ' f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img', logger='current', level=logging.DEBUG) - return mean_times_per_img + latency = mean_times_per_img + + if as_strings: + latency = str(latency) + ' ' + unit # type: ignore - latency = round(1000 / fps_list[0], 1) return latency -def measure_inference_speed(model: torch.nn.Module, - resource_args: Dict[str, Any], - max_iter: int = 100, - num_warmup: int = 5, - log_interval: int = 100) -> float: +def _get_model_latency(model: torch.nn.Module, + input_shape: Tuple = (1, 3, 224, 224), + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100) -> float: """Measure inference speed on GPU devices. Args: model (torch.nn.Module): The measured model. - resource_args (Dict[str, float]): resources information. - max_iter (Optional[int]): Max iteration num for inference speed test. + input_shape (tuple): Input shape (including batchsize) used for + calculation. Default to (1, 3, 224, 224). + max_iter (Optional[int]): Max iteration num for the measurement. + Default to 100. num_warmup (Optional[int]): Iteration num for warm-up stage. + Default to 5. log_interval (Optional[int]): Interval num for logging the results. + Default to 100. Returns: fps (float): The measured inference speed of the model. @@ -78,10 +96,11 @@ def measure_inference_speed(model: torch.nn.Module, device = 'cuda' else: raise NotImplementedError('To use cpu to test latency not supported.') + # benchmark with {max_iter} image and take the average for i in range(1, max_iter): if device == 'cuda': - data = torch.rand(resource_args['input_shape']).cuda() + data = torch.rand(input_shape).cuda() torch.cuda.synchronize() start_time = time.perf_counter() diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index 6d4342866..ac5292d0c 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Tuple +from typing import Dict, Optional, Tuple, Union import torch.nn -from mmengine.dist import broadcast_object_list, is_main_process from mmrazor.registry import TASK_UTILS from .base_estimator import BaseEstimator -from .counters import (get_model_complexity_info, params_units_convert, - repeat_measure_inference_speed) +from .counters import get_model_flops_params, get_model_latency @TASK_UTILS.register_module() @@ -15,24 +13,30 @@ class ResourceEstimator(BaseEstimator): """Estimator for calculating the resources consume. Args: - default_shape (tuple): Input data's default shape, for calculating - resources consume. Defaults to (1, 3, 224, 224) - units (str): Resource units. Defaults to 'M'. - disabled_counters (list): List of disabled spec op counters. - Defaults to None. - NOTE: disabled_counters contains the op counter class names - in estimator.op_counters that require to be disabled, - such as 'ConvCounter', 'BatchNorm2dCounter', ... + input_shape (tuple): Input data's default shape, for calculating + resources consume. Defaults to (1, 3, 224, 224). + units (dict): Dict that contains converted FLOPs/params/latency units. + Default to dict(flops='M', params='M', latency='ms'). + as_strings (bool): Output FLOPs/params/latency counts in a string + form. Default to False. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. Examples: >>> # direct calculate resource consume of nn.Conv2d >>> conv2d = nn.Conv2d(3, 32, 3) - >>> estimator = ResourceEstimator() - >>> estimator.estimate( - ... model=conv2d, - ... resource_args=dict(input_shape=(1, 3, 64, 64))) + >>> estimator = ResourceEstimator(input_shape=(1, 3, 64, 64)) + >>> estimator.estimate(model=conv2d) {'flops': 3.444, 'params': 0.001, 'latency': 0.0} + >>> # direct calculate resource consume of nn.Conv2d + >>> conv2d = nn.Conv2d(3, 32, 3) + >>> estimator = ResourceEstimator() + >>> flops_params_cfg = dict(input_shape=(1, 3, 32, 32)) + >>> estimator.estimate(model=conv2d, flops_params_cfg) + {'flops': 0.806, 'params': 0.001, 'latency': 0.0} + >>> # calculate resources of custom modules >>> class CustomModule(nn.Module): ... @@ -51,17 +55,14 @@ class ResourceEstimator(BaseEstimator): ... module.__params__ += 700000 ... >>> model = CustomModule() - >>> estimator.estimate( - ... model=model, - ... resource_args=dict(input_shape=(1, 3, 64, 64))) + >>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64)) + >>> estimator.estimate(model=model, flops_params_cfg) {'flops': 1.0, 'params': 0.7, 'latency': 0.0} ... >>> # calculate resources of custom modules with disable_counters - >>> estimator.estimate( - ... model=model, - ... resource_args=dict( - ... input_shape=(1, 3, 64, 64), - ... disabled_counters=['CustomModuleCounter'])) + >>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64), + ... disabled_counters=['CustomModuleCounter']) + >>> estimator.estimate(model=model, flops_params_cfg) {'flops': 0.0, 'params': 0.0, 'latency': 0.0} >>> # calculate resources of mmrazor.models @@ -69,87 +70,146 @@ class ResourceEstimator(BaseEstimator): mmrazor.engine.hooks.estimate_resources_hook for details. """ - def __init__(self, - default_shape: Tuple = (1, 3, 224, 224), - units: str = 'M', - disabled_counters: List[str] = [], - as_strings: bool = False, - measure_inference: bool = False): - super().__init__(default_shape, units, disabled_counters, as_strings, - measure_inference) - - def estimate( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, Any]: + def __init__( + self, + input_shape: Tuple = (1, 3, 224, 224), + units: Dict = dict(flops='M', params='M', latency='ms'), + as_strings: bool = False, + flops_params_cfg: Optional[dict] = None, + latency_cfg: Optional[dict] = None, + ): + super().__init__(input_shape, units, as_strings) + if not isinstance(units, dict): + raise TypeError('units for estimator should be a dict', + f'but got `{type(units)}`') + for unit_key in units: + if unit_key not in ['flops', 'params', 'latency']: + raise KeyError(f'Got invalid key `{unit_key}` in units. ', + 'Should be `flops`, `params` or `latency`.') + if flops_params_cfg: + self.flops_params_cfg = flops_params_cfg + else: + self.flops_params_cfg = dict() + self.latency_cfg = latency_cfg if latency_cfg else dict() + + def estimate(self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. + This method will first parse the merged :attr:`self.flops_params_cfg` + and the :attr:`self.latency_cfg` to check whether the keys are valid. + Args: model: The measured model. - resource_args (Dict[str, float]): Args for resources estimation. - NOTE: resource_args have the same items() as the init cfgs. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. + latency_cfg (dict): Cfg for estimating latency. Default to None. + + NOTE: If the `flops_params_cfg` and `latency_cfg` are both None, + this method will only estimate FLOPs/params with default settings. Returns: - Dict[str, str]): A dict that containing resource results(flops, - params and latency). + Dict[str, Union[float, str]]): A dict that contains the resource + results(FLOPs, params and latency). """ resource_metrics = dict() - if is_main_process(): - measure_inference = resource_args.pop('measure_inference', False) - if 'input_shape' not in resource_args.keys(): - resource_args['input_shape'] = self.default_shape - if 'disabled_counters' not in resource_args.keys(): - resource_args['disabled_counters'] = self.disabled_counters - model.eval() - flops, params = get_model_complexity_info(model, **resource_args) - if measure_inference: - latency = repeat_measure_inference_speed( - model, resource_args, max_iter=100, repeat_num=2) - else: - latency = 0.0 - as_strings = resource_args.get('as_strings', self.as_strings) - if as_strings and self.units is not None: - raise ValueError('Set units to None, when as_trings=True.') - if self.units is not None: - flops = params_units_convert(flops, self.units) - params = params_units_convert(params, self.units) - resource_metrics.update({ - 'flops': flops, - 'params': params, - 'latency': latency - }) - results = [resource_metrics] + measure_latency = True if latency_cfg else False + + if flops_params_cfg: + flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg} + self._check_flops_params_cfg(flops_params_cfg) + flops_params_cfg = self._set_default_resource_params( + flops_params_cfg) else: - results = [None] # type: ignore + flops_params_cfg = self.flops_params_cfg - broadcast_object_list(results) + if latency_cfg: + latency_cfg = {**self.latency_cfg, **latency_cfg} + self._check_latency_cfg(latency_cfg) + latency_cfg = self._set_default_resource_params(latency_cfg) + else: + latency_cfg = self.latency_cfg + + model.eval() + flops, params = get_model_flops_params(model, **flops_params_cfg) + if measure_latency: + latency = get_model_latency(model, **latency_cfg) + else: + latency = '0.0 ms' if self.as_strings else 0.0 # type: ignore - return results[0] + resource_metrics.update({ + 'flops': flops, + 'params': params, + 'latency': latency + }) + return resource_metrics - def estimate_spec_modules( - self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() - ) -> Dict[str, float]: - """Estimate the resources(flops/params/latency) of the spec modules. + def estimate_separation_modules( + self, + model: torch.nn.Module, + flops_params_cfg: dict = None) -> Dict[str, Union[float, str]]: + """Estimate FLOPs and params of the spec modules with separate return. Args: model: The measured model. - resource_args (Dict[str, float]): Args for resources estimation. - NOTE: resource_args have the same items() as the init cfgs. + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + Default to None. Returns: - Dict[str, float]): A dict that containing resource results(flops, - params) of each modules in resource_args['spec_modules']. + Dict[str, Union[float, str]]): A dict that contains the FLOPs and + params results (string | float format) of each modules in the + ``flops_params_cfg['spec_modules']``. """ - assert 'spec_modules' in resource_args, \ - 'spec_modules is required when calling estimate_spec_modules().' + if flops_params_cfg: + flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg} + self._check_flops_params_cfg(flops_params_cfg) + flops_params_cfg = self._set_default_resource_params( + flops_params_cfg) + else: + flops_params_cfg = self.flops_params_cfg + flops_params_cfg['seperate_return'] = True - resource_args.pop('measure_inference', False) - if 'input_shape' not in resource_args.keys(): - resource_args['input_shape'] = self.default_shape - if 'disabled_counters' not in resource_args.keys(): - resource_args['disabled_counters'] = self.disabled_counters + assert len(flops_params_cfg['spec_modules']), ( + 'spec_modules can not be empty when calling ' + f'`estimate_separation_modules` of {self.__class__.__name__} ') model.eval() - spec_modules_resources = get_model_complexity_info( - model, **resource_args) - + spec_modules_resources = get_model_flops_params( + model, **flops_params_cfg) return spec_modules_resources + + def _check_flops_params_cfg(self, flops_params_cfg: dict) -> None: + """Check the legality of ``flops_params_cfg``. + + Args: + flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. + """ + for key in flops_params_cfg: + if key not in get_model_flops_params.__code__.co_varnames[ + 1:]: # type: ignore + raise KeyError(f'Got invalid key `{key}` in flops_params_cfg.') + + def _check_latency_cfg(self, latency_cfg: dict) -> None: + """Check the legality of ``latency_cfg``. + + Args: + latency_cfg (dict): Cfg for estimating latency. + """ + for key in latency_cfg: + if key not in get_model_latency.__code__.co_varnames[ + 1:]: # type: ignore + raise KeyError(f'Got invalid key `{key}` in latency_cfg.') + + def _set_default_resource_params(self, cfg: dict) -> dict: + """Set default attributes for the input cfgs. + + Args: + cfg (dict): flops_params_cfg or latency_cfg. + """ + default_common_settings = ['input_shape', 'units', 'as_strings'] + for key in default_common_settings: + if key not in cfg: + cfg[key] = getattr(self, key) + return cfg diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py index 85f25eaff..5365831b9 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdet.models import BaseDetector from mmrazor.registry import TASK_UTILS +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + # todo: adapt to mmdet 2.0 @TASK_UTILS.register_module() diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index 99be89bca..60bcef4ba 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -118,9 +118,9 @@ def sample_choice(self, model: Module) -> None: def test_estimate(self) -> None: fool_conv2d = FoolConv2d() + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) results = estimator.estimate( - model=fool_conv2d, - resource_args=dict(input_shape=(1, 3, 224, 224))) + model=fool_conv2d, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -129,9 +129,9 @@ def test_estimate(self) -> None: def test_register_module(self) -> None: fool_add_constant = FoolConvModule() + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) results = estimator.estimate( - model=fool_add_constant, - resource_args=dict(input_shape=(1, 3, 224, 224))) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] @@ -140,46 +140,65 @@ def test_register_module(self) -> None: def test_disable_sepc_counter(self) -> None: fool_add_constant = FoolConvModule() + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), + disabled_counters=['FoolAddConstantCounter']) rest_results = estimator.estimate( - model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), - disabled_counters=['FoolAddConstantCounter'])) + model=fool_add_constant, flops_params_cfg=flops_params_cfg) rest_flops_count = rest_results['flops'] rest_params_count = rest_results['params'] self.assertLess(rest_flops_count, 45.158) self.assertLess(rest_params_count, 0.701) - def test_estimate_spec_modules(self) -> None: + def test_estimate_spec_module(self) -> None: fool_add_constant = FoolConvModule() - results = estimator.estimate_spec_modules( - model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), spec_modules=['add_constant'])) + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), + spec_modules=['add_constant', 'conv2d']) + results = estimator.estimate( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) + flops_count = results['flops'] + params_count = results['params'] + + self.assertEqual(flops_count, 45.158) + self.assertEqual(params_count, 0.701) + + def test_estimate_separation_modules(self) -> None: + fool_add_constant = FoolConvModule() + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=['add_constant']) + results = estimator.estimate_separation_modules( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) self.assertGreater(results['add_constant']['flops'], 0) with pytest.raises(AssertionError): - results = estimator.estimate_spec_modules( - model=fool_add_constant, - resource_args=dict( - input_shape=(1, 3, 224, 224), spec_modules=['backbone'])) + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=['backbone']) + results = estimator.estimate_separation_modules( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) + + with pytest.raises(AssertionError): + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), spec_modules=[]) + results = estimator.estimate_separation_modules( + model=fool_add_constant, flops_params_cfg=flops_params_cfg) def test_estimate_subnet(self) -> None: - resource_args = dict(input_shape=(1, 3, 224, 224)) + flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) model = MODELS.build(BACKBONE_CFG) self.sample_choice(model) copied_model = copy.deepcopy(model) results = estimator.estimate( - model=copied_model, resource_args=resource_args) + model=copied_model, flops_params_cfg=flops_params_cfg) flops_count = results['flops'] params_count = results['params'] fix_subnet = export_fix_subnet(model) load_fix_subnet(copied_model, fix_subnet) subnet_results = estimator.estimate( - model=copied_model, resource_args=resource_args) + model=copied_model, flops_params_cfg=flops_params_cfg) subnet_flops_count = subnet_results['flops'] subnet_params_count = subnet_results['params'] @@ -188,8 +207,8 @@ def test_estimate_subnet(self) -> None: # test whether subnet estimate will affect original model copied_model = copy.deepcopy(model) - results_after_estimate = \ - estimator.estimate(model=copied_model, resource_args=resource_args) + results_after_estimate = estimator.estimate( + model=copied_model, flops_params_cfg=flops_params_cfg) flops_count_after_estimate = results_after_estimate['flops'] params_count_after_estimate = results_after_estimate['params'] diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index 14e642c57..f30019274 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -112,10 +112,7 @@ def test_init(self): self.assertEqual(loop.candidates, fake_candidates) @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - @patch( - 'mmrazor.engine.runner.evolution_search_loop.get_model_complexity_info' - ) - def test_run_epoch(self, mock_flops, mock_export_fix_subnet): + def test_run_epoch(self, mock_export_fix_subnet): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -155,7 +152,7 @@ def test_run_epoch(self, mock_flops, mock_export_fix_subnet): self.runner.work_dir = self.temp_dir fake_subnet = {'1': 'choice1', '2': 'choice2'} loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) + loop._check_constraints = MagicMock(return_value=True) mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index 0f26c5aeb..fca29b823 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -192,30 +192,15 @@ def test_sample_subnet(self): self.assertEqual(subnet, fake_subnet) self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) - @patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet') - @patch( - 'mmrazor.engine.runner.subnet_sampler_loop.get_model_complexity_info') - def test_run(self, mock_flops, mock_export_fix_subnet): - # test run with flops_range=None - cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_run1' - runner = Runner.from_cfg(cfg) - fake_subnet = {'1': 'choice1', '2': 'choice2'} - runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - runner.train() - - self.assertEqual(runner.iter, runner.max_iters) - assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl')) - + def test_run(self): # test run with _check_constraints cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_run2' - cfg.train_cfg.flops_range = (0, 100) + cfg.experiment_name = 'test_run1' runner = Runner.from_cfg(cfg) fake_subnet = {'1': 'choice1', '2': 'choice2'} runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) - mock_export_fix_subnet.return_value = fake_subnet + loop = runner.build_train_loop(cfg.train_cfg) + loop._check_constraints = MagicMock(return_value=True) runner.train() self.assertEqual(runner.iter, runner.max_iters) diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py new file mode 100644 index 000000000..b9bd57989 --- /dev/null +++ b/tests/test_runners/test_utils/test_check.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +from mmrazor.engine.runner.utils import check_subnet_flops + +try: + from mmdet.models.detectors import BaseDetector +except ImportError: + from mmrazor.utils import get_placeholder + BaseDetector = get_placeholder('mmdet') + + +@patch('mmrazor.models.ResourceEstimator') +@patch('mmrazor.models.SPOS') +def test_check_subnet_flops(mock_model, mock_estimator): + # flops_range = None + flops_range = None + fake_subnet = {'1': 'choice1', '2': 'choice2'} + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) + assert result is True + + # flops_range is not None + # architecturte is BaseDetector + flops_range = (0., 100.) + mock_model.architecture = BaseDetector + fake_results = {'flops': 50.} + mock_estimator.estimate.return_value = fake_results + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) + assert result is True + + # flops_range is not None + # architecturte is BaseDetector + flops_range = (0., 100.) + fake_results = {'flops': -50.} + mock_estimator.estimate.return_value = fake_results + result = check_subnet_flops(mock_model, fake_subnet, mock_estimator, + flops_range) + assert result is False From d07dee9887f60346b2d8c3c929e7b91ebd44e865 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Thu, 15 Sep 2022 14:52:54 +0800 Subject: [PATCH 3/7] [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge --- .../image_classifier_loss_calculator.py | 14 +++++-- .../single_stage_detector_loss_calculator.py | 21 +++++++--- mmrazor/models/task_modules/tracer/parsers.py | 33 ++++++--------- .../test_tracer/test_backward_tracer.py | 42 ++++++++++++++++++- .../test_tracer/test_loss_calculator.py | 25 +++++++++++ 5 files changed, 104 insertions(+), 31 deletions(-) create mode 100644 tests/test_core/test_tracer/test_loss_calculator.py diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/image_classifier_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/image_classifier_loss_calculator.py index c06342bbc..65e908e30 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/image_classifier_loss_calculator.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/image_classifier_loss_calculator.py @@ -13,9 +13,17 @@ @TASK_UTILS.register_module() class ImageClassifierPseudoLoss: """Calculate the pseudo loss to trace the topology of a `ImageClassifier` - in MMClassification with `BackwardTracer`.""" + in MMClassification with `BackwardTracer`. + + Args: + input_shape (Tuple): The shape of the pseudo input. Defaults to + (2, 3, 224, 224). + """ + + def __init__(self, input_shape=(2, 3, 224, 224)): + self.input_shape = input_shape def __call__(self, model: ImageClassifier) -> torch.Tensor: - pseudo_img = torch.rand(1, 3, 224, 224) + pseudo_img = torch.rand(self.input_shape) pseudo_output = model(pseudo_img) - return sum(pseudo_output) + return pseudo_output.sum() diff --git a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py index 5365831b9..f8554580d 100644 --- a/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py +++ b/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py @@ -4,19 +4,28 @@ from mmrazor.registry import TASK_UTILS try: - from mmdet.models.detectors import BaseDetector + from mmdet.models import SingleStageDetector except ImportError: from mmrazor.utils import get_placeholder - BaseDetector = get_placeholder('mmdet') + SingleStageDetector = get_placeholder('mmdet') -# todo: adapt to mmdet 2.0 @TASK_UTILS.register_module() class SingleStageDetectorPseudoLoss: + """Calculate the pseudo loss to trace the topology of a + `SingleStageDetector` in MMDetection with `BackwardTracer`. - def __call__(self, model: BaseDetector) -> torch.Tensor: - pseudo_img = torch.rand(1, 3, 224, 224) - pseudo_output = model.forward_dummy(pseudo_img) + Args: + input_shape (Tuple): The shape of the pseudo input. Defaults to + (2, 3, 224, 224). + """ + + def __init__(self, input_shape=(2, 3, 224, 224)): + self.input_shape = input_shape + + def __call__(self, model: SingleStageDetector) -> torch.Tensor: + pseudo_img = torch.rand(self.input_shape) + pseudo_output = model(pseudo_img) out = torch.tensor(0.) for levels in pseudo_output: out += sum([level.sum() for level in levels]) diff --git a/mmrazor/models/task_modules/tracer/parsers.py b/mmrazor/models/task_modules/tracer/parsers.py index c5994a27a..efcfc8613 100644 --- a/mmrazor/models/task_modules/tracer/parsers.py +++ b/mmrazor/models/task_modules/tracer/parsers.py @@ -118,27 +118,16 @@ def parse_cat(tracer, grad_fn, module2name, param2module, cur_path, >>> # ``out`` is obtained by concatenating two tensors """ parents = grad_fn.next_functions - concat_id = '_'.join([str(id(p)) for p in parents]) - name = f'concat_{concat_id}' - # If a module is not a shared module and it has been visited during - # forward, its parent modules must have been traced already. - # However, a shared module will be visited more than once during - # forward, so it is still need to be traced even if it has been - # visited. - if (name in visited and visited[name] and name not in shared_module): - pass - else: - visited[name] = True - sub_path_lists = list() - for i, parent in enumerate(parents): - sub_path_list = PathList() - tracer.backward_trace(parent, module2name, param2module, Path(), - sub_path_list, visited, shared_module) - sub_path_lists.append(sub_path_list) - cur_path.append(PathConcatNode(name, sub_path_lists)) - - result_paths.append(copy.deepcopy(cur_path)) - cur_path.pop(-1) + sub_path_lists = list() + for i, parent in enumerate(parents): + sub_path_list = PathList() + tracer.backward_trace(parent, module2name, param2module, Path(), + sub_path_list, visited, shared_module) + sub_path_lists.append(sub_path_list) + cur_path.append(PathConcatNode('CatNode', sub_path_lists)) + + result_paths.append(copy.deepcopy(cur_path)) + cur_path.pop(-1) def parse_norm(tracer, grad_fn, module2name, param2module, cur_path, @@ -174,6 +163,8 @@ def parse_norm(tracer, grad_fn, module2name, param2module, cur_path, DEFAULT_BACKWARD_TRACER: Dict[str, Callable] = { + 'ConvolutionBackward': parse_conv, + 'SlowConv2DBackward': parse_conv, 'ThnnConv2DBackward': parse_conv, 'CudnnConvolutionBackward': parse_conv, 'MkldnnConvolutionBackward': parse_conv, diff --git a/tests/test_core/test_tracer/test_backward_tracer.py b/tests/test_core/test_tracer/test_backward_tracer.py index 93dddd56a..33f5eff78 100644 --- a/tests/test_core/test_tracer/test_backward_tracer.py +++ b/tests/test_core/test_tracer/test_backward_tracer.py @@ -57,6 +57,31 @@ def forward(self, x: Tensor) -> Tensor: return output +class MultiConcatModel3(Module): + + def __init__(self) -> None: + super().__init__() + + self.op1 = nn.Conv2d(3, 8, 1) + self.op2 = nn.Conv2d(3, 8, 1) + self.op3 = nn.Conv2d(3, 8, 1) + self.op4 = nn.Conv2d(24, 8, 1) + self.op5 = nn.Conv2d(24, 8, 1) + self.op6 = nn.Conv2d(24, 8, 1) + self.op7 = nn.Conv2d(24, 8, 1) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + cat1 = torch.cat([x1, x2, x3], dim=1) + x4 = self.op4(cat1) + x5 = self.op5(cat1) + x6 = self.op6(cat1) + x7 = self.op7(cat1) + return torch.cat([x4, x5, x6, x7], dim=1) + + class ResBlock(Module): def __init__(self) -> None: @@ -77,8 +102,11 @@ def forward(self, x: Tensor) -> Tensor: class ToyCNNPseudoLoss: + def __init__(self, input_shape=(2, 3, 16, 16)): + self.input_shape = input_shape + def __call__(self, model): - pseudo_img = torch.rand(2, 3, 16, 16) + pseudo_img = torch.rand(self.input_shape) pseudo_output = model(pseudo_img) return pseudo_output.sum() @@ -158,6 +186,18 @@ def test_trace_multi_cat(self) -> None: 'op1', 'op2', 'op3' ] + model = MultiConcatModel3() + tracer = BackwardTracer(loss_calculator=loss_calculator) + path_list = tracer.trace(model) + assert len(path_list) == 1 + + nonpass2parents = path_list.find_nodes_parents(NONPASS_NODES) + assert nonpass2parents['op1'] == list() + assert nonpass2parents['op2'] == list() + assert nonpass2parents['op3'] == list() + assert nonpass2parents['op4'] == nonpass2parents['op5'] == \ + nonpass2parents['op6'] == nonpass2parents['op7'] + def test_repr(self): toy_node = PathConvNode('op1') assert repr(toy_node) == 'PathConvNode(\'op1\')' diff --git a/tests/test_core/test_tracer/test_loss_calculator.py b/tests/test_core/test_tracer/test_loss_calculator.py new file mode 100644 index 000000000..309b96091 --- /dev/null +++ b/tests/test_core/test_tracer/test_loss_calculator.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine.hub import get_model + +from mmrazor.models.task_modules.tracer import (ImageClassifierPseudoLoss, + SingleStageDetectorPseudoLoss) + + +class TestLossCalculator(TestCase): + + def test_image_classifier_pseudo_loss(self): + model = get_model( + 'mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=False) + loss_calculator = ImageClassifierPseudoLoss() + loss = loss_calculator(model) + assert isinstance(loss, torch.Tensor) and loss.dim() == 0 + + def test_single_stage_detector_pseudo_loss(self): + model = get_model( + 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', pretrained=False) + loss_calculator = SingleStageDetectorPseudoLoss() + loss = loss_calculator(model) + assert isinstance(loss, torch.Tensor) and loss.dim() == 0 From 8d603d917e591dd0b6bf876584f66a4c25cda816 Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Thu, 29 Sep 2022 16:48:47 +0800 Subject: [PATCH 4/7] [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP --- .../dsnas_shufflenet_supernet.py | 28 ++ .../_base_/settings/imagenet_bs1024_dsnas.py | 102 +++++ .../DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml | 20 + configs/nas/mmcls/dsnas/README.md | 43 +++ .../mmcls/dsnas/dsnas_subnet_8xb128_in1k.py | 29 ++ .../mmcls/dsnas/dsnas_supernet_8xb128_in1k.py | 36 ++ mmrazor/engine/__init__.py | 4 +- mmrazor/models/algorithms/__init__.py | 5 +- mmrazor/models/algorithms/nas/__init__.py | 5 +- mmrazor/models/algorithms/nas/dsnas.py | 347 ++++++++++++++++++ mmrazor/models/mutables/__init__.py | 5 +- .../mutables/mutable_module/__init__.py | 5 +- .../mutable_module/diff_mutable_module.py | 105 +++++- .../module_mutator/diff_module_mutator.py | 4 +- .../mutators/module_mutator/module_mutator.py | 22 +- mmrazor/structures/subnet/fix_subnet.py | 32 +- .../test_models/test_algorithms/test_dsnas.py | 222 +++++++++++ .../test_mutables/test_onehotop.py | 203 ++++++++++ 18 files changed, 1187 insertions(+), 30 deletions(-) create mode 100644 configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py create mode 100644 configs/_base_/settings/imagenet_bs1024_dsnas.py create mode 100644 configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml create mode 100644 configs/nas/mmcls/dsnas/README.md create mode 100644 configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py create mode 100644 configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py create mode 100644 mmrazor/models/algorithms/nas/dsnas.py create mode 100644 tests/test_models/test_algorithms/test_dsnas.py create mode 100644 tests/test_models/test_mutables/test_onehotop.py diff --git a/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py b/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py new file mode 100644 index 000000000..f73c8b90e --- /dev/null +++ b/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py @@ -0,0 +1,28 @@ +norm_cfg = dict(type='BN', eps=0.01) + +_STAGE_MUTABLE = dict( + type='mmrazor.OneHotMutableOP', + fix_threshold=0.3, + candidates=dict( + shuffle_3x3=dict( + type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg), + shuffle_5x5=dict( + type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg), + shuffle_7x7=dict( + type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg), + shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg))) + +arch_setting = [ + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, mutable_cfg. + [64, 4, _STAGE_MUTABLE], + [160, 4, _STAGE_MUTABLE], + [320, 8, _STAGE_MUTABLE], + [640, 4, _STAGE_MUTABLE] +] + +nas_backbone = dict( + type='mmrazor.SearchableShuffleNetV2', + widen_factor=1.0, + arch_setting=arch_setting, + norm_cfg=norm_cfg) diff --git a/configs/_base_/settings/imagenet_bs1024_dsnas.py b/configs/_base_/settings/imagenet_bs1024_dsnas.py new file mode 100644 index 000000000..bf266c51c --- /dev/null +++ b/configs/_base_/settings/imagenet_bs1024_dsnas.py @@ -0,0 +1,102 @@ +# dataset settings +dataset_type = 'mmcls.ImageNet' +data_preprocessor = dict( + type='mmcls.ClsDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type='mmcls.LoadImageFromFile'), + dict(type='mmcls.RandomResizedCrop', scale=224), + dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), + dict(type='mmcls.PackClsInputs'), +] + +test_pipeline = [ + dict(type='mmcls.LoadImageFromFile'), + dict(type='mmcls.ResizeEdge', scale=256, edge='short'), + dict(type='mmcls.CenterCrop', crop_size=224), + dict(type='mmcls.PackClsInputs'), +] + +train_dataloader = dict( + batch_size=128, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=128, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +# optimizer +paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0) + +optim_wrapper = dict( + constructor='mmrazor.SeparateOptimWrapperConstructor', + architecture=dict( + optimizer=dict( + type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5), + paramwise_cfg=paramwise_cfg), + mutator=dict( + optimizer=dict( + type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5, + 0.999)))) + +search_epochs = 85 +# leanring policy +param_scheduler = dict( + architecture=[ + dict( + type='mmcls.LinearLR', + end=5, + start_factor=0.2, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='mmcls.CosineAnnealingLR', + T_max=240, + begin=5, + end=search_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='mmcls.CosineAnnealingLR', + T_max=160, + begin=search_epochs, + end=240, + eta_min=0.0, + by_epoch=True, + convert_to_iter_based=True) + ], + mutator=[]) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=240) +val_cfg = dict() +test_cfg = dict() diff --git a/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml new file mode 100644 index 000000000..d2fa294d3 --- /dev/null +++ b/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml @@ -0,0 +1,20 @@ +backbone.layers.0.0: shuffle_3x3 +backbone.layers.0.1: shuffle_3x3 +backbone.layers.0.2: shuffle_xception +backbone.layers.0.3: shuffle_3x3 +backbone.layers.1.0: shuffle_xception +backbone.layers.1.1: shuffle_7x7 +backbone.layers.1.2: shuffle_3x3 +backbone.layers.1.3: shuffle_3x3 +backbone.layers.2.0: shuffle_xception +backbone.layers.2.1: shuffle_xception +backbone.layers.2.2: shuffle_7x7 +backbone.layers.2.3: shuffle_xception +backbone.layers.2.4: shuffle_xception +backbone.layers.2.5: shuffle_xception +backbone.layers.2.6: shuffle_7x7 +backbone.layers.2.7: shuffle_3x3 +backbone.layers.3.0: shuffle_3x3 +backbone.layers.3.1: shuffle_xception +backbone.layers.3.2: shuffle_xception +backbone.layers.3.3: shuffle_3x3 diff --git a/configs/nas/mmcls/dsnas/README.md b/configs/nas/mmcls/dsnas/README.md new file mode 100644 index 000000000..6a085eb78 --- /dev/null +++ b/configs/nas/mmcls/dsnas/README.md @@ -0,0 +1,43 @@ +# DSNAS + +> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf) + + + +## Abstract + +Most existing NAS methods require two-stage parameter optimization. +However, performance of the same architecture in the two stages correlates poorly. +Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining. + +![pipeline](/docs/en/imgs/model_zoo/dsnas/pipeline.jpg) + +## Results and models + +### Supernet + +| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks | +| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: | +| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched | + +**Note**: + +1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example, + normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.** +2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51. +3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the + origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59. +4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS. +5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo. + +## Citation + +```latex +@inproceedings{hu2020dsnas, + title={Dsnas: Direct neural architecture search without parameter retraining}, + author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={12084--12092}, + year={2020} +} +``` diff --git a/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py new file mode 100644 index 000000000..ca30a5946 --- /dev/null +++ b/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py @@ -0,0 +1,29 @@ +_base_ = ['./dsnas_supernet_8xb128_in1k.py'] + +# NOTE: Replace this with the mutable_cfg searched by yourself. +fix_subnet = { + 'backbone.layers.0.0': 'shuffle_3x3', + 'backbone.layers.0.1': 'shuffle_7x7', + 'backbone.layers.0.2': 'shuffle_3x3', + 'backbone.layers.0.3': 'shuffle_5x5', + 'backbone.layers.1.0': 'shuffle_3x3', + 'backbone.layers.1.1': 'shuffle_3x3', + 'backbone.layers.1.2': 'shuffle_3x3', + 'backbone.layers.1.3': 'shuffle_7x7', + 'backbone.layers.2.0': 'shuffle_xception', + 'backbone.layers.2.1': 'shuffle_3x3', + 'backbone.layers.2.2': 'shuffle_3x3', + 'backbone.layers.2.3': 'shuffle_5x5', + 'backbone.layers.2.4': 'shuffle_3x3', + 'backbone.layers.2.5': 'shuffle_5x5', + 'backbone.layers.2.6': 'shuffle_7x7', + 'backbone.layers.2.7': 'shuffle_7x7', + 'backbone.layers.3.0': 'shuffle_xception', + 'backbone.layers.3.1': 'shuffle_3x3', + 'backbone.layers.3.2': 'shuffle_7x7', + 'backbone.layers.3.3': 'shuffle_3x3', +} + +model = dict(fix_subnet=fix_subnet) + +find_unused_parameters = False diff --git a/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py new file mode 100644 index 000000000..ea821da40 --- /dev/null +++ b/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py @@ -0,0 +1,36 @@ +_base_ = [ + 'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py', + 'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py', + 'mmcls::_base_/default_runtime.py', +] + +# model +model = dict( + type='mmrazor.Dsnas', + architecture=dict( + type='ImageClassifier', + data_preprocessor=_base_.data_preprocessor, + backbone=_base_.nas_backbone, + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + loss=dict( + type='LabelSmoothLoss', + num_classes=1000, + label_smooth_val=0.1, + mode='original', + loss_weight=1.0), + topk=(1, 5))), + mutator=dict(type='mmrazor.DiffModuleMutator'), + pretrain_epochs=15, + finetune_epochs=_base_.search_epochs, +) + +model_wrapper_cfg = dict( + type='mmrazor.DsnasDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +randomness = dict(seed=48, diff_rank_seed=True) diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index fd221b577..f2df86a83 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -10,6 +10,6 @@ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop', - 'EstimateResourcesHook' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook', + 'SelfDistillValLoop' ] diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index bbc0e5755..2d96f3a96 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -3,12 +3,13 @@ from .distill import (DAFLDataFreeDistillation, DataFreeDistillation, FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) -from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP +from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP from .pruning import SlimmableNetwork, SlimmableNetworkDDP __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation', - 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation' + 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas', + 'DsnasDDP' ] diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index b18fd339d..17eab7e86 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .autoslim import AutoSlim, AutoSlimDDP from .darts import Darts, DartsDDP +from .dsnas import Dsnas, DsnasDDP from .spos import SPOS -__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP'] +__all__ = [ + 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP' +] diff --git a/mmrazor/models/algorithms/nas/dsnas.py b/mmrazor/models/algorithms/nas/dsnas.py new file mode 100644 index 000000000..62c2c7f04 --- /dev/null +++ b/mmrazor/models/algorithms/nas/dsnas.py @@ -0,0 +1,347 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from mmengine.dist import get_dist_info +from mmengine.logging import MessageHub +from mmengine.model import BaseModel, MMDistributedDataParallel +from mmengine.optim import OptimWrapper, OptimWrapperDict +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.mutables.base_mutable import BaseMutable +from mmrazor.models.mutators import DiffModuleMutator +from mmrazor.models.utils import add_prefix +from mmrazor.registry import MODEL_WRAPPERS, MODELS, TASK_UTILS +from mmrazor.structures import load_fix_subnet +from mmrazor.utils import FixMutable +from ..base import BaseAlgorithm + + +@MODELS.register_module() +class Dsnas(BaseAlgorithm): + """Implementation of `DSNAS `_ + + Args: + architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel` + or built model. Corresponding to supernet in NAS algorithm. + mutator (dict|:obj:`DiffModuleMutator`): The config of + :class:`DiffModuleMutator` or built mutator. + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. + pretrain_epochs (int): Num of epochs for supernet pretraining. + finetune_epochs (int): Num of epochs for subnet finetuning. + flops_constraints (float): Flops constraints for judging whether to + backward flops loss or not. Default to 300.0(M). + estimator_cfg (Dict[str, Any]): Used for building a resource estimator. + Default to None. + norm_training (bool): Whether to set norm layers to training mode, + namely, not freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to False. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. Defaults to None. + init_cfg (dict): Init config for ``BaseModule``. + + Note: + Dsnas doesn't require retraining. It has 3 stages in searching: + 1. `cur_epoch` < `pretrain_epochs` refers to supernet pretraining. + 2. `pretrain_epochs` <= `cur_epoch` < `finetune_epochs` refers to + normal supernet training while mutator is updated. + 3. `cur_epoch` >= `finetune_epochs` refers to subnet finetuning. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: Optional[Union[DiffModuleMutator, Dict]] = None, + fix_subnet: Optional[FixMutable] = None, + pretrain_epochs: int = 0, + finetune_epochs: int = 80, + flops_constraints: float = 300.0, + estimator_cfg: Dict[str, Any] = None, + norm_training: bool = False, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None, + **kwargs): + super().__init__(architecture, data_preprocessor, **kwargs) + + if estimator_cfg is None: + estimator_cfg = dict(type='mmrazor.ResourceEstimator') + self.estimator = TASK_UTILS.build(estimator_cfg) + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self.architecture, fix_subnet) + self.is_supernet = False + else: + assert mutator is not None, \ + 'mutator cannot be None when fix_subnet is None.' + if isinstance(mutator, DiffModuleMutator): + self.mutator = mutator + elif isinstance(mutator, dict): + self.mutator = MODELS.build(mutator) + else: + raise TypeError('mutator should be a `dict` or ' + f'`DiffModuleMutator` instance, but got ' + f'{type(mutator)}') + + self.mutable_module_resources = self._get_module_resources() + # Mutator is an essential component of the NAS algorithm. It + # provides some APIs commonly used by NAS. + # Before using it, you must do some preparations according to + # the supernet. + self.mutator.prepare_from_supernet(self.architecture) + self.is_supernet = True + self.search_space_name_list = list( + self.mutator.name2mutable.keys()) + + self.norm_training = norm_training + self.pretrain_epochs = pretrain_epochs + self.finetune_epochs = finetune_epochs + if pretrain_epochs >= finetune_epochs: + raise ValueError(f'Pretrain stage (optional) must be done before ' + f'finetuning stage. Got `{pretrain_epochs}` >= ' + f'`{finetune_epochs}`.') + + self.flops_loss_coef = 1e-2 + self.flops_constraints = flops_constraints + _, self.world_size = get_dist_info() + + def search_subnet(self): + """Search subnet by mutator.""" + + # Avoid circular import + from mmrazor.structures import export_fix_subnet + + subnet = self.mutator.sample_choices() + self.mutator.set_choices(subnet) + return export_fix_subnet(self) + + def fix_subnet(self): + """Fix subnet when finetuning.""" + subnet = self.mutator.sample_choices() + self.mutator.set_choices(subnet) + for module in self.architecture.modules(): + if isinstance(module, BaseMutable): + if not module.is_fixed: + module.fix_chosen(module.current_choice) + self.is_supernet = False + + def train(self, mode=True): + """Convert the model into eval mode while keep normalization layer + unfreezed.""" + + super().train(mode) + if self.norm_training and not mode: + for module in self.architecture.modules(): + if isinstance(module, _BatchNorm): + module.training = True + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """The iteration step during training. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. + """ + if isinstance(optim_wrapper, OptimWrapperDict): + log_vars = dict() + self.message_hub = MessageHub.get_current_instance() + cur_epoch = self.message_hub.get_info('epoch') + need_update_mutator = self.need_update_mutator(cur_epoch) + + # TODO process the input + if cur_epoch == self.finetune_epochs and self.is_supernet: + # synchronize arch params to start the finetune stage. + for k, v in self.mutator.arch_params.items(): + dist.broadcast(v, src=0) + self.fix_subnet() + + # 1. update architecture + with optim_wrapper['architecture'].optim_context(self): + pseudo_data = self.data_preprocessor(data, True) + supernet_batch_inputs = pseudo_data['inputs'] + supernet_data_samples = pseudo_data['data_samples'] + supernet_loss = self( + supernet_batch_inputs, supernet_data_samples, mode='loss') + + supernet_losses, supernet_log_vars = self.parse_losses( + supernet_loss) + optim_wrapper['architecture'].backward( + supernet_losses, retain_graph=need_update_mutator) + optim_wrapper['architecture'].step() + optim_wrapper['architecture'].zero_grad() + log_vars.update(add_prefix(supernet_log_vars, 'supernet')) + + # 2. update mutator + if need_update_mutator: + with optim_wrapper['mutator'].optim_context(self): + mutator_loss = self.compute_mutator_loss() + mutator_losses, mutator_log_vars = \ + self.parse_losses(mutator_loss) + optim_wrapper['mutator'].update_params(mutator_losses) + log_vars.update(add_prefix(mutator_log_vars, 'mutator')) + # handle the grad of arch params & weights + self.handle_grads() + + else: + # Enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): + pseudo_data = self.data_preprocessor(data, True) + batch_inputs = pseudo_data['inputs'] + data_samples = pseudo_data['data_samples'] + losses = self(batch_inputs, data_samples, mode='loss') + parsed_losses, log_vars = self.parse_losses(losses) + optim_wrapper.update_params(parsed_losses) + + return log_vars + + def _get_module_resources(self): + """Get resources of spec modules.""" + + spec_modules = [] + for name, module in self.architecture.named_modules(): + if isinstance(module, BaseMutable): + for choice in module.choices: + spec_modules.append(name + '._candidates.' + choice) + + mutable_module_resources = self.estimator.estimate_separation_modules( + self.architecture, dict(spec_modules=spec_modules)) + + return mutable_module_resources + + def need_update_mutator(self, cur_epoch: int) -> bool: + """Whether to update mutator.""" + if cur_epoch >= self.pretrain_epochs and \ + cur_epoch < self.finetune_epochs: + return True + return False + + def compute_mutator_loss(self) -> Dict[str, torch.Tensor]: + """Compute mutator loss. + + In this method, arch_loss & flops_loss[optional] are computed + by traversing arch_weights & probs in search groups. + + Returns: + Dict: Loss of the mutator. + """ + arch_loss = 0.0 + flops_loss = 0.0 + for name, module in self.architecture.named_modules(): + if isinstance(module, BaseMutable): + k = str(self.search_space_name_list.index(name)) + probs = F.softmax(self.mutator.arch_params[k], -1) + arch_loss += torch.log( + (module.arch_weights * probs).sum(-1)).sum() + + # get the index of op with max arch weights. + index = (module.arch_weights == 1).nonzero().item() + _module_key = name + '._candidates.' + module.choices[index] + flops_loss += probs[index] * \ + self.mutable_module_resources[_module_key]['flops'] + + mutator_loss = dict(arch_loss=arch_loss / self.world_size) + + copied_model = copy.deepcopy(self) + fix_mutable = copied_model.search_subnet() + load_fix_subnet(copied_model, fix_mutable) + + subnet_flops = self.estimator.estimate(copied_model)['flops'] + if subnet_flops >= self.flops_constraints: + mutator_loss['flops_loss'] = \ + (flops_loss * self.flops_loss_coef) / self.world_size + + return mutator_loss + + def handle_grads(self): + """Handle grads of arch params & arch weights.""" + for name, module in self.architecture.named_modules(): + if isinstance(module, BaseMutable): + k = str(self.search_space_name_list.index(name)) + self.mutator.arch_params[k].grad.data.mul_( + module.arch_weights.grad.data.sum()) + module.arch_weights.grad.zero_() + + +@MODEL_WRAPPERS.register_module() +class DsnasDDP(MMDistributedDataParallel): + + def __init__(self, + *, + device_ids: Optional[Union[List, int, torch.device]] = None, + **kwargs) -> None: + if device_ids is None: + if os.environ.get('LOCAL_RANK') is not None: + device_ids = [int(os.environ['LOCAL_RANK'])] + super().__init__(device_ids=device_ids, **kwargs) + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating are also defined in + this method, such as GAN. + """ + if isinstance(optim_wrapper, OptimWrapperDict): + log_vars = dict() + self.message_hub = MessageHub.get_current_instance() + cur_epoch = self.message_hub.get_info('epoch') + need_update_mutator = self.module.need_update_mutator(cur_epoch) + + # TODO process the input + if cur_epoch == self.module.finetune_epochs and \ + self.module.is_supernet: + # synchronize arch params to start the finetune stage. + for k, v in self.module.mutator.arch_params.items(): + dist.broadcast(v, src=0) + self.module.fix_subnet() + + # 1. update architecture + with optim_wrapper['architecture'].optim_context(self): + pseudo_data = self.module.data_preprocessor(data, True) + supernet_batch_inputs = pseudo_data['inputs'] + supernet_data_samples = pseudo_data['data_samples'] + supernet_loss = self( + supernet_batch_inputs, supernet_data_samples, mode='loss') + + supernet_losses, supernet_log_vars = self.module.parse_losses( + supernet_loss) + optim_wrapper['architecture'].backward( + supernet_losses, retain_graph=need_update_mutator) + optim_wrapper['architecture'].step() + optim_wrapper['architecture'].zero_grad() + log_vars.update(add_prefix(supernet_log_vars, 'supernet')) + + # 2. update mutator + if need_update_mutator: + with optim_wrapper['mutator'].optim_context(self): + mutator_loss = self.module.compute_mutator_loss() + mutator_losses, mutator_log_vars = \ + self.module.parse_losses(mutator_loss) + optim_wrapper['mutator'].update_params(mutator_losses) + log_vars.update(add_prefix(mutator_log_vars, 'mutator')) + # handle the grad of arch params & weights + self.module.handle_grads() + + else: + # Enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): + pseudo_data = self.module.data_preprocessor(data, True) + batch_inputs = pseudo_data['inputs'] + data_samples = pseudo_data['data_samples'] + losses = self(batch_inputs, data_samples, mode='loss') + parsed_losses, log_vars = self.module.parse_losses(losses) + optim_wrapper.update_params(parsed_losses) + + return log_vars diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 917364607..074eda445 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -3,12 +3,13 @@ from .mutable_channel import (MutableChannel, OneShotMutableChannel, SlimmableMutableChannel) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, - OneShotMutableModule, OneShotMutableOP) + OneHotMutableOP, OneShotMutableModule, + OneShotMutableOP) from .mutable_value import MutableValue, OneShotMutableValue __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable', - 'MutableValue', 'OneShotMutableValue' + 'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP' ] diff --git a/mmrazor/models/mutables/mutable_module/__init__.py b/mmrazor/models/mutables/mutable_module/__init__.py index d1904e8c8..bcf10c3a8 100644 --- a/mmrazor/models/mutables/mutable_module/__init__.py +++ b/mmrazor/models/mutables/mutable_module/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .diff_mutable_module import (DiffChoiceRoute, DiffMutableModule, - DiffMutableOP) + DiffMutableOP, OneHotMutableOP) from .mutable_module import MutableModule from .one_shot_mutable_module import OneShotMutableModule, OneShotMutableOP __all__ = [ 'DiffMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', - 'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule' + 'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule', + 'OneHotMutableOP' ] diff --git a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py index 9269d4a2a..379c59235 100644 --- a/mmrazor/models/mutables/mutable_module/diff_mutable_module.py +++ b/mmrazor/models/mutables/mutable_module/diff_mutable_module.py @@ -37,8 +37,9 @@ def __init__(self, **kwargs) -> None: def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None) -> Any: - """Calls either :func:`forward_fixed` or :func:`forward_choice` - depending on whether :func:`is_fixed` is ``True``. + """Calls either :func:`forward_fixed` or :func:`forward_arch_param` + depending on whether :func:`is_fixed` is ``True`` and whether + :func:`arch_param` is None. To reduce the coupling between `Mutable` and `Mutator`, the `arch_param` is generated by the `Mutator` and is passed to the @@ -52,6 +53,9 @@ def forward(self, x (Any): input data for forward computation. arch_param (nn.Parameter, optional): the architecture parameters for ``DiffMutableModule``. + + Returns: + Any: the result of forward """ if self.is_fixed: return self.forward_fixed(x) @@ -97,6 +101,10 @@ class DiffMutableOP(DiffMutableModule[str, str]): Args: candidates (dict[str, dict]): the configs for the candidate operations. + fix_threshold (float): The threshold that determines whether to fix + the choice of current module as the op with the maximum `probs`. + It happens when the maximum prob is `fix_threshold` or more higher + then all the other probs. Default to 1.0. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. alias (str, optional): alias of the `MUTABLE`. @@ -109,6 +117,7 @@ class DiffMutableOP(DiffMutableModule[str, str]): def __init__( self, candidates: Dict[str, Dict], + fix_threshold: float = 1.0, module_kwargs: Optional[Dict[str, Dict]] = None, alias: Optional[str] = None, init_cfg: Optional[Dict] = None, @@ -120,6 +129,10 @@ def __init__( f'but got: {len(candidates)}' self._is_fixed = False + if fix_threshold < 0 or fix_threshold > 1.0: + raise ValueError( + f'The fix_threshold should be in [0, 1]. Got {fix_threshold}.') + self.fix_threshold = fix_threshold self._candidates = self._build_ops(candidates, self.module_kwargs) @staticmethod @@ -242,6 +255,94 @@ def choices(self) -> List[str]: return list(self._candidates.keys()) +@MODELS.register_module() +class OneHotMutableOP(DiffMutableOP): + """A type of ``MUTABLES`` for one-hot sample based architecture search, + such as DSNAS. Search the best module by learnable parameters `arch_param`. + + Args: + candidates (dict[str, dict]): the configs for the candidate + operations. + module_kwargs (dict[str, dict], optional): Module initialization named + arguments. Defaults to None. + alias (str, optional): alias of the `MUTABLE`. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ + + def sample_weights(self, + arch_param: nn.Parameter, + probs: torch.Tensor, + random_sample: bool = False) -> Tensor: + """Use one-hot distributions to sample the arch weights based on the + arch params. + + Args: + arch_param (nn.Parameter): architecture parameters for + `DiffMutableModule`. + probs (Tensor): the probs of choice. + random_sample (bool): Whether to random sample arch weights or not + Defaults to False. + + Returns: + Tensor: Sampled one-hot arch weights. + """ + import torch.distributions as D + if random_sample: + uni = torch.ones_like(arch_param) + m = D.one_hot_categorical.OneHotCategorical(uni) + else: + m = D.one_hot_categorical.OneHotCategorical(probs=probs) + return m.sample() + + def forward_arch_param(self, + x: Any, + arch_param: Optional[nn.Parameter] = None + ) -> Tensor: + """Forward with architecture parameters. + + Args: + x (Any): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + arch_param (str, optional): architecture parameters for + `DiffMutableModule`. + + Returns: + Tensor: the result of forward with ``arch_param``. + """ + if arch_param is None: + return self.forward_all(x) + else: + # compute the probs of choice + probs = self.compute_arch_probs(arch_param=arch_param) + + if not self.is_fixed: + self.arch_weights = self.sample_weights(arch_param, probs) + sorted_param = torch.topk(probs, 2) + index = ( + sorted_param[0][0] - sorted_param[0][1] >= + self.fix_threshold) + if index: + self.fix_chosen(self.choices[index]) + + if self.is_fixed: + index = self.choices.index(self._chosen[0]) + self.arch_weights.data.zero_() + self.arch_weights.data[index].fill_(1.0) + self.arch_weights.requires_grad_() + + # forward based on self.arch_weights + outputs = list() + for prob, module in zip(self.arch_weights, + self._candidates.values()): + if prob > 0.: + outputs.append(prob * module(x)) + + return sum(outputs) + + @MODELS.register_module() class DiffChoiceRoute(DiffMutableModule[str, List[str]]): """A type of ``MUTABLES`` for Neural Architecture Search, which can select diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py index e50314d5f..ac6358049 100644 --- a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py +++ b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py @@ -88,8 +88,8 @@ def sample_choices(self): choices = dict() for group_id, mutables in self.search_groups.items(): - arch_parm = self.arch_params[str(group_id)] - choice = mutables[0].sample_choice(arch_parm) + arch_param = self.arch_params[str(group_id)] + choice = mutables[0].sample_choice(arch_param) choices[group_id] = choice return choices diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py index 7fa9e31d5..dc045932b 100644 --- a/mmrazor/models/mutators/module_mutator/module_mutator.py +++ b/mmrazor/models/mutators/module_mutator/module_mutator.py @@ -54,6 +54,24 @@ def prepare_from_supernet(self, supernet: Module) -> None: """ self._build_search_groups(supernet) + @property + def name2mutable(self) -> Dict[str, MUTABLE_TYPE]: + """Search space of supernet. + + Note: + To get the mapping: module name to mutable. + + Raises: + RuntimeError: Called before search space has been parsed. + + Returns: + Dict[str, MUTABLE_TYPE]: The name2mutable dict. + """ + if self._name2mutable is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access name2mutable!') + return self._name2mutable + @property def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: """Search group of supernet. @@ -80,6 +98,8 @@ def _build_name_mutable_mapping( for name, module in supernet.named_modules(): if isinstance(module, self.mutable_class_type): name2mutable[name] = module + self._name2mutable = name2mutable + return name2mutable def _build_alias_names_mapping(self, @@ -121,7 +141,7 @@ def _build_search_groups(self, supernet: Module) -> None: >>> import torch >>> from mmrazor.models.mutables.diff_mutable import DiffMutableOP - >>> # Assume that a toy model consists of three mutabels + >>> # Assume that a toy model consists of three mutables >>> # whose name are op1,op2,op3. The corresponding >>> # alias names of the three mutables are a1, a1, a2. >>> model = ToyModel() diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index c29a0b181..2b142f6ea 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -50,21 +50,22 @@ def load_fix_subnet(model: nn.Module, # In the corresponding mutable, it will check whether the `chosen` # format is correct. if isinstance(module, BaseMutable): - if getattr(module, 'alias', None): - alias = module.alias - assert alias in fix_mutable, \ - f'The alias {alias} is not in fix_modules, ' \ - 'please check your `fix_mutable`.' - chosen = fix_mutable.get(alias, None) - else: - mutable_name = name.lstrip(prefix) - if mutable_name not in fix_mutable and \ - not isinstance(module, DerivedMutable): - raise RuntimeError( - f'The module name {mutable_name} is not in ' - 'fix_mutable, please check your `fix_mutable`.') - chosen = fix_mutable.get(mutable_name, None) - module.fix_chosen(chosen) + if not module.is_fixed: + if getattr(module, 'alias', None): + alias = module.alias + assert alias in fix_mutable, \ + f'The alias {alias} is not in fix_modules, ' \ + 'please check your `fix_mutable`.' + chosen = fix_mutable.get(alias, None) + else: + mutable_name = name.lstrip(prefix) + if mutable_name not in fix_mutable and \ + not isinstance(module, DerivedMutable): + raise RuntimeError( + f'The module name {mutable_name} is not in ' + 'fix_mutable, please check your `fix_mutable`.') + chosen = fix_mutable.get(mutable_name, None) + module.fix_chosen(chosen) # convert dynamic op to static op _dynamic_to_static(model) @@ -89,7 +90,6 @@ def export_fix_subnet(model: nn.Module, if isinstance(module, DerivedMutable) and not dump_derived_mutable: continue - assert not module.is_fixed if module.alias: fix_subnet[module.alias] = module.dump_chosen() else: diff --git a/tests/test_models/test_algorithms/test_dsnas.py b/tests/test_models/test_algorithms/test_dsnas.py new file mode 100644 index 000000000..929840148 --- /dev/null +++ b/tests/test_models/test_algorithms/test_dsnas.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from unittest import TestCase +from unittest.mock import patch + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from mmcls.structures import ClsDataSample +from mmengine.model import BaseModel +from mmengine.optim import build_optim_wrapper +from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict +from torch import Tensor +from torch.optim import SGD + +from mmrazor.models import DiffModuleMutator, Dsnas, OneHotMutableOP +from mmrazor.models.algorithms.nas.dsnas import DsnasDDP +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +@MODELS.register_module() +class ToyDiffModule(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) + self.candidates = dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ) + module_kwargs = dict(in_channels=3, out_channels=8, stride=1) + + self.mutable = OneHotMutableOP( + candidates=self.candidates, module_kwargs=module_kwargs) + self.bn = nn.BatchNorm2d(8) + + def forward(self, batch_inputs, data_samples=None, mode='tensor'): + if mode == 'loss': + out = self.bn(self.mutable(batch_inputs)) + return dict(loss=out) + elif mode == 'predict': + out = self.bn(self.mutable(batch_inputs)) + 1 + return out + elif mode == 'tensor': + out = self.bn(self.mutable(batch_inputs)) + 2 + return out + + +class TestDsnas(TestCase): + + def setUp(self) -> None: + self.device: str = 'cpu' + + OPTIMIZER_CFG = dict( + type='SGD', + lr=0.5, + momentum=0.9, + nesterov=True, + weight_decay=0.0001) + + self.OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG) + + def test_init(self) -> None: + # initiate dsnas when `norm_training` is True. + model = ToyDiffModule() + mutator = DiffModuleMutator() + algo = Dsnas(architecture=model, mutator=mutator, norm_training=True) + algo.eval() + self.assertTrue(model.bn.training) + + # initiate Dsnas with built mutator + model = ToyDiffModule() + mutator = DiffModuleMutator() + algo = Dsnas(model, mutator) + self.assertIs(algo.mutator, mutator) + + # initiate Dsnas with unbuilt mutator + mutator = dict(type='DiffModuleMutator') + algo = Dsnas(model, mutator) + self.assertIsInstance(algo.mutator, DiffModuleMutator) + + # initiate Dsnas when `fix_subnet` is not None + fix_subnet = {'mutable': 'torch_conv2d_5x5'} + algo = Dsnas(model, mutator, fix_subnet=fix_subnet) + self.assertEqual(algo.architecture.mutable.num_choices, 1) + + # initiate Dsnas with error type `mutator` + with self.assertRaisesRegex(TypeError, 'mutator should be'): + Dsnas(model, model) + + def test_forward_loss(self) -> None: + inputs = torch.randn(1, 3, 8, 8) + model = ToyDiffModule() + + # supernet + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + algo = Dsnas(model, mutator) + loss = algo(inputs, mode='loss') + self.assertIsInstance(loss, dict) + + # subnet + fix_subnet = {'mutable': 'torch_conv2d_5x5'} + algo = Dsnas(model, fix_subnet=fix_subnet) + loss = algo(inputs, mode='loss') + self.assertIsInstance(loss, dict) + + def _prepare_fake_data(self): + imgs = torch.randn(16, 3, 224, 224).to(self.device) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 1000, + (16, ))).to(self.device) + ] + return {'inputs': imgs, 'data_samples': data_samples} + + def test_search_subnet(self) -> None: + model = ToyDiffModule() + + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + algo = Dsnas(model, mutator) + subnet = algo.search_subnet() + self.assertIsInstance(subnet, dict) + + @patch('mmengine.logging.message_hub.MessageHub.get_info') + def test_dsnas_train_step(self, mock_get_info) -> None: + model = ToyDiffModule() + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + mock_get_info.return_value = 2 + + algo = Dsnas(model, mutator) + data = self._prepare_fake_data() + optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG) + loss = algo.train_step(data, optim_wrapper) + + self.assertTrue(isinstance(loss['loss'], Tensor)) + + algo = Dsnas(model, mutator) + optim_wrapper_dict = OptimWrapperDict( + architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), + mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) + loss = algo.train_step(data, optim_wrapper_dict) + + self.assertIsNotNone(loss) + + +class TestDsnasDDP(TestDsnas): + + @classmethod + def setUpClass(cls) -> None: + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + + # initialize the process group + if torch.cuda.is_available(): + backend = 'nccl' + cls.device = 'cuda' + else: + backend = 'gloo' + dist.init_process_group(backend, rank=0, world_size=1) + + def prepare_model(self, device_ids=None) -> Dsnas: + model = ToyDiffModule().to(self.device) + mutator = DiffModuleMutator().to(self.device) + mutator.prepare_from_supernet(model) + + algo = Dsnas(model, mutator) + + return DsnasDDP( + module=algo, find_unused_parameters=True, device_ids=device_ids) + + @classmethod + def tearDownClass(cls) -> None: + dist.destroy_process_group() + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='cuda device is not avaliable') + def test_init(self) -> None: + ddp_model = self.prepare_model() + self.assertIsInstance(ddp_model, DsnasDDP) + + @patch('mmengine.logging.message_hub.MessageHub.get_info') + def test_dsnasddp_train_step(self, mock_get_info) -> None: + model = ToyDiffModule() + mutator = DiffModuleMutator() + mutator.prepare_from_supernet(model) + mock_get_info.return_value = 2 + + algo = Dsnas(model, mutator) + ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) + data = self._prepare_fake_data() + optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG) + loss = ddp_model.train_step(data, optim_wrapper) + + self.assertIsNotNone(loss) + + algo = Dsnas(model, mutator) + ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) + optim_wrapper_dict = OptimWrapperDict( + architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), + mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) + loss = ddp_model.train_step(data, optim_wrapper_dict) + + self.assertIsNotNone(loss) diff --git a/tests/test_models/test_mutables/test_onehotop.py b/tests/test_models/test_mutables/test_onehotop.py new file mode 100644 index 000000000..4ace5870d --- /dev/null +++ b/tests/test_models/test_mutables/test_onehotop.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +import torch.nn as nn + +from mmrazor.models import * # noqa:F403,F401 +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +class TestOneHotOP(TestCase): + + def test_forward_arch_param(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates']))) + output = op.forward_arch_param(input, arch_param=arch_param) + assert output is not None + + output = op.forward_arch_param(input, arch_param=None) + assert output is not None + + # test when some element of arch_param is 0 + arch_param = nn.Parameter(torch.ones(op.num_choices)) + output = op.forward_arch_param(input, arch_param=arch_param) + assert output is not None + + def test_forward_fixed(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + op.fix_chosen('torch_conv2d_7x7') + output = op.forward_fixed(input) + + assert output is not None + assert op.is_fixed is True + + def test_forward(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + # test set_forward_args + arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates']))) + op.set_forward_args(arch_param=arch_param) + output = op.forward(input) + assert output is not None + + # test dump_chosen + with pytest.raises(AssertionError): + op.dump_chosen() + + # test forward when is_fixed is True + op.fix_chosen('torch_conv2d_7x7') + output = op.forward(input) + + def test_property(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + op = MODELS.build(op_cfg) + + assert len(op.choices) == 3 + + # test is_fixed propty + assert op.is_fixed is False + + # test is_fixed setting + op.fix_chosen('torch_conv2d_5x5') + + with pytest.raises(AttributeError): + op.is_fixed = True + + # test fix choice when is_fixed is True + with pytest.raises(AttributeError): + op.fix_chosen('torch_conv2d_3x3') + + def test_module_kwargs(self): + op_cfg = dict( + type='mmrazor.OneHotMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + in_channels=32, + out_channels=32, + stride=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + in_channels=32, + out_channels=32, + stride=1, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + in_channels=32, + out_channels=32, + stride=1, + ), + torch_maxpool_3x3=dict( + type='torchMaxPool2d', + kernel_size=3, + stride=1, + ), + torch_avgpool_3x3=dict( + type='torchAvgPool2d', + kernel_size=3, + stride=1, + ), + ), + ) + op = MODELS.build(op_cfg) + input = torch.randn(4, 32, 64, 64) + + op.fix_chosen('torch_avgpool_3x3') + output = op.forward(input) + assert output is not None From ef39c51bb9589dfc05b1537d0b6756cd91526893 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Sat, 8 Oct 2022 10:31:04 +0800 Subject: [PATCH 5/7] [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option --- tools/train.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tools/train.py b/tools/train.py index 2ccff305c..146d0950f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import logging import os import os.path as osp from mmengine.config import Config, DictAction +from mmengine.logging import print_log from mmengine.runner import Runner from mmrazor.utils import register_all_modules @@ -13,6 +15,19 @@ def parse_args(): parser = argparse.ArgumentParser(description='Train an algorithm') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='enable automatic-mixed-precision training') + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='enable automatically scaling LR.') + parser.add_argument( + '--resume', + action='store_true', + help='resume from the latest checkpoint in the work_dir automatically') parser.add_argument( '--cfg-options', nargs='+', @@ -55,6 +70,35 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + + # enable automatically scaling LR + if args.auto_scale_lr: + if 'auto_scale_lr' in cfg and \ + 'enable' in cfg.auto_scale_lr and \ + 'base_batch_size' in cfg.auto_scale_lr: + cfg.auto_scale_lr.enable = True + else: + raise RuntimeError('Can not find "auto_scale_lr" or ' + '"auto_scale_lr.enable" or ' + '"auto_scale_lr.base_batch_size" in your' + ' configuration file.') + + cfg.resume = args.resume + # build the runner from config runner = Runner.from_cfg(cfg) From 6cd8c68d0f4c7ac6e0f48b278ef1df8af886c9ce Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Sat, 8 Oct 2022 10:35:48 +0800 Subject: [PATCH 6/7] [Fix] Fix darts metafile (#278) fix darts metafile --- configs/nas/mmcls/darts/metafile.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/nas/mmcls/darts/metafile.yml b/configs/nas/mmcls/darts/metafile.yml index 06fee9c85..4b0515b5b 100644 --- a/configs/nas/mmcls/darts/metafile.yml +++ b/configs/nas/mmcls/darts/metafile.yml @@ -13,7 +13,7 @@ Collections: Converted From: Code: https://github.com/quark0/darts Models: - - Name: darts_subnet_1xb96_cifar10 + - Name: darts_subnet_1xb96_cifar10_2.0 In Collection: Darts Metadata: Params(M): 3.42 @@ -24,5 +24,5 @@ Models: Metrics: Top 1 Accuracy: 97.32 Top 5 Accuracy: 99.94 - Config: configs/nas/darts/darts_subnet_1xb96_cifar10.py + Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py Weights: https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth From f98ac3416b7e1b4f0232a0a8080a997222b4b301 Mon Sep 17 00:00:00 2001 From: LKJacky <108643365+LKJacky@users.noreply.github.com> Date: Mon, 10 Oct 2022 10:06:57 +0800 Subject: [PATCH 7/7] fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai Co-authored-by: pppppM Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> --- .circleci/test.yml | 60 +++---- .github/workflows/build.yml | 159 ++++++++++++++++++ .github/workflows/lint.yml | 27 +++ configs/nas/mmcls/darts/metafile.yml | 2 +- .../2_train_different_types_algorithms.md | 2 +- mmrazor/datasets/crd_dataset_wrapper.py | 2 +- mmrazor/models/task_modules/tracer/parsers.py | 11 +- requirements/optional.txt | 2 +- tests/test_datasets/test_datasets.py | 38 ++--- .../test_transforms/test_formatting.py | 12 +- .../test_algorithms/test_autoslim.py | 3 +- .../test_models/test_algorithms/test_dsnas.py | 27 ++- .../test_algorithms/test_slimmable_network.py | 3 +- .../test_losses/test_distillation_losses.py | 2 +- .../test_mbv2_channel_mutator.py | 1 + 15 files changed, 262 insertions(+), 89 deletions(-) create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/lint.yml diff --git a/.circleci/test.yml b/.circleci/test.yml index 92ac230c9..5da20de36 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -26,7 +26,6 @@ jobs: command: | pip install interrogate interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 80 mmrazor - build_cpu: parameters: # The python version must match available image tags in @@ -37,8 +36,6 @@ jobs: type: string torchvision: type: string - mmcv: - type: string docker: - image: cimg/python:<< parameters.python >> resource_class: large @@ -58,20 +55,21 @@ jobs: name: Install PyTorch command: | python -V - python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html - when: condition: - equal: [ "3.9.0", << parameters.python >> ] + equal: ["3.9.0", << parameters.python >>] steps: - run: pip install "protobuf <= 3.20.1" && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake - run: name: Install mmrazor dependencies command: | - python -m pip install git+ssh://git@github.com/open-mmlab/mmengine.git@main - python -m pip install << parameters.mmcv >> - python -m pip install git+ssh://git@github.com/open-mmlab/mmclassification.git@dev-1.x - python -m pip install git+ssh://git@github.com/open-mmlab/mmdetection.git@dev-3.x - python -m pip install git+ssh://git@github.com/open-mmlab/mmsegmentation.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmengine.git@main + pip install -U openmim + mim install 'mmcv >= 2.0.0rc1' + pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x pip install -r requirements.txt - run: name: Build and install @@ -80,10 +78,9 @@ jobs: - run: name: Run unittests command: | - python -m coverage run --branch --source mmrazor -m pytest tests/ - python -m coverage xml - python -m coverage report -m - + coverage run --branch --source mmrazor -m pytest tests/ + coverage xml + coverage report -m build_cuda: parameters: torch: @@ -94,8 +91,6 @@ jobs: cudnn: type: integer default: 7 - mmcv: - type: string machine: image: ubuntu-2004-cuda-11.4:202110-01 # docker_layer_caching: true @@ -103,13 +98,13 @@ jobs: steps: - checkout - run: - # CLoning repos in VM since Docker doesn't have access to the private key + # Cloning repos in VM since Docker doesn't have access to the private key name: Clone Repos command: | - git clone -b main --depth 1 ssh://git@github.com/open-mmlab/mmengine.git /home/circleci/mmengine - git clone -b dev-3.x --depth 1 ssh://git@github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - git clone -b dev-1.x --depth 1 ssh://git@github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification - git clone -b dev-1.x --depth 1 ssh://git@github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation + git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine + git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection + git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification + git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation - run: name: Build Docker image command: | @@ -117,10 +112,10 @@ jobs: docker run --gpus all -t -d -v /home/circleci/project:/mmrazor -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmdetection:/mmdetection -v /home/circleci/mmclassification:/mmclassification -v /home/circleci/mmsegmentation:/mmsegmentation -w /mmrazor --name mmrazor mmrazor:gpu - run: name: Install mmrazor dependencies - # pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html command: | docker exec mmrazor pip install -e /mmengine - docker exec mmrazor pip install << parameters.mmcv >> + docker exec mmrazor pip install -U openmim + docker exec mmrazor mim install 'mmcv >= 2.0.0rc1' docker exec mmrazor pip install -e /mmdetection docker exec mmrazor pip install -e /mmclassification docker exec mmrazor pip install -e /mmsegmentation @@ -132,7 +127,7 @@ jobs: - run: name: Run unittests command: | - docker exec mmrazor python -m pytest tests/ + docker exec mmrazor pytest tests/ workflows: pr_stage_lint: @@ -144,10 +139,10 @@ workflows: branches: ignore: - dev-1.x + - 1.x pr_stage_test: when: - not: - << pipeline.parameters.lint_only >> + not: << pipeline.parameters.lint_only >> jobs: - lint: name: lint @@ -159,16 +154,14 @@ workflows: name: minimum_version_cpu torch: 1.6.0 torchvision: 0.7.0 - python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.6.0/mmcv_full-2.0.0rc1-cp36-cp36m-manylinux1_x86_64.whl + python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images requires: - lint - build_cpu: name: maximum_version_cpu - torch: 1.9.0 - torchvision: 0.10.0 + torch: 1.12.1 + torchvision: 0.13.1 python: 3.9.0 - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.9.0/mmcv_full-2.0.0rc1-cp39-cp39-manylinux1_x86_64.whl requires: - minimum_version_cpu - hold: @@ -181,20 +174,17 @@ workflows: # Use double quotation mark to explicitly specify its type # as string instead of number cuda: "10.2" - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu102/torch1.8.0/mmcv_full-2.0.0rc1-cp37-cp37m-manylinux1_x86_64.whl requires: - hold merge_stage_test: when: - not: - << pipeline.parameters.lint_only >> + not: << pipeline.parameters.lint_only >> jobs: - build_cuda: name: minimum_version_gpu torch: 1.6.0 # Use double quotation mark to explicitly specify its type # as string instead of number - mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu101/torch1.6.0/mmcv_full-2.0.0rc1-cp37-cp37m-manylinux1_x86_64.whl cuda: "10.1" filters: branches: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..5d15599f2 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,159 @@ +name: build + +on: + push: + paths-ignore: + - "README.md" + - "README_zh-CN.md" + - "model-index.yml" + - "configs/**" + - "docs/**" + - ".dev_scripts/**" + + pull_request: + paths-ignore: + - "README.md" + - "README_zh-CN.md" + - "docs/**" + - "demo/**" + - ".dev_scripts/**" + - ".circleci/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_linux: + runs-on: ubuntu-18.04 + strategy: + matrix: + python-version: [3.7] + torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0] + include: + - torch: 1.6.0 + torch_version: 1.6 + torchvision: 0.7.0 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + python-version: 3.8 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + python-version: 3.8 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + python-version: 3.8 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + python-version: 3.8 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + python-version: 3.8 + - torch: 1.12.0 + torch_version: 1.12 + torchvision: 0.13.0 + - torch: 1.12.0 + torch_version: 1.12 + torchvision: 0.13.0 + python-version: 3.8 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip + run: | + pip install pip --upgrade + pip install wheel + - name: Install PyTorch + run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install MMEngine + run: pip install git+https://github.com/open-mmlab/mmengine.git@main + - name: Install MMCV + run: | + pip install -U openmim + mim install 'mmcv >= 2.0.0rc1' + - name: Install MMCls + run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + - name: Install MMDet + run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + - name: Install MMSeg + run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + - name: Install other dependencies + run: pip install -r requirements.txt + - name: Build and install + run: rm -rf .eggs && pip install -e . + - name: Run unittests and generate coverage report + run: | + coverage run --branch --source mmrazor -m pytest tests/ + coverage xml + coverage report -m + # Upload coverage report for python3.8 && pytorch1.12.0 cpu + - name: Upload coverage to Codecov + if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}} + uses: codecov/codecov-action@v2 + with: + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: false + + # test_windows: + # runs-on: ${{ matrix.os }} + # strategy: + # matrix: + # os: [windows-2022] + # python: [3.7] + # platform: [cpu] + # steps: + # - uses: actions/checkout@v2 + # - name: Set up Python ${{ matrix.python-version }} + # uses: actions/setup-python@v2 + # with: + # python-version: ${{ matrix.python-version }} + # - name: Upgrade pip + # run: | + # pip install pip --upgrade + # pip install wheel + # - name: Install lmdb + # run: pip install lmdb + # - name: Install PyTorch + # run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html + # - name: Install mmrazor dependencies + # run: | + # pip install git+https://github.com/open-mmlab/mmengine.git@main + # pip install -U openmim + # mim install 'mmcv >= 2.0.0rc1' + # pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + # pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + # pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + # pip install -r requirements.txt + # - name: Build and install + # run: | + # pip install -e . + # - name: Run unittests and generate coverage report + # run: | + # pytest tests/ diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..36422d008 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,27 @@ +name: lint + +on: [push, pull_request] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install pre-commit hook + run: | + pip install pre-commit + pre-commit install + - name: Linting + run: pre-commit run --all-files + - name: Check docstring coverage + run: | + pip install interrogate + interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmrazor diff --git a/configs/nas/mmcls/darts/metafile.yml b/configs/nas/mmcls/darts/metafile.yml index 4b0515b5b..9594e6765 100644 --- a/configs/nas/mmcls/darts/metafile.yml +++ b/configs/nas/mmcls/darts/metafile.yml @@ -24,5 +24,5 @@ Models: Metrics: Top 1 Accuracy: 97.32 Top 5 Accuracy: 99.94 - Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py + Config: configs/nas/darts/darts_subnet_1xb96_cifar10_2.0.py Weights: https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth diff --git a/docs/en/user_guides/2_train_different_types_algorithms.md b/docs/en/user_guides/2_train_different_types_algorithms.md index 7acb025a0..06c689708 100644 --- a/docs/en/user_guides/2_train_different_types_algorithms.md +++ b/docs/en/user_guides/2_train_different_types_algorithms.md @@ -86,7 +86,7 @@ For example, the default `_channel_cfg_paths` is set in the config below. ```Python python ./tools/train.py \ - configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M \ + configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py \ --work-dir your_work_dir ``` diff --git a/mmrazor/datasets/crd_dataset_wrapper.py b/mmrazor/datasets/crd_dataset_wrapper.py index 308bc1e4c..aa62f383b 100644 --- a/mmrazor/datasets/crd_dataset_wrapper.py +++ b/mmrazor/datasets/crd_dataset_wrapper.py @@ -74,7 +74,7 @@ def _parse_fullset_contrast_info(self) -> None: # e.g. [2, 3, 5]. num_classes: int = self.num_classes # type: ignore if num_classes is None: - num_classes = len(self.dataset.CLASSES) + num_classes = max(self.dataset.get_gt_labels()) + 1 if not self.dataset.test_mode: # type: ignore # Parse info. diff --git a/mmrazor/models/task_modules/tracer/parsers.py b/mmrazor/models/task_modules/tracer/parsers.py index efcfc8613..c342da716 100644 --- a/mmrazor/models/task_modules/tracer/parsers.py +++ b/mmrazor/models/task_modules/tracer/parsers.py @@ -118,13 +118,20 @@ def parse_cat(tracer, grad_fn, module2name, param2module, cur_path, >>> # ``out`` is obtained by concatenating two tensors """ parents = grad_fn.next_functions + concat_id = '_'.join([str(id(p)) for p in parents]) + concat_id_list = [str(id(p)) for p in parents] + concat_id_list.sort() + concat_id = '_'.join(concat_id_list) + name = f'concat_{concat_id}' + + visited[name] = True sub_path_lists = list() - for i, parent in enumerate(parents): + for _, parent in enumerate(parents): sub_path_list = PathList() tracer.backward_trace(parent, module2name, param2module, Path(), sub_path_list, visited, shared_module) sub_path_lists.append(sub_path_list) - cur_path.append(PathConcatNode('CatNode', sub_path_lists)) + cur_path.append(PathConcatNode(name, sub_path_lists)) result_paths.append(copy.deepcopy(cur_path)) cur_path.pop(-1) diff --git a/requirements/optional.txt b/requirements/optional.txt index 609cc3925..32f7d6fd0 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,3 @@ albumentations>=0.3.2 scipy -timm +# timm diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py index 1e6031a97..1eaf72ec8 100644 --- a/tests/test_datasets/test_datasets.py +++ b/tests/test_datasets/test_datasets.py @@ -6,6 +6,7 @@ from unittest import TestCase import numpy as np +from mmcls.registry import DATASETS as CLS_DATASETS from mmrazor.registry import DATASETS from mmrazor.utils import register_all_modules @@ -15,7 +16,8 @@ class Test_CRD_CIFAR10(TestCase): - DATASET_TYPE = 'CRD_CIFAR10' + ORI_DATASET_TYPE = 'CIFAR10' + DATASET_TYPE = 'CRDDataset' @classmethod def setUpClass(cls) -> None: @@ -24,10 +26,11 @@ def setUpClass(cls) -> None: tmpdir = tempfile.TemporaryDirectory() cls.tmpdir = tmpdir data_prefix = tmpdir.name - cls.DEFAULT_ARGS = dict( + cls.ORI_DEFAULT_ARGS = dict( data_prefix=data_prefix, pipeline=[], test_mode=False) + cls.DEFAULT_ARGS = dict(neg_num=1, percent=0.5) - dataset_class = DATASETS.get(cls.DATASET_TYPE) + dataset_class = CLS_DATASETS.get(cls.ORI_DATASET_TYPE) base_folder = osp.join(data_prefix, dataset_class.base_folder) os.mkdir(base_folder) @@ -65,25 +68,16 @@ def test_initialize(self): dataset_class = DATASETS.get(self.DATASET_TYPE) # Test overriding metainfo by `metainfo` argument - cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}} + ori_cfg = { + **self.ORI_DEFAULT_ARGS, 'metainfo': { + 'classes': ('bus', 'car') + }, + 'type': self.ORI_DATASET_TYPE, + '_scope_': 'mmcls' + } + cfg = {'dataset': ori_cfg, **self.DEFAULT_ARGS} dataset = dataset_class(**cfg) - self.assertEqual(dataset.CLASSES, ('bus', 'car')) - - # Test overriding metainfo by `classes` argument - cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']} - dataset = dataset_class(**cfg) - self.assertEqual(dataset.CLASSES, ('bus', 'car')) - - classes_file = osp.join(ASSETS_ROOT, 'classes.txt') - cfg = {**self.DEFAULT_ARGS, 'classes': classes_file} - dataset = dataset_class(**cfg) - self.assertEqual(dataset.CLASSES, ('bus', 'car')) - self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1}) - - # Test invalid classes - cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)} - with self.assertRaisesRegex(ValueError, "type "): - dataset_class(**cfg) + self.assertEqual(dataset.dataset.CLASSES, ('bus', 'car')) @classmethod def tearDownClass(cls): @@ -91,4 +85,4 @@ def tearDownClass(cls): class Test_CRD_CIFAR100(Test_CRD_CIFAR10): - DATASET_TYPE = 'CRD_CIFAR100' + ORI_DATASET_TYPE = 'CIFAR100' diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 46aa671df..69e211aad 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -6,7 +6,7 @@ import numpy as np import torch from mmcls.structures import ClsDataSample -from mmengine.data import LabelData +from mmengine.structures import LabelData from mmrazor.datasets.transforms import PackCRDClsInputs @@ -34,7 +34,7 @@ def setUp(self): 'img': rng.rand(300, 400), 'gt_label': rng.randint(3, ), # TODO. - 'contrast_sample_idxs': rng.randint() + 'contrast_sample_idxs': rng.randint(3, ) } self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'flip') @@ -44,13 +44,13 @@ def test_transform(self): results = transform(copy.deepcopy(self.results1)) self.assertIn('inputs', results) self.assertIsInstance(results['inputs'], torch.Tensor) - self.assertIn('data_sample', results) - self.assertIsInstance(results['data_sample'], ClsDataSample) + self.assertIn('data_samples', results) + self.assertIsInstance(results['data_samples'], ClsDataSample) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertIsInstance(data_sample.gt_label, LabelData) def test_repr(self): transform = PackCRDClsInputs(meta_keys=self.meta_keys) self.assertEqual( - repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})') + repr(transform), f'PackCRDClsInputs(meta_keys={self.meta_keys})') diff --git a/tests/test_models/test_algorithms/test_autoslim.py b/tests/test_models/test_algorithms/test_autoslim.py index 850ce0937..f73222630 100644 --- a/tests/test_models/test_algorithms/test_autoslim.py +++ b/tests/test_models/test_algorithms/test_autoslim.py @@ -18,7 +18,8 @@ DISTILLER_TYPE = Union[torch.nn.Module, Dict] ARCHITECTURE_CFG = dict( - type='mmcls.ImageClassifier', + _scope_='mmcls', + type='ImageClassifier', backbone=dict(type='MobileNetV2', widen_factor=1.5), neck=dict(type='GlobalAveragePooling'), head=dict( diff --git a/tests/test_models/test_algorithms/test_dsnas.py b/tests/test_models/test_algorithms/test_dsnas.py index 929840148..9f6dfc902 100644 --- a/tests/test_models/test_algorithms/test_dsnas.py +++ b/tests/test_models/test_algorithms/test_dsnas.py @@ -170,19 +170,17 @@ def setUpClass(cls) -> None: os.environ['MASTER_PORT'] = '12345' # initialize the process group - if torch.cuda.is_available(): - backend = 'nccl' - cls.device = 'cuda' - else: - backend = 'gloo' + backend = 'nccl' if torch.cuda.is_available() else 'gloo' dist.init_process_group(backend, rank=0, world_size=1) def prepare_model(self, device_ids=None) -> Dsnas: - model = ToyDiffModule().to(self.device) - mutator = DiffModuleMutator().to(self.device) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + model = ToyDiffModule() + mutator = DiffModuleMutator() mutator.prepare_from_supernet(model) - algo = Dsnas(model, mutator) + algo = Dsnas(model, mutator).to(self.device) return DsnasDDP( module=algo, find_unused_parameters=True, device_ids=device_ids) @@ -199,24 +197,19 @@ def test_init(self) -> None: @patch('mmengine.logging.message_hub.MessageHub.get_info') def test_dsnasddp_train_step(self, mock_get_info) -> None: - model = ToyDiffModule() - mutator = DiffModuleMutator() - mutator.prepare_from_supernet(model) + ddp_model = self.prepare_model() mock_get_info.return_value = 2 - algo = Dsnas(model, mutator) - ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) data = self._prepare_fake_data() optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG) loss = ddp_model.train_step(data, optim_wrapper) self.assertIsNotNone(loss) - algo = Dsnas(model, mutator) - ddp_model = DsnasDDP(module=algo, find_unused_parameters=True) + ddp_model = self.prepare_model() optim_wrapper_dict = OptimWrapperDict( - architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)), - mutator=OptimWrapper(SGD(model.parameters(), lr=0.01))) + architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)), + mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01))) loss = ddp_model.train_step(data, optim_wrapper_dict) self.assertIsNotNone(loss) diff --git a/tests/test_models/test_algorithms/test_slimmable_network.py b/tests/test_models/test_algorithms/test_slimmable_network.py index 792e8e305..a9a7bc16e 100644 --- a/tests/test_models/test_algorithms/test_slimmable_network.py +++ b/tests/test_models/test_algorithms/test_slimmable_network.py @@ -15,7 +15,8 @@ from mmrazor.models.algorithms import SlimmableNetwork, SlimmableNetworkDDP MODEL_CFG = dict( - type='mmcls.ImageClassifier', + _scope_='mmcls', + type='ImageClassifier', backbone=dict(type='MobileNetV2', widen_factor=1.5), neck=dict(type='GlobalAveragePooling'), head=dict( diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index 4328f7865..37fea2baf 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -2,7 +2,7 @@ from unittest import TestCase import torch -from mmengine.data import BaseDataElement +from mmengine.structures import BaseDataElement from mmrazor import digit_version from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss, diff --git a/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py b/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py index 5b1f2a0e7..61ec34565 100644 --- a/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py +++ b/tests/test_models/test_mutators/test_classical_models/test_mbv2_channel_mutator.py @@ -15,6 +15,7 @@ from ..utils import load_and_merge_channel_cfgs MODEL_CFG = dict( + _scope_='mmcls', type='mmcls.ImageClassifier', backbone=dict(type='MobileNetV2', widen_factor=1.5), neck=dict(type='GlobalAveragePooling'),