Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type hints for mmcv/ops/... #1987

Merged
merged 5 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmcv/ops/cc_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mmcv.cnn import PLUGIN_LAYERS, Scale


def NEG_INF_DIAG(n, device):
def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
"""Returns a diagonal matrix of size [n, n].

The diagonal are all "-inf". This is for avoiding calculating the
Expand Down Expand Up @@ -41,15 +41,15 @@ class CrissCrossAttention(nn.Module):
in_channels (int): Channels of the input feature map.
"""

def __init__(self, in_channels):
def __init__(self, in_channels: int) -> None:
super().__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = Scale(0.)
self.in_channels = in_channels

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""forward function of Criss-Cross Attention.

Args:
Expand Down Expand Up @@ -78,7 +78,7 @@ def forward(self, x):

return out

def __repr__(self):
def __repr__(self) -> str:
s = self.__class__.__name__
s += f'(in_channels={self.in_channels})'
return s
9 changes: 6 additions & 3 deletions mmcv/ops/contour_expand.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import numpy as np
import torch

Expand All @@ -7,8 +9,9 @@
ext_module = ext_loader.load_ext('_ext', ['contour_expand'])


def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
kernel_num):
def contour_expand(kernel_mask: Union[np.array, torch.Tensor],
internal_kernel_label: Union[np.array, torch.Tensor],
min_kernel_area: int, kernel_num: int) -> list:
"""Expand kernel contours so that foreground pixels are assigned into
instances.

Expand Down Expand Up @@ -42,7 +45,7 @@ def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
internal_kernel_label,
min_kernel_area=min_kernel_area,
kernel_num=kernel_num)
label = label.tolist()
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
label = label.tolist() # type: ignore
else:
label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
min_kernel_area, kernel_num)
Expand Down
10 changes: 8 additions & 2 deletions mmcv/ops/convex_iou.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['convex_iou', 'convex_giou'])


def convex_giou(pointsets, polygons):
def convex_giou(pointsets: torch.Tensor,
polygons: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return generalized intersection-over-union (Jaccard index) between point
sets and polygons.

Expand All @@ -26,7 +31,8 @@ def convex_giou(pointsets, polygons):
return convex_giou, points_grad


def convex_iou(pointsets, polygons):
def convex_iou(pointsets: torch.Tensor,
polygons: torch.Tensor) -> torch.Tensor:
"""Return intersection-over-union (Jaccard index) between point sets and
polygons.

Expand Down