Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support receptive field search of CNN models. (TPAMI paper rf-next) #2056

Merged
merged 31 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a184844
support rfsearch
gasvn Apr 28, 2022
cef9214
add labs for rfsearch
gasvn Apr 28, 2022
dae649b
format
gasvn Jun 15, 2022
4e9cea9
format
gasvn Jun 15, 2022
a2d003a
add docstring and type hints
gasvn Jun 16, 2022
7f6f3d4
clean code
gasvn Jun 16, 2022
198832f
rm unused func
gasvn Jun 16, 2022
6ad6876
update code
gasvn Jul 3, 2022
ef25353
update code
gasvn Jul 6, 2022
1cf082b
update code
gasvn Jul 7, 2022
c718054
update details
gasvn Oct 3, 2022
22fb259
fix details
gasvn Oct 3, 2022
5cadcd1
support asymmetric kernel
lzyhha Nov 15, 2022
6712ab1
support asymmetric kernel
lzyhha Nov 17, 2022
6855c03
Apply suggestions from code review
lzyhha Nov 19, 2022
ea372a1
Apply suggestions from code review
lzyhha Nov 19, 2022
333cad9
Apply suggestions from code review
lzyhha Nov 20, 2022
eb363ed
Apply suggestions from code review
lzyhha Nov 20, 2022
7285f1d
Apply suggestions from code review
lzyhha Nov 20, 2022
f598a80
Apply suggestions from code review
lzyhha Nov 20, 2022
9b73758
Apply suggestions from code review
lzyhha Nov 20, 2022
a73fcfc
add unit tests for rfsearch
lzyhha Nov 21, 2022
87c30b2
set device for Conv2dRFSearchOp
lzyhha Nov 21, 2022
656e591
Apply suggestions from code review
lzyhha Nov 21, 2022
c1fbd12
remove unused function search_estimate_only
lzyhha Nov 21, 2022
9329a7a
move unit tests
lzyhha Nov 21, 2022
1809e50
Update tests/test_cnn/test_rfsearch/test_operator.py
lzyhha Nov 21, 2022
860577c
Apply suggestions from code review
lzyhha Nov 22, 2022
b6bbcb2
Update mmcv/cnn/rfsearch/operator.py
lzyhha Nov 22, 2022
26955e7
change logger
lzyhha Nov 22, 2022
b339444
Update mmcv/cnn/rfsearch/operator.py
lzyhha Nov 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .rfsearch import Conv2dRFSearchOp, RFSearchHook
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
Expand All @@ -37,5 +38,6 @@
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg', 'Conv2dRFSearchOp',
'RFSearchHook'
]
5 changes: 5 additions & 0 deletions mmcv/cnn/rfsearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp
from .search import RFSearchHook

__all__ = ['BaseConvRFSearchOp', 'Conv2dRFSearchOp', 'RFSearchHook']
170 changes: 170 additions & 0 deletions mmcv/cnn/rfsearch/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from mmcv.runner import BaseModule
from mmcv.utils.logging import get_logger
from .utils import expand_rates, get_single_padding

logger = get_logger('mmcv')


class BaseConvRFSearchOp(BaseModule):
"""Based class of ConvRFSearchOp.

Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict.
"""

def __init__(self, op_layer: nn.Module, global_config: dict):
super().__init__()
self.op_layer = op_layer
self.global_config = global_config

def normlize(self, weights: nn.Parameter) -> nn.Parameter:
"""Normalize weights.

Args:
weights (nn.Parameter): Weights to be normalized.

Returns:
nn.Parameters: Normalized weights.
"""
abs_weights = torch.abs(weights)
normalized_weights = abs_weights / torch.sum(abs_weights)
return normalized_weights


class Conv2dRFSearchOp(BaseConvRFSearchOp):
"""Enable Conv2d with receptive field searching ability.

Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict. Defaults to None.
gasvn marked this conversation as resolved.
Show resolved Hide resolved
By default this must include:

- "init_alphas": The value for initializing weights of each branch.
- "num_branches": The controller of the size of
search space (the number of branches).
- "exp_rate": The controller of the sparsity of search space.
- "mmin": The minimum dilation rate.
- "mmax": The maximum dilation rate.

Extra keys may exist, but are used by RFSearchHook, e.g., "step",
"max_step", "search_interval", and "skip_layer".
verbose (bool): Determines whether to print rf-next
related logging messages.
Defaults to True.
"""

def __init__(self,
op_layer: nn.Module,
global_config: dict,
verbose: bool = True):
super().__init__(op_layer, global_config)
assert global_config is not None, 'global_config is None'
self.num_branches = global_config['num_branches']
assert self.num_branches in [2, 3]
self.verbose = verbose
init_dilation = op_layer.dilation
self.dilation_rates = expand_rates(init_dilation, global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
self.dilation_rates = [(op_layer.dilation[0], r[1])
for r in self.dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
self.dilation_rates = [(r[0], op_layer.dilation[1])
for r in self.dilation_rates]

self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
nn.init.constant_(self.branch_weights, global_config['init_alphas'])

def forward(self, input: Tensor) -> Tensor:
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if len(self.dilation_rates) == 1:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(self.dilation_rates[0]),
dilation=self.dilation_rates[0],
groups=self.op_layer.groups,
)
]
else:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(r),
dilation=r,
groups=self.op_layer.groups,
) * norm_w[i] for i, r in enumerate(self.dilation_rates)
]
output = outputs[0]
for i in range(1, len(self.dilation_rates)):
output += outputs[i]
return output

def estimate_rates(self):
"""Estimate new dilation rate based on trained branch_weights."""
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if self.verbose:
logger.info('Estimate dilation {} with weight {}.'.format(
self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()))

sum0, sum1, w_sum = 0, 0, 0
for i in range(len(self.dilation_rates)):
sum0 += norm_w[i].item() * self.dilation_rates[i][0]
sum1 += norm_w[i].item() * self.dilation_rates[i][1]
w_sum += norm_w[i].item()
estimated = [
np.clip(
int(round(sum0 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item(),
np.clip(
int(round(sum1 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item()
]
self.op_layer.dilation = tuple(estimated)
self.op_layer.padding = self.get_padding(self.op_layer.dilation)
self.dilation_rates = [tuple(estimated)]
if self.verbose:
logger.info(f'Estimate as {tuple(estimated)}')

def expand_rates(self):
"""Expand dilation rate."""
dilation = self.op_layer.dilation
dilation_rates = expand_rates(dilation, self.global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
dilation_rates = [(dilation[0], r[1]) for r in dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
dilation_rates = [(r[0], dilation[1]) for r in dilation_rates]

self.dilation_rates = copy.deepcopy(dilation_rates)
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
nn.init.constant_(self.branch_weights,
self.global_config['init_alphas'])

def get_padding(self, dilation):
padding = (get_single_padding(self.op_layer.kernel_size[0],
self.op_layer.stride[0], dilation[0]),
get_single_padding(self.op_layer.kernel_size[1],
self.op_layer.stride[1], dilation[1]))
return padding
Loading