forked from open-mmlab/mmcv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support receptive field search of CNN models (open-mmlab#2056)
* support rfsearch * add labs for rfsearch * format * format * add docstring and type hints * clean code Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * rm unused func * update code * update code * update code * update details * fix details * support asymmetric kernel * support asymmetric kernel * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * add unit tests for rfsearch * set device for Conv2dRFSearchOp * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * remove unused function search_estimate_only * move unit tests * Update tests/test_cnn/test_rfsearch/test_operator.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/cnn/rfsearch/operator.py Co-authored-by: Yue Zhou <592267829@qq.com> * change logger * Update mmcv/cnn/rfsearch/operator.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: lzyhha <819814373@qq.com> Co-authored-by: Zhongyu Li <44114862+lzyhha@users.noreply.github.com> Co-authored-by: Yue Zhou <592267829@qq.com>
- Loading branch information
1 parent
595623f
commit 2e92c88
Showing
7 changed files
with
970 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp | ||
from .search import RFSearchHook | ||
|
||
__all__ = ['BaseConvRFSearchOp', 'Conv2dRFSearchOp', 'RFSearchHook'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import copy | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
|
||
from mmcv.runner import BaseModule | ||
from mmcv.utils.logging import get_logger | ||
from .utils import expand_rates, get_single_padding | ||
|
||
logger = get_logger('mmcv') | ||
|
||
|
||
class BaseConvRFSearchOp(BaseModule): | ||
"""Based class of ConvRFSearchOp. | ||
Args: | ||
op_layer (nn.Module): pytorch module, e,g, Conv2d | ||
global_config (dict): config dict. | ||
""" | ||
|
||
def __init__(self, op_layer: nn.Module, global_config: dict): | ||
super().__init__() | ||
self.op_layer = op_layer | ||
self.global_config = global_config | ||
|
||
def normlize(self, weights: nn.Parameter) -> nn.Parameter: | ||
"""Normalize weights. | ||
Args: | ||
weights (nn.Parameter): Weights to be normalized. | ||
Returns: | ||
nn.Parameters: Normalized weights. | ||
""" | ||
abs_weights = torch.abs(weights) | ||
normalized_weights = abs_weights / torch.sum(abs_weights) | ||
return normalized_weights | ||
|
||
|
||
class Conv2dRFSearchOp(BaseConvRFSearchOp): | ||
"""Enable Conv2d with receptive field searching ability. | ||
Args: | ||
op_layer (nn.Module): pytorch module, e,g, Conv2d | ||
global_config (dict): config dict. Defaults to None. | ||
By default this must include: | ||
- "init_alphas": The value for initializing weights of each branch. | ||
- "num_branches": The controller of the size of | ||
search space (the number of branches). | ||
- "exp_rate": The controller of the sparsity of search space. | ||
- "mmin": The minimum dilation rate. | ||
- "mmax": The maximum dilation rate. | ||
Extra keys may exist, but are used by RFSearchHook, e.g., "step", | ||
"max_step", "search_interval", and "skip_layer". | ||
verbose (bool): Determines whether to print rf-next | ||
related logging messages. | ||
Defaults to True. | ||
""" | ||
|
||
def __init__(self, | ||
op_layer: nn.Module, | ||
global_config: dict, | ||
verbose: bool = True): | ||
super().__init__(op_layer, global_config) | ||
assert global_config is not None, 'global_config is None' | ||
self.num_branches = global_config['num_branches'] | ||
assert self.num_branches in [2, 3] | ||
self.verbose = verbose | ||
init_dilation = op_layer.dilation | ||
self.dilation_rates = expand_rates(init_dilation, global_config) | ||
if self.op_layer.kernel_size[ | ||
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: | ||
self.dilation_rates = [(op_layer.dilation[0], r[1]) | ||
for r in self.dilation_rates] | ||
if self.op_layer.kernel_size[ | ||
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: | ||
self.dilation_rates = [(r[0], op_layer.dilation[1]) | ||
for r in self.dilation_rates] | ||
|
||
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) | ||
if self.verbose: | ||
logger.info(f'Expand as {self.dilation_rates}') | ||
nn.init.constant_(self.branch_weights, global_config['init_alphas']) | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) | ||
if len(self.dilation_rates) == 1: | ||
outputs = [ | ||
nn.functional.conv2d( | ||
input, | ||
weight=self.op_layer.weight, | ||
bias=self.op_layer.bias, | ||
stride=self.op_layer.stride, | ||
padding=self.get_padding(self.dilation_rates[0]), | ||
dilation=self.dilation_rates[0], | ||
groups=self.op_layer.groups, | ||
) | ||
] | ||
else: | ||
outputs = [ | ||
nn.functional.conv2d( | ||
input, | ||
weight=self.op_layer.weight, | ||
bias=self.op_layer.bias, | ||
stride=self.op_layer.stride, | ||
padding=self.get_padding(r), | ||
dilation=r, | ||
groups=self.op_layer.groups, | ||
) * norm_w[i] for i, r in enumerate(self.dilation_rates) | ||
] | ||
output = outputs[0] | ||
for i in range(1, len(self.dilation_rates)): | ||
output += outputs[i] | ||
return output | ||
|
||
def estimate_rates(self): | ||
"""Estimate new dilation rate based on trained branch_weights.""" | ||
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) | ||
if self.verbose: | ||
logger.info('Estimate dilation {} with weight {}.'.format( | ||
self.dilation_rates, | ||
norm_w.detach().cpu().numpy().tolist())) | ||
|
||
sum0, sum1, w_sum = 0, 0, 0 | ||
for i in range(len(self.dilation_rates)): | ||
sum0 += norm_w[i].item() * self.dilation_rates[i][0] | ||
sum1 += norm_w[i].item() * self.dilation_rates[i][1] | ||
w_sum += norm_w[i].item() | ||
estimated = [ | ||
np.clip( | ||
int(round(sum0 / w_sum)), self.global_config['mmin'], | ||
self.global_config['mmax']).item(), | ||
np.clip( | ||
int(round(sum1 / w_sum)), self.global_config['mmin'], | ||
self.global_config['mmax']).item() | ||
] | ||
self.op_layer.dilation = tuple(estimated) | ||
self.op_layer.padding = self.get_padding(self.op_layer.dilation) | ||
self.dilation_rates = [tuple(estimated)] | ||
if self.verbose: | ||
logger.info(f'Estimate as {tuple(estimated)}') | ||
|
||
def expand_rates(self): | ||
"""Expand dilation rate.""" | ||
dilation = self.op_layer.dilation | ||
dilation_rates = expand_rates(dilation, self.global_config) | ||
if self.op_layer.kernel_size[ | ||
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: | ||
dilation_rates = [(dilation[0], r[1]) for r in dilation_rates] | ||
if self.op_layer.kernel_size[ | ||
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: | ||
dilation_rates = [(r[0], dilation[1]) for r in dilation_rates] | ||
|
||
self.dilation_rates = copy.deepcopy(dilation_rates) | ||
if self.verbose: | ||
logger.info(f'Expand as {self.dilation_rates}') | ||
nn.init.constant_(self.branch_weights, | ||
self.global_config['init_alphas']) | ||
|
||
def get_padding(self, dilation): | ||
padding = (get_single_padding(self.op_layer.kernel_size[0], | ||
self.op_layer.stride[0], dilation[0]), | ||
get_single_padding(self.op_layer.kernel_size[1], | ||
self.op_layer.stride[1], dilation[1])) | ||
return padding |
Oops, something went wrong.