Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] enable exporting to onnx for PointRend #953

Merged
merged 24 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 161 additions & 39 deletions mmcv/ops/point_sample.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,94 @@
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa

from os import path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.onnx.operators import shape_as_tensor


def bilinear_grid_sample(im, grid, align_corners=False):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.

Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.
Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn

x = grid[:, :, :, 0]
y = grid[:, :, :, 1]

if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2

x = x.view(n, -1)
y = y.view(n, -1)

x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1

wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)

# Apply default for grid_sample function zero padding
im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

# Clip coordinates to padded image size
x0 = torch.where(x0 < 0, torch.tensor(0), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
x1 = torch.where(x1 < 0, torch.tensor(0), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
y0 = torch.where(y0 < 0, torch.tensor(0), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
y1 = torch.where(y1 < 0, torch.tensor(0), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

im_padded = im_padded.view(n, c, -1)

x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)

return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)


def is_in_onnx_export_without_custom_ops():
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
return torch.onnx.is_in_onnx_export(
) and not osp.exists(ort_custom_op_path)


def normalize(grid):
Expand Down Expand Up @@ -70,46 +155,67 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
if rois.size(1) == 5:
rois = rois[:, 1:]
abs_img_points = rel_roi_points.clone()
abs_img_points[:, :, 0] = abs_img_points[:, :, 0] * (
rois[:, None, 2] - rois[:, None, 0])
abs_img_points[:, :, 1] = abs_img_points[:, :, 1] * (
rois[:, None, 3] - rois[:, None, 1])
abs_img_points[:, :, 0] += rois[:, None, 0]
abs_img_points[:, :, 1] += rois[:, None, 1]
# To avoid an error during exporting to onnx use independent
# variables instead inplace computation
xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
xs += rois[:, None, 0]
ys += rois[:, None, 1]
abs_img_points = torch.stack([xs, ys], dim=2)
return abs_img_points


def abs_img_point_to_rel_img_point(abs_img_points,
img_shape,
spatial_scale=1.):
def get_shape_from_feature_map(x):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
"""Get spatial resolution of input feature map considering exporting to
onnx mode.

Args:
x (torch.Tensor): Input tensor, shape (N, C, H, W)
Returns:
torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
"""
if torch.onnx.is_in_onnx_export():
img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
x.device).float()
else:
img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
x.device).float()
return img_shape


def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
"""Convert image based absolute point coordinates to image based relative
coordinates for sampling.

Args:
abs_img_points (Tensor): Image based absolute point coordinates,
shape (N, P, 2)
img_shape (tuple): (height, width) of image or feature map.
img (tuple/Tensor): (height, width) of image or feature map.
spatial_scale (float): Scale points by this factor. Default: 1.

Returns:
Tensor: Image based relative point coordinates for sampling,
shape (N, P, 2)
"""

assert isinstance(img_shape, tuple) and len(img_shape) == 2
h, w = img_shape
scale = torch.tensor([w, h],
dtype=torch.float,
device=abs_img_points.device)
scale = scale.view(1, 1, 2)
rel_img_points = abs_img_points / scale * spatial_scale
assert (isinstance(img, tuple) and len(img) == 2) or \
(isinstance(img, torch.Tensor) and len(img.shape) == 4)

return rel_img_points
if isinstance(img, tuple):
h, w = img
scale = torch.tensor([w, h],
dtype=torch.float,
device=abs_img_points.device)
scale = scale.view(1, 1, 2)
else:
scale = get_shape_from_feature_map(img)

return abs_img_points / scale * spatial_scale


def rel_roi_point_to_rel_img_point(rois,
rel_roi_points,
img_shape,
img,
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
spatial_scale=1.):
"""Convert roi based relative point coordinates to image based absolute
point coordinates.
Expand All @@ -118,7 +224,7 @@ def rel_roi_point_to_rel_img_point(rois,
rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
rel_roi_points (Tensor): Point coordinates inside RoI, relative to
RoI, location, range (0, 1), shape (N, P, 2)
img_shape (tuple): (height, width) of image or feature map.
img (tuple/Tensor): (height, width) of image or feature map.
spatial_scale (float): Scale points by this factor. Default: 1.

Returns:
Expand All @@ -127,7 +233,7 @@ def rel_roi_point_to_rel_img_point(rois,
"""

abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img_shape,
rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
spatial_scale)

