Skip to content

Commit

Permalink
Revert "handling case when num detections is smaller than max_total_s…
Browse files Browse the repository at this point in the history
…ize"

This reverts commit 61e70b8.
  • Loading branch information
masahi committed May 30, 2021
1 parent 6725150 commit 0fa8805
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,14 +818,12 @@ def all_class_impl(
max_total_size,
output_format="tensorflow",
)
k = _op.minimum(_op.const(max_total_size, dtype="int64"), _op.max(num_detections))
nmsed_scores, topk_indices = _op.topk(
selected_scores, k, axis=1, ret_type="both"
selected_scores, k=max_total_size, axis=1, ret_type="both"
)
num_detections = _op.minimum(num_detections, k)
topk_indices = _op.expand_dims(topk_indices, axis=0)
indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1)

num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64"))
nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32")
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)
Expand All @@ -834,7 +832,7 @@ def all_class_impl(
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

# Fill in invalid boxes with 0
# Fill in invalid entries with 0
box_range = _op.arange(
_op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64"
)
Expand Down

0 comments on commit 0fa8805

Please sign in to comment.