Skip to content

Commit

Permalink
[Relay, TF] Support converting TF combined_nms using Relay all_class_…
Browse files Browse the repository at this point in the history
…nms (apache#8174)

* import from branch

commit c86bcf4
Merge: 0fa8805 da75b2a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 12:13:29 2021 +0900

    Merge branch 'tmp' into all_class_nms_tf

commit 0fa8805
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 06:24:57 2021 +0900

    Revert "handling case when num detections is smaller than max_total_size"

    This reverts commit 61e70b8.

commit 6725150
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:43:06 2021 +0900

    handling case when num detections is smaller than max_total_size

commit 39549aa
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:32:37 2021 +0900

    simplify frontend

commit ca9470b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:25:13 2021 +0900

    update op definition

commit 47bdef9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 19:47:04 2021 +0900

    remove unnecessary mask

commit 445a7da
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:54:19 2021 +0900

    remove in_buffer

commit 71879b1
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:48:22 2021 +0900

    minor fix

commit 72e055a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:45:37 2021 +0900

    make it more readable

commit a1fe7c4
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:44:14 2021 +0900

    clean up

commit 0c659bf
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:33:54 2021 +0900

    improve sort on cpu

commit 480f6b7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:29:53 2021 +0900

    collect indices and scores in one kernel

commit 2b441c3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 15:47:31 2021 +0900

    initialization bug fixed in cuda

commit d43e801
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 15:23:09 2021 +0900

    cpu nms bug fixed

commit 025010e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 11:09:47 2021 +0900

    add cpu impl

commit 787d839
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 10:38:20 2021 +0900

    refactoring

commit 0540430
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 10:03:51 2021 +0900

    initial import

    commit 5ff0985625ec75f117af37017ebf4089dafb8a46
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 10:02:45 2021 +0900

        cleanup

    commit 199f9b6
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 10:00:15 2021 +0900

        Revert "add gather_nd shape func"

        This reverts commit 1ff4d53.

    commit 47a05c4
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:53:00 2021 +0900

        format

    commit 9dcd0f0
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:48:43 2021 +0900

        make it static

    commit eb06393
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:14:31 2021 +0900

        restore old impl and use it for q != 1 case

    commit 115a5df
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:00:40 2021 +0900

        fixed score gathering

    commit d203562
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 08:53:14 2021 +0900

        minimum fixed

    commit 3fe91e8
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 06:59:39 2021 +0900

        batch issue fixed

    commit 19e3e84
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 04:29:15 2021 +0900

        zero padding working

        This reverts commit 58c3413.

    commit ce7848b
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 13:12:47 2021 +0900

        pylint, do not use -1 for default value

    commit 968f3bd
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 13:07:31 2021 +0900

        rename to index_rank and make it Optional

    commit 9e06b84
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 18:01:59 2021 +0900

        fix pylint

    commit 81dc605
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:57:03 2021 +0900

        minor fix

    commit 54297b6
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:54:16 2021 +0900

        support dynamic scatter nd

    commit e25c225
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:33:19 2021 +0900

        gather_dim -> num_indices_per_tuple

    commit aaa6211
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:23:46 2021 +0900

        add dynamic gather_nd test

    commit 3a9fe5d
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:18:26 2021 +0900

        refactor gather_nd ref funcs

    commit 1ff4d53
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 14:36:34 2021 +0900

        add gather_nd shape func

    commit b020064
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 04:01:11 2021 +0900

        working on zero padding

    commit 4567417
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 03:21:52 2021 +0900

        working

    commit 7f5c76d
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 02:37:50 2021 +0900

        relay type inference works, debugging topi

    commit 4a4b8df
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 15:08:16 2021 +0900

        add max_total_size to attributes

    commit 7218b2f
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 14:50:58 2021 +0900

        tf frontend update

    commit cde4a1f
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 14:17:14 2021 +0900

        all class nms tf mode first cut

    commit 5f349f7
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 06:54:34 2021 +0900

        begin supporting per batch output

    commit 0044365
    Author: Trevor Morris <trevmorr@amazon.com>
    Date:   Mon May 3 19:46:28 2021 +0000

        initial

    commit 168a617
    Author: Trevor Morris <trevmorr@amazon.com>
    Date:   Fri Apr 16 20:31:32 2021 +0000

        initia;
        l

commit da75b2a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:58:19 2021 +0900

    do minimum in topi

commit 52c5e8a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:54:49 2021 +0900

    more simplify

commit 44d88cd
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:51:39 2021 +0900

    simplify

commit 74e1917
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:39:37 2021 +0900

    black

commit fc3a38e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:37:30 2021 +0900

    minor change

commit f88e2a3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:14:54 2021 +0900

    minor refactor

commit f2d7ed4
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 07:08:47 2021 +0900

    support the case when there is not enough box

commit 0f184a6
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 06:24:16 2021 +0900

    Revert "handling case when num detections is smaller than max_total_size"

    This reverts commit 61e70b8.

commit d7180f2
Merge: 61e70b8 06ac205
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:43:37 2021 +0900

    Merge branch 'gather_nd_shape_func' into tmp

commit 61e70b8
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:43:06 2021 +0900

    handling case when num detections is smaller than max_total_size

commit 453a79b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:32:37 2021 +0900

    simplify frontend

commit 2fc5f1e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun May 30 05:25:13 2021 +0900

    update op definition

commit 8afbd30
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 19:47:04 2021 +0900

    remove unnecessary mask

commit ff870f7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:54:19 2021 +0900

    remove in_buffer

commit e71b922
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:48:22 2021 +0900

    minor fix

commit b02faae
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:45:37 2021 +0900

    make it more readable

commit 6baee99
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:44:14 2021 +0900

    clean up

commit 7a2a2df
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:33:54 2021 +0900

    improve sort on cpu

commit afad2a2
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 16:29:53 2021 +0900

    collect indices and scores in one kernel

commit c5718e2
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 15:47:31 2021 +0900

    initialization bug fixed in cuda

commit 5623e3f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 15:23:09 2021 +0900

    cpu nms bug fixed

commit c40eaec
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 11:09:47 2021 +0900

    add cpu impl

commit 6c7aaeb
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 10:38:20 2021 +0900

    refactoring

commit 7b87922
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat May 29 10:03:51 2021 +0900

    initial import

    commit 5ff0985625ec75f117af37017ebf4089dafb8a46
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 10:02:45 2021 +0900

        cleanup

    commit 199f9b6
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 10:00:15 2021 +0900

        Revert "add gather_nd shape func"

        This reverts commit 1ff4d53.

    commit 47a05c4
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:53:00 2021 +0900

        format

    commit 9dcd0f0
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:48:43 2021 +0900

        make it static

    commit eb06393
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:14:31 2021 +0900

        restore old impl and use it for q != 1 case

    commit 115a5df
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 09:00:40 2021 +0900

        fixed score gathering

    commit d203562
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 08:53:14 2021 +0900

        minimum fixed

    commit 3fe91e8
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 06:59:39 2021 +0900

        batch issue fixed

    commit 19e3e84
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 04:29:15 2021 +0900

        zero padding working

        This reverts commit 58c3413.

    commit ce7848b
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 13:12:47 2021 +0900

        pylint, do not use -1 for default value

    commit 968f3bd
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 13:07:31 2021 +0900

        rename to index_rank and make it Optional

    commit 9e06b84
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 18:01:59 2021 +0900

        fix pylint

    commit 81dc605
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:57:03 2021 +0900

        minor fix

    commit 54297b6
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:54:16 2021 +0900

        support dynamic scatter nd

    commit e25c225
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:33:19 2021 +0900

        gather_dim -> num_indices_per_tuple

    commit aaa6211
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:23:46 2021 +0900

        add dynamic gather_nd test

    commit 3a9fe5d
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 17:18:26 2021 +0900

        refactor gather_nd ref funcs

    commit 1ff4d53
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 21 14:36:34 2021 +0900

        add gather_nd shape func

    commit b020064
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 04:01:11 2021 +0900

        working on zero padding

    commit 4567417
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 03:21:52 2021 +0900

        working

    commit 7f5c76d
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Sat May 29 02:37:50 2021 +0900

        relay type inference works, debugging topi

    commit 4a4b8df
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 15:08:16 2021 +0900

        add max_total_size to attributes

    commit 7218b2f
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 14:50:58 2021 +0900

        tf frontend update

    commit cde4a1f
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 14:17:14 2021 +0900

        all class nms tf mode first cut

    commit 5f349f7
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Fri May 28 06:54:34 2021 +0900

        begin supporting per batch output

    commit 0044365
    Author: Trevor Morris <trevmorr@amazon.com>
    Date:   Mon May 3 19:46:28 2021 +0000

        initial

    commit 168a617
    Author: Trevor Morris <trevmorr@amazon.com>
    Date:   Fri Apr 16 20:31:32 2021 +0000

        initia;
        l

commit 06ac205
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 28 13:12:47 2021 +0900

    pylint, do not use -1 for default value

commit 2adc426
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 28 13:07:31 2021 +0900

    rename to index_rank and make it Optional

commit c458da6
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 18:01:59 2021 +0900

    fix pylint

commit b7faf0f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:57:03 2021 +0900

    minor fix

commit c031641
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:54:16 2021 +0900

    support dynamic scatter nd

commit 56f3f0e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:33:19 2021 +0900

    gather_dim -> num_indices_per_tuple

commit 081823b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:23:46 2021 +0900

    add dynamic gather_nd test

commit 6b2655b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 17:18:26 2021 +0900

    refactor gather_nd ref funcs

commit f9f5dfb
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri May 21 14:36:34 2021 +0900

    add gather_nd shape func

* make combined nms converter public

* do topk on smaller score tensor

* update tests

* remove max_total_size attribute, do minimum in relay side

* fix topk

* update relay doc

* update doc

* fix pylint

* update shape func for tf mode and add test

* name change

* reject dynamic inputs

* revert gather_nd change

* do not try to support dynamic batch size in tile rep

* check batch_size is int

* fix dtype issue in scan

* fix slicing before topk
  • Loading branch information
masahi authored Jun 4, 2021
1 parent e0baf80 commit f4ec5fd
Show file tree
Hide file tree
Showing 13 changed files with 555 additions and 102 deletions.
12 changes: 10 additions & 2 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,19 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
}
};

