Skip to content

Commit

Permalink
[Feature] Support receptive field search of CNN models (open-mmlab#2056)
Browse files Browse the repository at this point in the history
* support rfsearch

* add labs for rfsearch

* format

* format

* add docstring and type hints

* clean code

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* rm unused func

* update code

* update code

* update code

* update  details

* fix details

* support asymmetric kernel

* support asymmetric kernel

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

* add unit tests for rfsearch

* set device for Conv2dRFSearchOp

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* remove unused function search_estimate_only

* move unit tests

* Update tests/test_cnn/test_rfsearch/test_operator.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/rfsearch/operator.py

Co-authored-by: Yue Zhou <592267829@qq.com>

* change logger

* Update mmcv/cnn/rfsearch/operator.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: lzyhha <819814373@qq.com>
Co-authored-by: Zhongyu Li <44114862+lzyhha@users.noreply.github.com>
Co-authored-by: Yue Zhou <592267829@qq.com>
  • Loading branch information
5 people committed Mar 20, 2023
1 parent e7adffb commit 6c108ee
Show file tree
Hide file tree
Showing 7 changed files with 968 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .rfsearch import Conv2dRFSearchOp, RFSearchHook
from .utils import fuse_conv_bn, get_model_complexity_info
from .vgg import VGG, make_vgg_layer

Expand All @@ -23,5 +24,5 @@
'Scale', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', 'ConvTranspose2d',
'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'fuse_conv_bn',
'get_model_complexity_info'
'get_model_complexity_info', 'Conv2dRFSearchOp', 'RFSearchHook'
]
5 changes: 5 additions & 0 deletions mmcv/cnn/rfsearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp
from .search import RFSearchHook

__all__ = ['BaseConvRFSearchOp', 'Conv2dRFSearchOp', 'RFSearchHook']
170 changes: 170 additions & 0 deletions mmcv/cnn/rfsearch/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import numpy as np
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from torch import Tensor

from .utils import expand_rates, get_single_padding

logger = MMLogger.get_current_instance()


class BaseConvRFSearchOp(BaseModule):
"""Based class of ConvRFSearchOp.
Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict.
"""

def __init__(self, op_layer: nn.Module, global_config: dict):
super().__init__()
self.op_layer = op_layer
self.global_config = global_config

def normlize(self, weights: nn.Parameter) -> nn.Parameter:
"""Normalize weights.
Args:
weights (nn.Parameter): Weights to be normalized.
Returns:
nn.Parameters: Normalized weights.
"""
abs_weights = torch.abs(weights)
normalized_weights = abs_weights / torch.sum(abs_weights)
return normalized_weights


class Conv2dRFSearchOp(BaseConvRFSearchOp):
"""Enable Conv2d with receptive field searching ability.
Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict. Defaults to None.
By default this must include:
- "init_alphas": The value for initializing weights of each branch.
- "num_branches": The controller of the size of
search space (the number of branches).
- "exp_rate": The controller of the sparsity of search space.
- "mmin": The minimum dilation rate.
- "mmax": The maximum dilation rate.
Extra keys may exist, but are used by RFSearchHook, e.g., "step",
"max_step", "search_interval", and "skip_layer".
verbose (bool): Determines whether to print rf-next
related logging messages.
Defaults to True.
"""

def __init__(self,
op_layer: nn.Module,
global_config: dict,
verbose: bool = True):
super().__init__(op_layer, global_config)
assert global_config is not None, 'global_config is None'
self.num_branches = global_config['num_branches']
assert self.num_branches in [2, 3]
self.verbose = verbose
init_dilation = op_layer.dilation
self.dilation_rates = expand_rates(init_dilation, global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
self.dilation_rates = [(op_layer.dilation[0], r[1])
for r in self.dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
self.dilation_rates = [(r[0], op_layer.dilation[1])
for r in self.dilation_rates]

self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
nn.init.constant_(self.branch_weights, global_config['init_alphas'])

def forward(self, input: Tensor) -> Tensor:
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if len(self.dilation_rates) == 1:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(self.dilation_rates[0]),
dilation=self.dilation_rates[0],
groups=self.op_layer.groups,
)
]
else:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(r),
dilation=r,
groups=self.op_layer.groups,
) * norm_w[i] for i, r in enumerate(self.dilation_rates)
]
output = outputs[0]
for i in range(1, len(self.dilation_rates)):
output += outputs[i]
return output

def estimate_rates(self):
"""Estimate new dilation rate based on trained branch_weights."""
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if self.verbose:
logger.info('Estimate dilation {} with weight {}.'.format(
self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()))

sum0, sum1, w_sum = 0, 0, 0
for i in range(len(self.dilation_rates)):
sum0 += norm_w[i].item() * self.dilation_rates[i][0]
sum1 += norm_w[i].item() * self.dilation_rates[i][1]
w_sum += norm_w[i].item()
estimated = [
np.clip(
int(round(sum0 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item(),
np.clip(
int(round(sum1 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item()
]
self.op_layer.dilation = tuple(estimated)
self.op_layer.padding = self.get_padding(self.op_layer.dilation)
self.dilation_rates = [tuple(estimated)]
if self.verbose:
logger.info(f'Estimate as {tuple(estimated)}')

def expand_rates(self):
"""Expand dilation rate."""
dilation = self.op_layer.dilation
dilation_rates = expand_rates(dilation, self.global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
dilation_rates = [(dilation[0], r[1]) for r in dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
dilation_rates = [(r[0], dilation[1]) for r in dilation_rates]

self.dilation_rates = copy.deepcopy(dilation_rates)
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
nn.init.constant_(self.branch_weights,
self.global_config['init_alphas'])

def get_padding(self, dilation):
padding = (get_single_padding(self.op_layer.kernel_size[0],
self.op_layer.stride[0], dilation[0]),
get_single_padding(self.op_layer.kernel_size[1],
self.op_layer.stride[1], dilation[1]))
return padding
Loading

0 comments on commit 6c108ee

Please sign in to comment.