Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints in mmcv/ops/point_sample.py #2019

Merged
merged 2 commits into from
May 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions mmcv/ops/point_sample.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa

from os import path as osp
from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.utils import _pair
from torch.onnx.operators import shape_as_tensor


def bilinear_grid_sample(im, grid, align_corners=False):
def bilinear_grid_sample(im: Tensor,
grid: Tensor,
align_corners: bool = False) -> Tensor:
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.

Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
Expand Down Expand Up @@ -85,14 +89,14 @@ def bilinear_grid_sample(im, grid, align_corners=False):
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)


def is_in_onnx_export_without_custom_ops():
def is_in_onnx_export_without_custom_ops() -> bool:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
return torch.onnx.is_in_onnx_export(
) and not osp.exists(ort_custom_op_path)


def normalize(grid):
def normalize(grid: Tensor) -> Tensor:
"""Normalize input grid from [-1, 1] to [0, 1]

Args:
Expand All @@ -105,7 +109,7 @@ def normalize(grid):
return (grid + 1.0) / 2.0


def denormalize(grid):
def denormalize(grid: Tensor) -> Tensor:
"""Denormalize input grid from range [0, 1] to [-1, 1]

Args:
Expand All @@ -118,7 +122,8 @@ def denormalize(grid):
return grid * 2.0 - 1.0


def generate_grid(num_grid, size, device):
def generate_grid(num_grid: int, size: Tuple[int, int],
device: torch.device) -> Tensor:
"""Generate regular square grid of points in [0, 1] x [0, 1] coordinate
space.

Expand All @@ -139,7 +144,8 @@ def generate_grid(num_grid, size, device):
return grid.view(1, -1, 2).expand(num_grid, -1, -1)


def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
def rel_roi_point_to_abs_img_point(rois: Tensor,
rel_roi_points: Tensor) -> Tensor:
"""Convert roi based relative point coordinates to image based absolute
point coordinates.

Expand Down Expand Up @@ -170,7 +176,7 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
return abs_img_points


def get_shape_from_feature_map(x):
def get_shape_from_feature_map(x: Tensor) -> Tensor:
"""Get spatial resolution of input feature map considering exporting to
onnx mode.

Expand All @@ -189,7 +195,9 @@ def get_shape_from_feature_map(x):
return img_shape


def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
def abs_img_point_to_rel_img_point(abs_img_points: Tensor,
img: Union[tuple, Tensor],
spatial_scale: float = 1.) -> Tensor:
"""Convert image based absolute point coordinates to image based relative
coordinates for sampling.

Expand Down Expand Up @@ -220,10 +228,10 @@ def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
return abs_img_points / scale * spatial_scale


def rel_roi_point_to_rel_img_point(rois,
rel_roi_points,
img,
spatial_scale=1.):
def rel_roi_point_to_rel_img_point(rois: Tensor,
rel_roi_points: Tensor,
img: Union[tuple, Tensor],
spatial_scale: float = 1.) -> Tensor:
"""Convert roi based relative point coordinates to image based absolute
point coordinates.

Expand All @@ -247,7 +255,10 @@ def rel_roi_point_to_rel_img_point(rois,
return rel_img_point


def point_sample(input, points, align_corners=False, **kwargs):
def point_sample(input: Tensor,
points: Tensor,
align_corners: bool = False,
**kwargs) -> Tensor:
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
lie inside ``[0, 1] x [0, 1]`` square.
Expand Down Expand Up @@ -285,7 +296,10 @@ def point_sample(input, points, align_corners=False, **kwargs):

class SimpleRoIAlign(nn.Module):

def __init__(self, output_size, spatial_scale, aligned=True):
def __init__(self,
output_size: Tuple[int],
spatial_scale: float,
aligned: bool = True) -> None:
"""Simple RoI align in PointRend, faster than standard RoIAlign.

Args:
Expand All @@ -303,7 +317,7 @@ def __init__(self, output_size, spatial_scale, aligned=True):
self.use_torchvision = False
self.aligned = aligned

def forward(self, features, rois):
def forward(self, features: Tensor, rois: Tensor) -> Tensor:
num_imgs = features.size(0)
num_rois = rois.size(0)
rel_roi_points = generate_grid(
Expand Down Expand Up @@ -339,7 +353,7 @@ def forward(self, features, rois):

return roi_feats

def __repr__(self):
def __repr__(self) -> str:
format_str = self.__class__.__name__
format_str += '(output_size={}, spatial_scale={}'.format(
self.output_size, self.spatial_scale)
Expand Down