/*! \brief Attributes used in non_maximum_suppression operator */
/*! \brief Attributes used in all_class_non_maximum_suppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
std::string output_format;

TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {}
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(output_format)
.set_default("onnx")
.describe(
"Output format, onnx or tensorflow. Returns outputs in a way that can be easily "
"consumed by each frontend.");
}
};

/*! \brief Attributes used in roi_align operators */
Expand Down
106 changes: 103 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,88 @@ def _impl(inputs, attr, params, mod):
return _impl


def convert_combined_nms_with_all_class_nms(
batch_size,
max_output_boxes_per_batch,
num_class,
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
clip_boxes,
):
"""Converts TF combined_nms using Relay all_class_max_suppression op"""
(selected_indices, selected_scores, num_detections,) = _op.vision.all_class_non_max_suppression(
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
output_format="tensorflow",
)
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
assert isinstance(batch_size, int), "dynamic batch size not supported yet."
tile_batch_reps = _op.const([batch_size, 1])
box_range_2d = _op.tile(box_range, tile_batch_reps)
valid_mask = _op.cast(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)

def select_topk(do_zero_pad):
def true_branch():
arange = _op.arange(
_op.const(0, dtype="int64"),
_op.const(max_output_boxes_per_batch, dtype="int64"),
dtype="int64",
)
pad = _op.full(
_op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,)
)
topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps)
nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
nmsed_scores = nmsed_scores * valid_mask
return nmsed_scores, topk_indices

