Skip to content

Commit

Permalink
[Feature] Add performance predictor (#306)
Browse files Browse the repository at this point in the history
* add predictor with 4 handlers

* [Improvement] Update Candidate with multi-dim search constraints. (#322)

* update doc

* add support type

* clean code

* update candidates

* clean

* xx

* set_resource -> set_score

* fix ci bug

* py36 lint

* fix bug

* fix check constrain

* py36 ci

* redesign candidate

* fix pre-commit

* update cfg

* add build_resource_estimator

* fix ci bug

* remove runner.epoch in testcase

* update metric_predictor:
1. update MetricPredictor;
2. add predictor config for searching;
3. add predictor in evolution_search_loop.

* add UT for predictor

* add MLPHandler

* patch optional.txt for predictors

* patch test_evolution_search_loop

* refactor apis of predictor and handlers

* fix ut and remove predictor_cfg in predictor

* adapt new mutable & mutator design

* fix ut

* remove unness assert after rebase

* move predictor-build in __init__ & simplify estimator-build

Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
  • Loading branch information
gaoyang07 and Yue Sun authored Nov 14, 2022
1 parent fb42405 commit 18fc50f
Show file tree
Hide file tree
Showing 17 changed files with 1,190 additions and 117 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']

model = dict(norm_training=True)

train_cfg = dict(
_delete_=True,
type='mmrazor.EvolutionSearchLoop',
dataloader=_base_.val_dataloader,
evaluator=_base_.val_evaluator,
max_epochs=20,
num_candidates=50,
top_k=10,
num_mutation=25,
num_crossover=25,
mutate_prob=0.1,
constraints_range=dict(flops=(0., 360.)),
predictor_cfg=dict(
type='mmrazor.MetricPredictor',
train_samples=20,
handler_cfg=dict(type='mmrazor.GaussProcessHandler')),
)
11 changes: 7 additions & 4 deletions mmrazor/engine/hooks/estimate_resources_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import TASK_UTILS

DATA_BATCH = Optional[Sequence[dict]]

Expand All @@ -23,7 +23,7 @@ class EstimateResourcesHook(Hook):
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default to True.
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
Default to dict().
Default to None.
Example:
>>> add the `EstimatorResourcesHook` in custom_hooks as follows:
Expand All @@ -41,11 +41,14 @@ class EstimateResourcesHook(Hook):
def __init__(self,
interval: int = -1,
by_epoch: bool = True,
estimator_cfg: Dict[str, Any] = dict(),
estimator_cfg: Dict[str, Any] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
self.estimator = ResourceEstimator(**estimator_cfg)
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)

def after_val_epoch(self,
runner,
Expand Down
146 changes: 85 additions & 61 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import os.path as osp
import random
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from mmengine import fileio
from mmengine.dist import broadcast_object_list
Expand All @@ -14,7 +14,6 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
Expand Down Expand Up @@ -45,8 +44,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
crossover_prob (float): The probability of crossover. Defaults to 0.5.
constraints_range (Dict[str, Any]): Constraints to be used for
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
estimator_cfg (dict, Optional): Used for building a resource estimator.
Defaults to None.
predictor_cfg (dict, Optional): Used for building a metric predictor.
Defaults to None.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
init_candidates (str, optional): The candidates file path, which is
Expand All @@ -68,7 +69,8 @@ def __init__(self,
mutate_prob: float = 0.1,
crossover_prob: float = 0.5,
constraints_range: Dict[str, Any] = dict(flops=(0., 330.)),
resource_estimator_cfg: Optional[Dict] = None,
estimator_cfg: Optional[Dict] = None,
predictor_cfg: Optional[Dict] = None,
score_key: str = 'accuracy/top1',
init_candidates: Optional[str] = None) -> None:
super().__init__(runner, dataloader, max_epochs)
Expand Down Expand Up @@ -109,56 +111,28 @@ def __init__(self,
else:
self.model = runner.model

# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)

def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.
Examples of ``resource_estimator``:
# `ResourceEstimator` will be used
resource_estimator = dict()
# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')
Args:
resource_estimator (ResourceEstimator or dict): A
resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.
Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')

resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore

if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore

return estimator # type: ignore
# initialize estimator
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)

# initialize predictor
self.use_predictor = False
self.predictor_cfg = predictor_cfg
if self.predictor_cfg is not None:
self.predictor_cfg['score_key'] = self.score_key
self.predictor_cfg['search_groups'] = \
self.model.mutator.search_groups
self.predictor = TASK_UTILS.build(self.predictor_cfg)

def run(self) -> None:
"""Launch searching."""
self.runner.call_hook('before_train')

if self.predictor_cfg is not None:
self._init_predictor()

if self.resume_from:
self._resume()

Expand All @@ -174,7 +148,7 @@ def run_epoch(self) -> None:
"""Iterate one epoch.
Steps:
1. Sample some new candidates from the supernet.Then Append them
1. Sample some new candidates from the supernet. Then Append them
to the candidates, Thus make its number equal to the specified
number.
2. Validate these candidates(step 1) and update their scores.
Expand Down Expand Up @@ -240,8 +214,8 @@ def update_candidates_scores(self) -> None:
top-k candicates."""
for i, candidate in enumerate(self.candidates.subnets):
self.model.set_subnet(candidate)
metrics = self._val_candidate()
score = metrics[self.score_key] \
metrics = self._val_candidate(use_predictor=self.use_predictor)
score = round(metrics[self.score_key], 2) \
if len(metrics) != 0 else 0.
self.candidates.set_resource(i, score, 'score')
self.runner.logger.info(
Expand All @@ -250,7 +224,7 @@ def update_candidates_scores(self) -> None:
f'Flops: {self.candidates.resources("flops")[i]} '
f'Params: {self.candidates.resources("params")[i]} '
f'Latency: {self.candidates.resources("latency")[i]} '
f'Score: {self.candidates.scores} ')
f'Score: {self.candidates.scores[i]} ')

def gen_mutation_candidates(self):
"""Generate specified number of mutation candicates."""
Expand Down Expand Up @@ -340,13 +314,23 @@ def _save_best_fix_subnet(self):
f'{save_name} saved in {self.runner.work_dir}.')

@torch.no_grad()
def _val_candidate(self) -> Dict:
"""Run validation."""
self.runner.model.eval()
for data_batch in self.dataloader:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(outputs, data_batch)
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
def _val_candidate(self, use_predictor: bool = False) -> Dict:
"""Run validation.
Args:
use_predictor (bool): Whether to use predictor to get metrics.
Defaults to False.
"""
if use_predictor:
assert self.predictor is not None
metrics = self.predictor.predict(self.model)
else:
self.runner.model.eval()
for data_batch in self.dataloader:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(
data_samples=outputs, data_batch=data_batch)
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
return metrics

def _save_searcher_ckpt(self) -> None:
Expand Down Expand Up @@ -391,3 +375,43 @@ def _check_constraints(
constraints_range=self.constraints_range)

return is_pass, results

def _init_predictor(self):
"""Initialize predictor, training is required."""
if self.predictor.handler_ckpt:
self.predictor.load_checkpoint()
self.runner.logger.info(
f'Loaded Checkpoints from {self.predictor.handler_ckpt}')
else:
self.runner.logger.info('No predictor checkpoints found. '
'Start pre-training the predictor.')
if isinstance(self.predictor.train_samples, str):
self.runner.logger.info('Find specified samples in '
f'{self.predictor.train_samples}')
train_samples = fileio.load(self.predictor.train_samples)
self.candidates = train_samples['subnets']
else:
self.runner.logger.info(
'Without specified samples. Start random sampling.')
temp_num_candidates = self.num_candidates
self.num_candidates = self.predictor.train_samples

assert self.use_predictor is False, (
'Real evaluation is required when initializing predictor.')
self.sample_candidates()
self.update_candidates_scores()
self.num_candidates = temp_num_candidates

inputs = []
for candidate in self.candidates.subnets:
inputs.append(self.predictor.model2vector(candidate))
inputs = np.array(inputs)
labels = np.array(self.candidates.scores)
self.predictor.fit(inputs, labels)
if self.runner.rank == 0:
predictor_dir = self.predictor.save_checkpoint(
osp.join(self.runner.work_dir, 'predictor'))
self.runner.logger.info(
f'Predictor pre-trained, saved in {predictor_dir}.')
self.use_predictor = True
self.candidates = Candidates()
58 changes: 8 additions & 50 deletions mmrazor/engine/runner/subnet_sampler_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os
import random
Expand All @@ -13,7 +12,6 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates
from mmrazor.utils import SupportRandomSubnet
Expand Down Expand Up @@ -102,8 +100,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
candidates. Defaults to 'accuracy_top-1'.
constraints_range (Dict[str, Any]): Constraints to be used for
screening candidates. Defaults to dict(flops=(0, 330)).
resource_estimator_cfg (dict, Optional): Used for building a
resource estimator. Defaults to None.
estimator_cfg (dict, Optional): Used for building a resource estimator.
Defaults to None.
num_candidates (int): The number of the candidates consist of samples
from supernet and itself. Defaults to 1000.
num_samples (int): The number of sample in each sampling subnet.
Expand Down Expand Up @@ -138,7 +136,7 @@ def __init__(self,
val_interval: int = 1000,
score_key: str = 'accuracy/top1',
constraints_range: Dict[str, Any] = dict(flops=(0, 330)),
resource_estimator_cfg: Optional[Dict] = None,
estimator_cfg: Optional[Dict] = None,
num_candidates: int = 1000,
num_samples: int = 10,
top_k: int = 5,
Expand Down Expand Up @@ -176,51 +174,11 @@ def __init__(self,
self.candidates = Candidates()
self.top_k_candidates = Candidates()

# Build resource estimator.
resource_estimator_cfg = dict(
) if resource_estimator_cfg is None else resource_estimator_cfg
self.estimator = self.build_resource_estimator(resource_estimator_cfg)

def build_resource_estimator(
self, resource_estimator: Union[ResourceEstimator,
Dict]) -> ResourceEstimator:
"""Build resource estimator for search loop.
Examples of ``resource_estimator``:
# `ResourceEstimator` will be used
resource_estimator = dict()
# custom resource_estimator
resource_estimator = dict(type='mmrazor.ResourceEstimator')
Args:
resource_estimator (ResourceEstimator or dict):
A resource_estimator or a dict to build resource estimator.
If ``resource_estimator`` is a resource estimator object,
just returns itself.
Returns:
:obj:`ResourceEstimator`: Resource estimator object build from
``resource_estimator``.
"""
if isinstance(resource_estimator, ResourceEstimator):
return resource_estimator
elif not isinstance(resource_estimator, dict):
raise TypeError(
'resource estimator should be a ResourceEstimator object or'
f'dict, but got {resource_estimator}')

resource_estimator_cfg = copy.deepcopy(
resource_estimator) # type: ignore

if 'type' in resource_estimator_cfg:
estimator = TASK_UTILS.build(resource_estimator_cfg)
else:
estimator = ResourceEstimator(
**resource_estimator_cfg) # type: ignore

return estimator # type: ignore
# initialize estimator
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)

def run(self) -> None:
"""Launch training."""
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/models/algorithms/nas/dsnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def __init__(self,
**kwargs):
super().__init__(architecture, data_preprocessor, **kwargs)

if estimator_cfg is None:
estimator_cfg = dict(type='mmrazor.ResourceEstimator')
# initialize estimator
estimator_cfg = dict() if estimator_cfg is None else estimator_cfg
if 'type' not in estimator_cfg:
estimator_cfg['type'] = 'mmrazor.ResourceEstimator'
self.estimator = TASK_UTILS.build(estimator_cfg)
if fix_subnet:
# Avoid circular import
Expand Down
1 change: 1 addition & 0 deletions mmrazor/models/task_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .delivery import * # noqa: F401,F403
from .estimators import ResourceEstimator
from .predictor import * # noqa: F401,F403
from .recorder import * # noqa: F401,F403
from .tracer import * # noqa: F401,F403

Expand Down
4 changes: 4 additions & 0 deletions mmrazor/models/task_modules/predictor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .metric_predictor import MetricPredictor

__all__ = ['MetricPredictor']
Loading

0 comments on commit 18fc50f

Please sign in to comment.