Skip to content

Commit

Permalink
Add type hints for mmcv/cnn/bricks (#1993)
Browse files Browse the repository at this point in the history
* Add type hint

* Add typehint in mmcv/cnn/bricks*

* Deal conflict0

* Fix

* fix

* minor fix

* minor fix

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
triple-Mu and zhouzaida authored Jun 20, 2022
1 parent 2d3e42f commit 305c2a3
Show file tree
Hide file tree
Showing 19 changed files with 207 additions and 155 deletions.
10 changes: 6 additions & 4 deletions mmcv/cnn/bricks/activation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -28,12 +30,12 @@ class Clamp(nn.Module):
Default to 1.
"""

def __init__(self, min=-1., max=1.):
def __init__(self, min: float = -1., max: float = 1.):
super().__init__()
self.min = min
self.max = max

def forward(self, x):
def forward(self, x) -> torch.Tensor:
"""Forward function.
Args:
Expand Down Expand Up @@ -67,7 +69,7 @@ class GELU(nn.Module):
>>> output = m(input)
"""

def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)


Expand All @@ -78,7 +80,7 @@ def forward(self, input):
ACTIVATION_LAYERS.register_module(module=nn.GELU)


def build_activation_layer(cfg):
def build_activation_layer(cfg: Dict) -> nn.Module:
"""Build activation layer.
Args:
Expand Down
16 changes: 9 additions & 7 deletions mmcv/cnn/bricks/context_block.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
from torch import nn

from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS


def last_zero_init(m):
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
Expand Down Expand Up @@ -34,10 +36,10 @@ class ContextBlock(nn.Module):
_abbr_ = 'context_block'

def __init__(self,
in_channels,
ratio,
pooling_type='att',
fusion_types=('channel_add', )):
in_channels: int,
ratio: float,
pooling_type: str = 'att',
fusion_types: tuple = ('channel_add', )):
super().__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
Expand Down Expand Up @@ -82,7 +84,7 @@ def reset_parameters(self):
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)

def spatial_pool(self, x):
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
Expand All @@ -108,7 +110,7 @@ def spatial_pool(self, x):

return context

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1]
context = self.spatial_pool(x)

Expand Down
4 changes: 3 additions & 1 deletion mmcv/cnn/bricks/conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

from torch import nn

from .registry import CONV_LAYERS
Expand All @@ -9,7 +11,7 @@
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)


def build_conv_layer(cfg, *args, **kwargs):
def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
"""Build convolution layer.
Args:
Expand Down
20 changes: 11 additions & 9 deletions mmcv/cnn/bricks/conv2d_adaptive_padding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

Expand Down Expand Up @@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d):
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
dilation, groups, bias)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
img_h, img_w = x.size()[-2:]
kernel_h, kernel_w = self.weight.size()[-2:]
stride_h, stride_w = self.stride
Expand Down
44 changes: 25 additions & 19 deletions mmcv/cnn/bricks/conv_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn as nn

from mmcv.utils import _BatchNorm, _InstanceNorm
Expand Down Expand Up @@ -68,21 +70,21 @@ class ConvModule(nn.Module):
_abbr_ = 'conv_block'

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias='auto',
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
inplace=True,
with_spectral_norm=False,
padding_mode='zeros',
order=('conv', 'norm', 'act')):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: Union[bool, str] = 'auto',
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
act_cfg: Optional[Dict] = dict(type='ReLU'),
inplace: bool = True,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
Expand Down Expand Up @@ -143,18 +145,19 @@ def __init__(self,
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.norm_name, norm = build_norm_layer(
norm_cfg, norm_channels) # type: ignore
self.add_module(self.norm_name, norm)
if self.with_bias:
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
warnings.warn(
'Unnecessary conv bias before batch/instance norm')
else:
self.norm_name = None
self.norm_name = None # type: ignore

# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy()
act_cfg_ = act_cfg.copy() # type: ignore
# nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
Expand Down Expand Up @@ -193,7 +196,10 @@ def init_weights(self):
if self.with_norm:
constant_init(self.norm, 1, bias=0)

def forward(self, x, activate=True, norm=True):
def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order:
if layer == 'conv':
if self.with_explicit_padding:
Expand Down
68 changes: 37 additions & 31 deletions mmcv/cnn/bricks/conv_ws.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from .registry import CONV_LAYERS


def conv_ws_2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
eps=1e-5):
def conv_ws_2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
eps: float = 1e-5) -> torch.Tensor:
c_in = weight.size(0)
weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
Expand All @@ -26,15 +29,15 @@ def conv_ws_2d(input,
class ConvWS2d(nn.Conv2d):

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-5):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
eps: float = 1e-5):
super().__init__(
in_channels,
out_channels,
Expand All @@ -46,7 +49,7 @@ def __init__(self,
bias=bias)
self.eps = eps

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps)

Expand Down Expand Up @@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d):
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True):
super().__init__(
in_channels,
out_channels,
Expand All @@ -98,21 +101,24 @@ def __init__(self,
self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1))

def _get_weight(self, weight):
def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
weight = (weight - mean) / std
weight = self.weight_gamma * weight + self.weight_beta
return weight

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
local_metadata: Dict, strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Override default load function.
AWS overrides the function _load_from_state_dict to recover
Expand All @@ -124,7 +130,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
"""

self.weight_gamma.data.fill_(-1)
local_missing_keys = []
local_missing_keys: List = []
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys,
unexpected_keys, error_msgs)
Expand Down
Loading

0 comments on commit 305c2a3

Please sign in to comment.