Skip to content

Commit

Permalink
[Enhance] Add type hints in /ops: (#2030)
Browse files Browse the repository at this point in the history
* [Enhance] Add type hints in /ops:
`fused_bias_leakyrelu.py`, 'gather_points.py`, `group_points.py`.
There is no need to add type hints in `furthest_point_sample.py` and
`info.py`.
As for `focal_loss.py`, please see #1994.

* Modidied the default value of a variable.

* [Enhance] Add type hints in:
`knn.py`, `masked_conv.py`, `merge_cells.py`, `min_area_polygons.py`,
`modulated_deform_conv.py`, multi_scale_deform_attn.py`.

* Fix typehint.

* Fixed typehint.

* remove type hints of symbolic

* add no_type_check to ignore mypy check for method

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
WINDSKY45 and zhouzaida authored Jun 20, 2022
1 parent 305c2a3 commit b9a96e5
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 108 deletions.
30 changes: 21 additions & 9 deletions mmcv/ops/fused_bias_leakyrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
"""

@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
def forward(ctx, grad_output: torch.Tensor, out: torch.Tensor,
negative_slope: float, scale: float) -> tuple:
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
Expand All @@ -139,7 +140,8 @@ def forward(ctx, grad_output, out, negative_slope, scale):
return grad_input, grad_bias

@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
def backward(ctx, gradgrad_input: torch.Tensor,
gradgrad_bias: nn.Parameter) -> tuple:
out, = ctx.saved_tensors

# The second order deviation, in fact, contains two parts, while the
Expand All @@ -160,7 +162,8 @@ def backward(ctx, gradgrad_input, gradgrad_bias):
class FusedBiasLeakyReLUFunction(Function):

@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
def forward(ctx, input: torch.Tensor, bias: nn.Parameter,
negative_slope: float, scale: float) -> torch.Tensor:
empty = input.new_empty(0)

out = ext_module.fused_bias_leakyrelu(
Expand All @@ -178,7 +181,7 @@ def forward(ctx, input, bias, negative_slope, scale):
return out

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
out, = ctx.saved_tensors

grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
Expand All @@ -204,26 +207,32 @@ class FusedBiasLeakyReLU(nn.Module):
TODO: Implement the CPU version.
Args:
channel (int): The channel number of the feature map.
num_channels (int): The channel number of the feature map.
negative_slope (float, optional): Same as nn.LeakyRelu.
Defaults to 0.2.
scale (float, optional): A scalar to adjust the variance of the feature
map. Defaults to 2**0.5.
"""

def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
def __init__(self,
num_channels: int,
negative_slope: float = 0.2,
scale: float = 2**0.5):
super().__init__()

self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope
self.scale = scale

def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
self.scale)


def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
def fused_bias_leakyrelu(input: torch.Tensor,
bias: nn.Parameter,
negative_slope: float = 0.2,
scale: float = 2**0.5) -> torch.Tensor:
r"""Fused bias leaky ReLU function.
This function is introduced in the StyleGAN2:
Expand Down Expand Up @@ -256,7 +265,10 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
negative_slope, scale)


def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
def bias_leakyrelu_ref(x: torch.Tensor,
bias: nn.Parameter,
negative_slope: float = 0.2,
scale: float = 2**0.5) -> torch.Tensor:

if bias is not None:
assert bias.ndim == 1
Expand Down
4 changes: 3 additions & 1 deletion mmcv/ops/gather_points.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch
from torch.autograd import Function

Expand Down Expand Up @@ -37,7 +39,7 @@ def forward(ctx, features: torch.Tensor,
return output

@staticmethod
def backward(ctx, grad_out):
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
idx, C, N = ctx.for_backwards
B, npoint = idx.size()

Expand Down
34 changes: 19 additions & 15 deletions mmcv/ops/group_points.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Tuple, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -37,15 +37,15 @@ class QueryAndGroup(nn.Module):
"""

def __init__(self,
max_radius,
sample_num,
min_radius=0,
use_xyz=True,
return_grouped_xyz=False,
normalize_xyz=False,
uniform_sample=False,
return_unique_cnt=False,
return_grouped_idx=False):
max_radius: float,
sample_num: int,
min_radius: float = 0.,
use_xyz: bool = True,
return_grouped_xyz: bool = False,
normalize_xyz: bool = False,
uniform_sample: bool = False,
return_unique_cnt: bool = False,
return_grouped_idx: bool = False):
super().__init__()
self.max_radius = max_radius
self.min_radius = min_radius
Expand All @@ -64,7 +64,12 @@ def __init__(self,
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'

def forward(self, points_xyz, center_xyz, features=None):
def forward(
self,
points_xyz: torch.Tensor,
center_xyz: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple]:
"""
Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the
Expand All @@ -75,7 +80,7 @@ def forward(self, points_xyz, center_xyz, features=None):
points.
Returns:
torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped
Tuple | torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped
concatenated coordinates and features of points.
"""
# if self.max_radius is None, we will perform kNN instead of ball query
Expand Down Expand Up @@ -149,7 +154,7 @@ def __init__(self, use_xyz: bool = True):
def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: torch.Tensor = None):
features: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
Expand Down Expand Up @@ -210,8 +215,7 @@ def forward(ctx, features: torch.Tensor,
return output

@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""
Args:
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
Expand Down
4 changes: 3 additions & 1 deletion mmcv/ops/knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch.autograd import Function

Expand All @@ -19,7 +21,7 @@ class KNN(Function):
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
center_xyz: Optional[torch.Tensor] = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
Expand Down
31 changes: 20 additions & 11 deletions mmcv/ops/masked_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -27,7 +28,13 @@ def symbolic(g, features, mask, weight, bias, padding, stride):
stride_i=stride)

@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
def forward(ctx,
features: torch.Tensor,
mask: torch.Tensor,
weight: torch.nn.Parameter,
bias: torch.nn.Parameter,
padding: int = 0,
stride: int = 1) -> torch.Tensor:
assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
Expand Down Expand Up @@ -75,7 +82,7 @@ def forward(ctx, features, mask, weight, bias, padding=0, stride=1):

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
return (None, ) * 5


Expand All @@ -90,18 +97,20 @@ class MaskedConv2d(nn.Conv2d):
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)

def forward(self, input, mask=None):
def forward(self,
input: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if mask is None: # fallback to the normal Conv2d
return super().forward(input)
else:
Expand Down
39 changes: 23 additions & 16 deletions mmcv/ops/merge_cells.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from abc import abstractmethod
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -19,7 +20,7 @@ class BaseMergeCell(nn.Module):
another convolution layer.
Args:
in_channels (int): number of input channels in out_conv layer.
fused_channels (int): number of input channels in out_conv layer.
out_channels (int): number of output channels in out_conv layer.
with_out_conv (bool): Whether to use out_conv layer
out_conv_cfg (dict): Config dict for convolution layer, which should
Expand All @@ -42,18 +43,18 @@ class BaseMergeCell(nn.Module):
"""

def __init__(self,
fused_channels=256,
out_channels=256,
with_out_conv=True,
out_conv_cfg=dict(
fused_channels: Optional[int] = 256,
out_channels: Optional[int] = 256,
with_out_conv: bool = True,
out_conv_cfg: dict = dict(
groups=1, kernel_size=3, padding=1, bias=True),
out_norm_cfg=None,
out_conv_order=('act', 'conv', 'norm'),
with_input1_conv=False,
with_input2_conv=False,
input_conv_cfg=None,
input_norm_cfg=None,
upsample_mode='nearest'):
out_norm_cfg: Optional[dict] = None,
out_conv_order: tuple = ('act', 'conv', 'norm'),
with_input1_conv: bool = False,
with_input2_conv: bool = False,
input_conv_cfg: Optional[dict] = None,
input_norm_cfg: Optional[dict] = None,
upsample_mode: str = 'nearest'):
super().__init__()
assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv
Expand Down Expand Up @@ -111,7 +112,10 @@ def _resize(self, x, size):
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x

def forward(self, x1, x2, out_size=None):
def forward(self,
x1: torch.Tensor,
x2: torch.Tensor,
out_size: Optional[tuple] = None) -> torch.Tensor:
assert x1.shape[:2] == x2.shape[:2]
assert out_size is None or len(out_size) == 2
if out_size is None: # resize to larger one
Expand All @@ -131,7 +135,7 @@ def forward(self, x1, x2, out_size=None):

class SumCell(BaseMergeCell):

def __init__(self, in_channels, out_channels, **kwargs):
def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)

def _binary_op(self, x1, x2):
Expand All @@ -140,7 +144,7 @@ def _binary_op(self, x1, x2):

class ConcatCell(BaseMergeCell):

def __init__(self, in_channels, out_channels, **kwargs):
def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__(in_channels * 2, out_channels, **kwargs)

def _binary_op(self, x1, x2):
Expand All @@ -150,7 +154,10 @@ def _binary_op(self, x1, x2):

class GlobalPoolingCell(BaseMergeCell):

def __init__(self, in_channels=None, out_channels=None, **kwargs):
def __init__(self,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
**kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

Expand Down
4 changes: 3 additions & 1 deletion mmcv/ops/min_area_polygons.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['min_area_polygons'])


def min_area_polygons(pointsets):
def min_area_polygons(pointsets: torch.Tensor) -> torch.Tensor:
"""Find the smallest polygons that surrounds all points in the point sets.
Args:
Expand Down
Loading

0 comments on commit b9a96e5

Please sign in to comment.