Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] enable exporting to onnx for PointRend #4977

Closed
1 change: 1 addition & 0 deletions docs/tutorials/onnx2tensorrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ The table below lists the models that are guaranteed to be convertable to Tensor
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| PointRend | `configs/point_rend/point_rend_r50_caffe_fpn_mstrain_1x_coco.py` | Y | Y | |

Notes:

Expand Down
1 change: 1 addition & 0 deletions docs/tutorials/pytorch2onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ The table below lists the models that are guaranteed to be exportable to ONNX an
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| CornerNet | `configs/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco.py` | Y | N | no flip, no batch inference, tested with torch==1.7.0 and onnxruntime==1.5.1. |
| PointRend | `configs/point_rend/point_rend_r50_caffe_fpn_mstrain_1x_coco.py` | Y | Y | |

Notes:

Expand Down
28 changes: 16 additions & 12 deletions mmdet/models/roi_heads/mask_heads/mask_point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_target_single(self, rois, rel_roi_points, pos_assigned_gt_inds,
0, pos_assigned_gt_inds))
gt_masks_th = gt_masks_th.unsqueeze(1)
rel_img_points = rel_roi_point_to_rel_img_point(
rois, rel_roi_points, gt_masks_th.shape[2:])
rois, rel_roi_points, gt_masks_th)
point_targets = point_sample(gt_masks_th,
rel_img_points).squeeze(1)
else:
Expand Down Expand Up @@ -285,16 +285,20 @@ def get_roi_rel_points_test(self, mask_pred, pred_label, cfg):
num_points = cfg.subdivision_num_points
uncertainty_map = self._get_uncertainty(mask_pred, pred_label)
num_rois, _, mask_height, mask_width = uncertainty_map.shape
h_step = 1.0 / mask_height
w_step = 1.0 / mask_width

