Skip to content

Commit

Permalink
BUG P0 (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Sep 14, 2022
1 parent b87afb9 commit e912e86
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 6 deletions.
63 changes: 63 additions & 0 deletions mmdeploy/backend/coreml/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,66 @@ def coreml_nms(context, node):
max_boxes=max_boxes)

context.add(tuple(results), torch_name=node.outputs[0])


@register_torch_op
def log2(context, node):
"""bind log2."""
import numpy as np
inputs = _get_inputs(context, node)
x = inputs[0]
log_x = mb.log(x=x)
context.add(mb.mul(x=log_x, y=1 / np.log(2.0)), node.name)


@register_torch_op
def roi_align(context, node):
"""roi align."""
inputs = _get_inputs(context, node)

x = context[node.inputs[0]]
input_shape = x.shape # (B, C, h_in, w_in)
if len(input_shape) != 4:
raise ValueError(
'"CropResize" op: expected input rank 4, got {}'.format(x.rank))

const_box_info = True
if context[node.inputs[1]].val is None or context[
node.inputs[2]].val is None:
const_box_info = False

extrapolation_value = context[node.inputs[2]].val
# CoreML index information along with boxes
if const_box_info:
boxes = context[node.inputs[1]].val
# CoreML expects boxes/ROI in
# [N, 1, 5, 1, 1] format
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
else:
boxes = inputs[1]
boxes = mb.reshape(
x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
# Get Height and Width of crop
h_out = inputs[3]
w_out = inputs[4]

# Torch input format: [B, C, h_in, w_in]
# CoreML input format: [B, C, h_in, w_in]

# Crop Resize
x = mb.crop_resize(
x=x,
roi=boxes,
target_height=h_out.val,
target_width=w_out.val,
normalized_coordinates=False,
spatial_scale=extrapolation_value,
box_coordinate_mode='CORNERS_WIDTH_FIRST',
sampling_mode='OFFSET_CORNERS',
)

# CoreML output format: [N, 1, C, h_out, w_out]
# Torch output format: [N, C, h_out, w_out]
x = mb.squeeze(x=x, axes=[1])

context.add(x, torch_name=node.outputs[0])
18 changes: 12 additions & 6 deletions mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdeploy.codebase.mmdet import (get_post_processing_params,
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
Expand Down Expand Up @@ -104,11 +104,17 @@ def rpn_head__get_bboxes(ctx,

if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
prior_inds = topk_inds.new_zeros((1, 1))
anchors = anchors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
bbox_pred, scores = gather_topk(
bbox_pred,
scores,
inds=topk_inds,
batch_size=batch_size,
is_batched=True)
anchors = gather_topk(
anchors,
inds=topk_inds,
batch_size=batch_size,
is_batched=False)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,39 @@ def single_roi_extractor__forward__openvino(ctx,
args = (output_size, featmap_strides, sample_num, rois, *feats)
result = SingleRoIExtractorOpenVINO.apply(*args)
return result


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
backend=Backend.COREML.value)
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def single_roi_extractor__forward__coreml(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` of SingleRoIExtractor for coreml."""
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
if num_levels == 1:
assert len(rois) > 0, 'The number of rois should be positive'
self.roi_layers[0].use_torchvision = True
return self.roi_layers[0](feats[0], rois)

target_lvls = self.map_roi_levels(rois, num_levels)

if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)

for i in range(num_levels):
mask = target_lvls == i
# inds = mask.nonzero(as_tuple=False).squeeze(1)
rois_t = rois * mask.unsqueeze(-1)
# use the roi align in torhcvision
self.roi_layers[i].use_torchvision = True
roi_feats_t = self.roi_layers[i](feats[i], rois_t)
roi_feats = roi_feats + roi_feats_t * (rois_t[:, -1] > 0).reshape(
-1, 1, 1, 1)
# slice to recover original size
return roi_feats

0 comments on commit e912e86

Please sign in to comment.