Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for mmcv/ops #1995

Merged
merged 27 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6ed7a82
Merge Master
May 26, 2022
035627f
Add typehint in mmcv/ops/*
May 26, 2022
673358e
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 26, 2022
959bfe0
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 27, 2022
bbbcb42
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 27, 2022
ed59600
Fix
May 28, 2022
cca0aef
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 29, 2022
156e02a
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 29, 2022
968fabf
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 30, 2022
03132b7
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 30, 2022
7c7c4b3
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu May 30, 2022
bbe9ebf
Update mmcv/ops/roi_align.py
triple-Mu May 31, 2022
c52913d
Fix
May 31, 2022
b0b0c8f
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu Jun 1, 2022
db32d55
Fix
Jun 1, 2022
be5dc96
Fix
Jun 1, 2022
1e3f137
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu Jun 2, 2022
8c03056
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu Jun 3, 2022
e02b1d4
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu Jun 8, 2022
2dc676a
Update mmcv/ops/riroi_align_rotated.py
triple-Mu Jun 17, 2022
9845d2e
Update mmcv/ops/riroi_align_rotated.py
triple-Mu Jun 17, 2022
6b84192
Merge branch 'open-mmlab:master' into tripleMu-typehint1
triple-Mu Jun 17, 2022
ad1df0e
remove type hints of all symbolic methods
zhouzaida Jun 18, 2022
df0aa4c
remove type hints of all symbolic methods
zhouzaida Jun 18, 2022
88ce86f
minor refinement
zhouzaida Jun 18, 2022
c038621
minor refinement
zhouzaida Jun 18, 2022
6161a7f
minor fix
zhouzaida Jun 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions mmcv/ops/psa_mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
from typing import Optional, Tuple

import torch
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
Expand All @@ -20,7 +23,8 @@ def symbolic(g, input, psa_type, mask_size):
mask_size_i=mask_size)

@staticmethod
def forward(ctx, input, psa_type, mask_size):
def forward(ctx, input: torch.Tensor, psa_type: str,
mask_size: int) -> torch.Tensor:
ctx.psa_type = psa_type
ctx.mask_size = _pair(mask_size)
ctx.save_for_backward(input)
Expand All @@ -45,7 +49,9 @@ def forward(ctx, input, psa_type, mask_size):
return output

@staticmethod
def backward(ctx, grad_output):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
input = ctx.saved_tensors[0]
psa_type = ctx.psa_type
h_mask, w_mask = ctx.mask_size
Expand All @@ -71,7 +77,7 @@ def backward(ctx, grad_output):

class PSAMask(nn.Module):

def __init__(self, psa_type, mask_size=None):
def __init__(self, psa_type: str, mask_size: Optional[tuple] = None):
super().__init__()
assert psa_type in ['collect', 'distribute']
if psa_type == 'collect':
Expand All @@ -82,7 +88,7 @@ def __init__(self, psa_type, mask_size=None):
self.mask_size = mask_size
self.psa_type = psa_type

def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return psa_mask(input, self.psa_type_enum, self.mask_size)

def __repr__(self):
Expand Down
41 changes: 24 additions & 17 deletions mmcv/ops/riroi_align_rotated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.autograd import Function

Expand All @@ -11,14 +14,14 @@
class RiRoIAlignRotatedFunction(Function):

@staticmethod
def forward(ctx,
features,
rois,
out_size,
spatial_scale,
num_samples=0,
num_orientations=8,
clockwise=False):
def forward(ctx: Any,
features: torch.Tensor,
rois: torch.Tensor,
out_size: Union[int, tuple],
spatial_scale: float,
num_samples: int = 0,
num_orientations: int = 8,
clockwise: bool = False) -> torch.Tensor:
if isinstance(out_size, int):
out_h = out_size
out_w = out_size
Expand Down Expand Up @@ -54,7 +57,9 @@ def forward(ctx,
return output

@staticmethod
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Optional[Tuple[torch.Tensor, None, None, None, None, None, None]]:
feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale
num_orientations = ctx.num_orientations
Expand All @@ -67,7 +72,7 @@ def backward(ctx, grad_output):
out_w = grad_output.size(3)
out_h = grad_output.size(2)

grad_input = grad_rois = None
grad_input = None

if ctx.needs_input_grad[0]:
grad_input = rois.new_zeros(batch_size, num_channels, feature_h,
Expand All @@ -83,7 +88,8 @@ def backward(ctx, grad_output):
num_orientations=num_orientations,
clockwise=clockwise)

return grad_input, grad_rois, None, None, None, None, None
return grad_input, None, None, None, None, None, None
return None


riroi_align_rotated = RiRoIAlignRotatedFunction.apply
Expand Down Expand Up @@ -111,11 +117,11 @@ class RiRoIAlignRotated(nn.Module):
"""

def __init__(self,
out_size,
spatial_scale,
num_samples=0,
num_orientations=8,
clockwise=False):
out_size: tuple,
spatial_scale: float,
num_samples: int = 0,
num_orientations: int = 8,
clockwise: bool = False):
super().__init__()

self.out_size = out_size
Expand All @@ -124,7 +130,8 @@ def __init__(self,
self.num_orientations = int(num_orientations)
self.clockwise = clockwise

def forward(self, features, rois):
def forward(self, features: torch.Tensor,
rois: torch.Tensor) -> torch.Tensor:
return RiRoIAlignRotatedFunction.apply(features, rois, self.out_size,
self.spatial_scale,
self.num_samples,
Expand Down
34 changes: 18 additions & 16 deletions mmcv/ops/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any

import torch
import torch.nn as nn
from torch.autograd import Function
Expand Down Expand Up @@ -62,14 +64,14 @@ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
mode_s=pool_mode)

@staticmethod
def forward(ctx,
input,
rois,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
pool_mode='avg',
aligned=True):
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: int,
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
pool_mode: str = 'avg',
aligned: bool = True) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
Expand Down Expand Up @@ -108,7 +110,7 @@ def forward(ctx,

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
rois, argmax_y, argmax_x = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous.
Expand Down Expand Up @@ -175,12 +177,12 @@ class RoIAlign(nn.Module):
},
cls_name='RoIAlign')
def __init__(self,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
use_torchvision=False):
output_size: tuple,
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
pool_mode: str = 'avg',
aligned: bool = True,
use_torchvision: bool = False):
super().__init__()

self.output_size = _pair(output_size)
Expand All @@ -190,7 +192,7 @@ def __init__(self,
self.aligned = aligned
self.use_torchvision = use_torchvision

def forward(self, input, rois):
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""
Args:
input: NCHW images
Expand Down
36 changes: 21 additions & 15 deletions mmcv/ops/roi_align_rotated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
Expand Down Expand Up @@ -37,14 +40,14 @@ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
clockwise_i=clockwise)

@staticmethod
def forward(ctx,
input,
rois,
output_size,
spatial_scale,
sampling_ratio=0,
aligned=True,
clockwise=False):
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: Union[int, tuple],
spatial_scale: float,
sampling_ratio: int = 0,
aligned: bool = True,
clockwise: bool = False) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
Expand All @@ -71,7 +74,10 @@ def forward(ctx,
return output

@staticmethod
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], None, None,
None, None, None]:
feature_size = ctx.feature_size
rois = ctx.saved_tensors[0]
assert feature_size is not None
Expand Down Expand Up @@ -151,11 +157,11 @@ class RoIAlignRotated(nn.Module):
},
cls_name='RoIAlignRotated')
def __init__(self,
output_size,
spatial_scale,
sampling_ratio=0,
aligned=True,
clockwise=False):
output_size: Union[int, tuple],
spatial_scale: float,
sampling_ratio: int = 0,
aligned: bool = True,
clockwise: bool = False):
super().__init__()

self.output_size = _pair(output_size)
Expand All @@ -164,7 +170,7 @@ def __init__(self,
self.aligned = aligned
self.clockwise = clockwise

def forward(self, input, rois):
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return RoIAlignRotatedFunction.apply(input, rois, self.output_size,
self.spatial_scale,
self.sampling_ratio, self.aligned,
Expand Down
18 changes: 14 additions & 4 deletions mmcv/ops/roi_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union

import torch
import torch.nn as nn
from torch.autograd import Function
Expand All @@ -23,7 +25,11 @@ def symbolic(g, input, rois, output_size, spatial_scale):
spatial_scale_f=spatial_scale)

@staticmethod
def forward(ctx, input, rois, output_size, spatial_scale=1.0):
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: Union[int, tuple],
spatial_scale: float = 1.0) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
Expand All @@ -49,7 +55,9 @@ def forward(ctx, input, rois, output_size, spatial_scale=1.0):

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
rois, argmax = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)

Expand All @@ -70,13 +78,15 @@ def backward(ctx, grad_output):

class RoIPool(nn.Module):

def __init__(self, output_size, spatial_scale=1.0):
def __init__(self,
output_size: Union[int, tuple],
spatial_scale: float = 1.0):
super().__init__()

self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)

def forward(self, input, rois):
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)

def __repr__(self):
Expand Down
19 changes: 14 additions & 5 deletions mmcv/ops/roiaware_pool3d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union

import torch
from torch import nn as nn
from torch.autograd import Function
Expand All @@ -25,7 +27,10 @@ class RoIAwarePool3d(nn.Module):
Default: 'max'.
"""

def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
def __init__(self,
out_size: Union[int, tuple],
max_pts_per_voxel: int = 128,
mode: str = 'max'):
super().__init__()

self.out_size = out_size
Expand All @@ -34,7 +39,8 @@ def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
pool_mapping = {'max': 0, 'avg': 1}
self.mode = pool_mapping[mode]

def forward(self, rois, pts, pts_feature):
def forward(self, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor) -> torch.Tensor:
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
Expand All @@ -55,8 +61,9 @@ def forward(self, rois, pts, pts_feature):
class RoIAwarePool3dFunction(Function):

@staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode):
def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor, out_size: Union[int, tuple],
max_pts_per_voxel: int, mode: int) -> torch.Tensor:
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
Expand Down Expand Up @@ -108,7 +115,9 @@ def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
return pooled_features

@staticmethod
def backward(ctx, grad_out):
def backward(
ctx: Any, grad_out: torch.Tensor
) -> Tuple[None, None, torch.Tensor, None, None, None]:
ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret

Expand Down
Loading