diff --git a/mmcv/ops/point_sample.py b/mmcv/ops/point_sample.py index 0d3147c155..b40ccaba82 100644 --- a/mmcv/ops/point_sample.py +++ b/mmcv/ops/point_sample.py @@ -1,15 +1,19 @@ # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa from os import path as osp +from typing import Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from torch.nn.modules.utils import _pair from torch.onnx.operators import shape_as_tensor -def bilinear_grid_sample(im, grid, align_corners=False): +def bilinear_grid_sample(im: Tensor, + grid: Tensor, + align_corners: bool = False) -> Tensor: """Given an input and a flow-field grid, computes the output using input values and pixel locations from grid. Supported only bilinear interpolation method to sample the input pixels. @@ -17,7 +21,7 @@ def bilinear_grid_sample(im, grid, align_corners=False): Args: im (torch.Tensor): Input feature map, shape (N, C, H, W) grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) - align_corners {bool}: If set to True, the extrema (-1 and 1) are + align_corners (bool): If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, @@ -85,14 +89,14 @@ def bilinear_grid_sample(im, grid, align_corners=False): return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) -def is_in_onnx_export_without_custom_ops(): +def is_in_onnx_export_without_custom_ops() -> bool: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() return torch.onnx.is_in_onnx_export( ) and not osp.exists(ort_custom_op_path) -def normalize(grid): +def normalize(grid: Tensor) -> Tensor: """Normalize input grid from [-1, 1] to [0, 1] Args: @@ -105,7 +109,7 @@ def normalize(grid): return (grid + 1.0) / 2.0 -def denormalize(grid): +def denormalize(grid: Tensor) -> Tensor: """Denormalize input grid from range [0, 1] to [-1, 1] Args: @@ -118,7 +122,8 @@ def denormalize(grid): return grid * 2.0 - 1.0 -def generate_grid(num_grid, size, device): +def generate_grid(num_grid: int, size: Tuple[int, int], + device: torch.device) -> Tensor: """Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. @@ -139,7 +144,8 @@ def generate_grid(num_grid, size, device): return grid.view(1, -1, 2).expand(num_grid, -1, -1) -def rel_roi_point_to_abs_img_point(rois, rel_roi_points): +def rel_roi_point_to_abs_img_point(rois: Tensor, + rel_roi_points: Tensor) -> Tensor: """Convert roi based relative point coordinates to image based absolute point coordinates. @@ -170,7 +176,7 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points): return abs_img_points -def get_shape_from_feature_map(x): +def get_shape_from_feature_map(x: Tensor) -> Tensor: """Get spatial resolution of input feature map considering exporting to onnx mode. @@ -189,7 +195,9 @@ def get_shape_from_feature_map(x): return img_shape -def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.): +def abs_img_point_to_rel_img_point(abs_img_points: Tensor, + img: Union[tuple, Tensor], + spatial_scale: float = 1.) -> Tensor: """Convert image based absolute point coordinates to image based relative coordinates for sampling. @@ -220,10 +228,10 @@ def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.): return abs_img_points / scale * spatial_scale -def rel_roi_point_to_rel_img_point(rois, - rel_roi_points, - img, - spatial_scale=1.): +def rel_roi_point_to_rel_img_point(rois: Tensor, + rel_roi_points: Tensor, + img: Union[tuple, Tensor], + spatial_scale: float = 1.) -> Tensor: """Convert roi based relative point coordinates to image based absolute point coordinates. @@ -247,7 +255,10 @@ def rel_roi_point_to_rel_img_point(rois, return rel_img_point -def point_sample(input, points, align_corners=False, **kwargs): +def point_sample(input: Tensor, + points: Tensor, + align_corners: bool = False, + **kwargs) -> Tensor: """A wrapper around :func:`grid_sample` to support 3D point_coords tensors Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to lie inside ``[0, 1] x [0, 1]`` square. @@ -285,7 +296,10 @@ def point_sample(input, points, align_corners=False, **kwargs): class SimpleRoIAlign(nn.Module): - def __init__(self, output_size, spatial_scale, aligned=True): + def __init__(self, + output_size: Tuple[int], + spatial_scale: float, + aligned: bool = True) -> None: """Simple RoI align in PointRend, faster than standard RoIAlign. Args: @@ -303,7 +317,7 @@ def __init__(self, output_size, spatial_scale, aligned=True): self.use_torchvision = False self.aligned = aligned - def forward(self, features, rois): + def forward(self, features: Tensor, rois: Tensor) -> Tensor: num_imgs = features.size(0) num_rois = rois.size(0) rel_roi_points = generate_grid( @@ -339,7 +353,7 @@ def forward(self, features, rois): return roi_feats - def __repr__(self): + def __repr__(self) -> str: format_str = self.__class__.__name__ format_str += '(output_size={}, spatial_scale={}'.format( self.output_size, self.spatial_scale)