Skip to content

Commit

Permalink
simplify frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 30, 2021
1 parent ca9470b commit 39549aa
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def _impl(inputs, attr, params, mod):

def _combined_nms():
def all_class_impl(
batch_size,
boxes,
scores,
max_output_boxes_per_class,
Expand All @@ -817,20 +818,16 @@ def all_class_impl(
max_total_size,
output_format="tensorflow",
)
topk_indices = _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="indices")
nmsed_scores, topk_indices = _op.topk(
selected_scores, k=max_total_size, axis=1, ret_type="both"
)
topk_indices = _op.expand_dims(topk_indices, axis=0)
indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1)
num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))

nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32")
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)

indices_shape = _infer_shape(indices, mod)
indices_dims = len(indices_shape)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
nmsed_scores = _op.gather_nd(scores, indices, batch_dims=1)

if clip_boxes:
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))
Expand All @@ -839,8 +836,6 @@ def all_class_impl(
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
batch_size = indices_shape[0]

if isinstance(batch_size, tvm.tir.Any):
box_range_2d = _op.tile(box_range, _op.concatenate([batch_size, 1]))
else:
Expand All @@ -849,7 +844,6 @@ def all_class_impl(
valid_mask = _op.cast(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)
nmsed_scores = nmsed_scores * valid_mask
nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)

return _expr.TupleWrapper(
Expand Down Expand Up @@ -887,6 +881,7 @@ def _impl(inputs, attr, params, mod):
boxes = _op.squeeze(boxes, axis=[2])
scores_trans = _op.transpose(scores, [0, 2, 1])
return all_class_impl(
batch_size,
boxes,
scores_trans,
max_output_size,
Expand Down

0 comments on commit 39549aa

Please sign in to comment.