uncertainty_map = uncertainty_map.view(num_rois,
mask_height * mask_width)
num_points = min(mask_height * mask_width, num_points)
if isinstance(mask_height, torch.Tensor):
h_step = 1.0 / mask_height.float()
w_step = 1.0 / mask_width.float()
else:
h_step = 1.0 / mask_height
w_step = 1.0 / mask_width
# cast to int to avoid dynamic K for TopK op in ONNX
mask_size = int(mask_height * mask_width)
uncertainty_map = uncertainty_map.view(num_rois, mask_size)
num_points = min(mask_size, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
point_coords = uncertainty_map.new_zeros(num_rois, num_points, 2)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
mask_width).float() * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
mask_width).float() * h_step
xs = w_step / 2.0 + (point_indices.long() %
mask_width).float() * w_step
ys = h_step / 2.0 + (point_indices.long() //
mask_width).float() * h_step
point_coords = torch.stack([xs, ys], dim=2)
return point_indices, point_coords
155 changes: 119 additions & 36 deletions mmdet/models/roi_heads/point_rend_roi_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
import os

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -64,24 +65,39 @@ def _get_fine_grained_point_feats(self, x, rois, rel_roi_points,
"""Sample fine grained feats from each level feature map and
concatenate them together."""
num_imgs = len(img_metas)
batch_size = x[0].shape[0]
num_rois = rois.shape[0]
fine_grained_feats = []
for idx in range(self.mask_roi_extractor.num_inputs):
feats = x[idx]
spatial_scale = 1. / float(
self.mask_roi_extractor.featmap_strides[idx])
point_feats = []
for batch_ind in range(num_imgs):
# unravel batch dim
feat = feats[batch_ind].unsqueeze(0)
inds = (rois[:, 0].long() == batch_ind)
if inds.any():
rel_img_points = rel_roi_point_to_rel_img_point(
rois[inds], rel_roi_points[inds], feat.shape[2:],
spatial_scale).unsqueeze(0)
point_feat = point_sample(feat, rel_img_points)
point_feat = point_feat.squeeze(0).transpose(0, 1)
point_feats.append(point_feat)
fine_grained_feats.append(torch.cat(point_feats, dim=0))
# support export to ONNX with batch dim
if torch.onnx.is_in_onnx_export():
rel_img_points = rel_roi_point_to_rel_img_point(
rois, rel_roi_points, feats, spatial_scale)
channels = feats.shape[1]
num_points = rel_img_points.shape[1]
rel_img_points = rel_img_points.reshape(
batch_size, -1, num_points, 2)
point_feats = point_sample(feats, rel_img_points)
point_feats = point_feats.transpose(1, 2).reshape(
num_rois, channels, num_points)
fine_grained_feats.append(point_feats)
else:
point_feats = []
for batch_ind in range(num_imgs):
# unravel batch dim
feat = feats[batch_ind].unsqueeze(0)
inds = (rois[:, 0].long() == batch_ind)
if inds.any():
rel_img_points = rel_roi_point_to_rel_img_point(
rois[inds], rel_roi_points[inds], feat,
spatial_scale).unsqueeze(0)
point_feat = point_sample(feat, rel_img_points)
point_feat = point_feat.squeeze(0).transpose(0, 1)
point_feats.append(point_feat)
fine_grained_feats.append(torch.cat(point_feats, dim=0))
return torch.cat(fine_grained_feats, dim=1)

def _mask_point_forward_test(self, x, rois, label_pred, mask_pred,
Expand Down Expand Up @@ -115,8 +131,26 @@ def _mask_point_forward_test(self, x, rois, label_pred, mask_pred,
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
refined_mask_pred = refined_mask_pred.reshape(
num_rois, channels, mask_height * mask_width)
refined_mask_pred = refined_mask_pred.scatter_(
2, point_indices, mask_point_pred)

is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
# avoid ScatterElements op in ONNX for TensorRT
if torch.onnx.is_in_onnx_export() and is_trt_backend:
mask_shape = refined_mask_pred.shape
point_shape = point_indices.shape
inds_dim0 = torch.arange(point_shape[0]).reshape(
point_shape[0], 1, 1).expand_as(point_indices)
inds_dim1 = torch.arange(point_shape[1]).reshape(
1, point_shape[1], 1).expand_as(point_indices)
inds_1d = inds_dim0.reshape(
-1) * mask_shape[1] * mask_shape[2] + inds_dim1.reshape(
-1) * mask_shape[2] + point_indices.reshape(-1)
refined_mask_pred = refined_mask_pred.reshape(-1)
refined_mask_pred[inds_1d] = mask_point_pred.reshape(-1)
refined_mask_pred = refined_mask_pred.reshape(*mask_shape)
else:
refined_mask_pred = refined_mask_pred.scatter_(
2, point_indices, mask_point_pred)

refined_mask_pred = refined_mask_pred.view(num_rois, channels,
mask_height, mask_width)

Expand All @@ -133,44 +167,93 @@ def simple_test_mask(self,
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
num_imgs = len(det_bboxes)
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record MaskHead '
'as it has not been executed this time')
segm_results = [[[] for _ in range(self.mask_head.num_classes)]
for _ in range(num_imgs)]
else:
# The length of proposals of different batches may be different.
# In order to form a batch, a padding operation is required.
if isinstance(det_bboxes, list):
# padding to form a batch
max_size = max([bboxes.size(0) for bboxes in det_bboxes])
for i, (bbox, label) in enumerate(zip(det_bboxes, det_labels)):
supplement_bbox = bbox.new_full(
(max_size - bbox.size(0), bbox.size(1)), 0)
supplement_label = label.new_full(
(max_size - label.size(0), ), 0)
det_bboxes[i] = torch.cat((supplement_bbox, bbox), dim=0)
det_labels[i] = torch.cat((supplement_label, label), dim=0)
det_bboxes = torch.stack(det_bboxes, dim=0)
det_labels = torch.stack(det_labels, dim=0)

batch_size = det_bboxes.size(0)
num_proposals_per_img = det_bboxes.shape[1]

# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_rois = bbox2roi(_bboxes)
det_bboxes = det_bboxes[..., :4]
if rescale:
if not isinstance(scale_factors[0], float):
scale_factors = det_bboxes.new_tensor(scale_factors)
det_bboxes = det_bboxes * scale_factors.unsqueeze(1)
batch_index = torch.arange(
det_bboxes.size(0),
device=det_bboxes.device).float().view(-1, 1, 1).expand(
det_bboxes.size(0), det_bboxes.size(1), 1)
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
mask_rois = mask_rois.view(-1, 5)
mask_results = self._mask_forward(x, mask_rois)
# split batch mask prediction back to each image
mask_pred = mask_results['mask_pred']
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
mask_rois = mask_rois.split(num_mask_roi_per_img, 0)

# Support exporting to ONNX
if torch.onnx.is_in_onnx_export():
max_shape = img_metas[0]['img_shape_for_onnx']
num_det = det_bboxes.shape[1]
det_bboxes = det_bboxes.reshape(-1, 4)
det_labels = det_labels.reshape(-1)

mask_pred = self._mask_point_forward_test(
x, mask_rois, det_labels, mask_pred, img_metas)

segm_results = self.mask_head.get_seg_masks(
mask_pred, det_bboxes, det_labels, self.test_cfg,
max_shape, scale_factors[0], rescale)
segm_results = segm_results.reshape(batch_size, num_det,
max_shape[0], max_shape[1])
return segm_results

# Recover the batch dimension
mask_preds = mask_pred.reshape(batch_size, num_proposals_per_img,
*mask_pred.shape[1:])
mask_rois = mask_rois.view(batch_size, -1, 5)

# apply mask post-processing to each image individually
segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
mask_pred = mask_preds[i]
det_bbox = det_bboxes[i]
det_label = det_labels[i]
mask_rois_i = mask_rois[i]

# remove padding
supplement_mask = det_bbox[..., -1] != 0
mask_pred = mask_pred[supplement_mask]
det_bbox = det_bbox[supplement_mask]
det_label = det_label[supplement_mask]

if det_label.shape[0] == 0:
segm_results.append(
[[] for _ in range(self.mask_head.num_classes)])
else:
x_i = [xx[[i]] for xx in x]
mask_rois_i = mask_rois[i]
mask_rois_i[:, 0] = 0 # TODO: remove this hack
if not torch.onnx.is_in_onnx_export():
mask_rois_i[:, 0] = 0 # TODO: remove this hack
mask_pred_i = self._mask_point_forward_test(
x_i, mask_rois_i, det_labels[i], mask_preds[i],
[img_metas])
x_i, mask_rois_i, det_label, mask_pred, [img_metas])
segm_result = self.mask_head.get_seg_masks(
mask_pred_i, _bboxes[i], det_labels[i], self.test_cfg,
mask_pred_i, det_bbox, det_label, self.test_cfg,
ori_shapes[i], scale_factors[i], rescale)
segm_results.append(segm_result)
return segm_results
Expand Down