def false_branch():
if isinstance(max_output_boxes_per_class, int):
# Do topk on smaller input if possible
slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64")
selected_scores_slice = _op.strided_slice(
selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1]
)
else:
selected_scores_slice = selected_scores
return _op.topk(selected_scores_slice, k=max_total_size, axis=1, ret_type="both")

# TODO(masahi): support dynamic num_boxes
# return _expr.If(do_zero_pad, true_branch(), false_branch())
return true_branch() if do_zero_pad else false_branch()

assert isinstance(max_output_boxes_per_batch, int), "dynamic number of boxes not supported yet."
nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size)

indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1)
nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.take(indices, _op.const(0), axis=2)
nmsed_classes = _op.cast(nmsed_classes, "float32")
nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1)
num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))

if clip_boxes:
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)

return _expr.TupleWrapper(
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4
)


def _combined_nms():
def _impl(inputs, attr, params, mod):
# Get parameter values
Expand Down Expand Up @@ -821,9 +903,27 @@ def _impl(inputs, attr, params, mod):
q = boxes_shape[2]
num_classes = scores_shape[2]

if q != num_classes:
# When q is 1, it means same box coords are used for all classes.
boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4))
assert isinstance(batch_size, int) and isinstance(
num_anchors, int
), "Dynamic inputs not supported yet"