return rel_img_point
Expand All @@ -153,8 +259,15 @@ def point_sample(input, points, align_corners=False, **kwargs):
if points.dim() == 3:
add_dim = True
points = points.unsqueeze(2)
output = F.grid_sample(
input, denormalize(points), align_corners=align_corners, **kwargs)
if is_in_onnx_export_without_custom_ops():
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
# If custom ops for onnx runtime not compiled use python
# implementation of grid_sample function to make onnx graph
# with supported nodes
output = bilinear_grid_sample(
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
input, denormalize(points), align_corners=align_corners)
else:
output = F.grid_sample(
input, denormalize(points), align_corners=align_corners, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
Expand All @@ -181,29 +294,38 @@ def __init__(self, output_size, spatial_scale, aligned=True):
self.aligned = aligned

def forward(self, features, rois):

num_imgs = features.size(0)
num_rois = rois.size(0)
rel_roi_points = generate_grid(
num_rois, self.output_size, device=rois.device)

point_feats = []
for batch_ind in range(num_imgs):
# unravel batch dim
feat = features[batch_ind].unsqueeze(0)
inds = (rois[:, 0].long() == batch_ind)
if inds.any():
rel_img_points = rel_roi_point_to_rel_img_point(
rois[inds], rel_roi_points[inds], feat.shape[2:],
self.spatial_scale).unsqueeze(0)
point_feat = point_sample(
feat, rel_img_points, align_corners=not self.aligned)
point_feat = point_feat.squeeze(0).transpose(0, 1)
point_feats.append(point_feat)
if torch.onnx.is_in_onnx_export():
rel_img_points = rel_roi_point_to_rel_img_point(
rois, rel_roi_points, features, self.spatial_scale)
rel_img_points = rel_img_points.reshape(num_imgs, -1,
*rel_img_points.shape[1:])
point_feats = point_sample(
features, rel_img_points, align_corners=not self.aligned)
point_feats = point_feats.transpose(1, 2)
else:
point_feats = []
for batch_ind in range(num_imgs):
# unravel batch dim
feat = features[batch_ind].unsqueeze(0)
inds = (rois[:, 0].long() == batch_ind)
if inds.any():
rel_img_points = rel_roi_point_to_rel_img_point(
rois[inds], rel_roi_points[inds], feat,
self.spatial_scale).unsqueeze(0)
point_feat = point_sample(
feat, rel_img_points, align_corners=not self.aligned)
point_feat = point_feat.squeeze(0).transpose(0, 1)
point_feats.append(point_feat)

point_feats = torch.cat(point_feats, dim=0)

channels = features.size(1)
roi_feats = torch.cat(point_feats, dim=0)
roi_feats = roi_feats.reshape(num_rois, channels, *self.output_size)
roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)

return roi_feats

Expand Down
40 changes: 40 additions & 0 deletions tests/test_ops/test_bilinear_grid_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class TestBilinearGridSample(object):

def _test_bilinear_grid_sample(self,
dtype=torch.float,
align_corners=False,
multiplier=1,
precision=1e-3):
from mmcv.ops.point_sample import bilinear_grid_sample

input = torch.rand(1, 1, 20, 20, dtype=dtype)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
grid *= multiplier

out = bilinear_grid_sample(input, grid, align_corners=align_corners)
ref_out = F.grid_sample(input, grid, align_corners=align_corners)

assert np.allclose(out.data.detach().cpu().numpy(),
ref_out.data.detach().cpu().numpy(), precision)

def test_bilinear_grid_sample(self):
self._test_bilinear_grid_sample(torch.double, False)
self._test_bilinear_grid_sample(torch.double, True)
self._test_bilinear_grid_sample(torch.float, False)
self._test_bilinear_grid_sample(torch.float, True)
self._test_bilinear_grid_sample(torch.float, False)
self._test_bilinear_grid_sample(torch.float, True, 5)
self._test_bilinear_grid_sample(torch.float, False, 10)
self._test_bilinear_grid_sample(torch.float, True, -6)
self._test_bilinear_grid_sample(torch.float, False, -10)
self._test_bilinear_grid_sample(torch.double, True, 5)
self._test_bilinear_grid_sample(torch.double, False, 10)
self._test_bilinear_grid_sample(torch.double, True, -6)
self._test_bilinear_grid_sample(torch.double, False, -10)
Loading