if q == 1:
boxes = _op.squeeze(boxes, axis=[2])
scores_trans = _op.transpose(scores, [0, 2, 1])
max_output_boxes_per_batch = num_anchors * num_classes
return convert_combined_nms_with_all_class_nms(
batch_size,
max_output_boxes_per_batch,
num_classes,
boxes,
scores_trans,
max_output_size,
iou_threshold,
score_threshold,
max_total_size.data.numpy().item(),
attr["clip_boxes"],
)

boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4])
scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1])

Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,15 @@ def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)
output_format = attrs.output_format
return topi_compute(
inputs[0],
inputs[1],
max_output_size,
iou_threshold,
score_threshold,
output_format,
)

return _compute_nms

Expand Down
22 changes: 20 additions & 2 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def nms_shape_func(attrs, inputs, _):


@script
def _all_class_nms_shape_func(boxes_shape, scores_shape):
def _all_class_nms_shape_func_onnx(boxes_shape, scores_shape):
out_shape = output_tensor((2,), "int64")
count_shape = output_tensor((1,), "int64")

Expand All @@ -99,9 +99,27 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape):
return out_shape, count_shape


@script
def _all_class_nms_shape_func_tf(boxes_shape, scores_shape):
out_indices_shape = output_tensor((3,), "int64")
out_scores_shape = output_tensor((2,), "int64")
count_shape = output_tensor((1,), "int64")

out_indices_shape[0] = boxes_shape[0]
out_indices_shape[1] = scores_shape[1] * boxes_shape[1]
out_indices_shape[2] = int64(2)
out_scores_shape[0] = boxes_shape[0]
out_scores_shape[1] = scores_shape[1] * boxes_shape[1]
count_shape[0] = boxes_shape[0]

return out_indices_shape, out_scores_shape, count_shape


@reg.register_shape_func("vision.all_class_non_max_suppression", False)
def all_class_nms_shape_func(attrs, inputs, _):
return _all_class_nms_shape_func(inputs[0], inputs[1])
if attrs.output_format == "onnx":
return _all_class_nms_shape_func_onnx(inputs[0], inputs[1])
return _all_class_nms_shape_func_tf(inputs[0], inputs[1])


@script
Expand Down
41 changes: 35 additions & 6 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def non_max_suppression(


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0
boxes,
scores,
max_output_boxes_per_class=-1,
iou_threshold=-1.0,
score_threshold=-1.0,
output_format="onnx",
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand All @@ -175,16 +180,31 @@ def all_class_non_max_suppression(
score_threshold : float or relay.Expr, optional
Score threshold to filter out low score boxes early
output_format : string, optional
"onnx" or "tensorflow". Specify by which frontends the outputs are
intented to be consumed.
Returns
-------
out : relay.Tuple
The output is a relay.Tuple of two tensors, the first is `indices` of size
`(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor
`num_total_detection` of shape `(1,)` representing the total number of selected boxes.
If `output_format` is "onnx", the output is a relay.Tuple of two tensors, the first is
`indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar
tensor `num_total_detection` of shape `(1,)` representing the total number of selected
boxes. The three values in `indices` encode batch, class, and box indices.
Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first,
in descending of scores, followed by boxes from batch 0, class 1 etc. Out of
`batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection`
rows are valid.
If `output_format` is "tensorflow", the output is a relay.Tuple of three tensors, the first
is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of
size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size
`(batch_size,)` representing the total number of selected boxes per batch. The two values
in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at
batch b, only the first `num_total_detection[b]` entries are valid. The second axis of
`indices` and `scores` are sorted within each class by box scores, but not across classes.
So the box indices and scores for the class 0 come first in a sorted order, followed by
the class 1 etc.
"""
if not isinstance(max_output_boxes_per_class, expr.Expr):
max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32")
Expand All @@ -194,6 +214,15 @@ def all_class_non_max_suppression(
score_threshold = expr.const(score_threshold, "float32")

out = _make.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
output_format,
)
return expr.TupleWrapper(out, 2)

if output_format == "onnx":
return expr.TupleWrapper(out, 2)

return expr.TupleWrapper(out, 3)
Loading

0 comments on commit f4ec5fd

Please sign in to comment.