From 0fa88051b3d30337674a07e579b78e8cb254cd66 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 30 May 2021 06:24:57 +0900 Subject: [PATCH] Revert "handling case when num detections is smaller than max_total_size" This reverts commit 61e70b82f338300224b22f4d6bdda349e7aa5aca. --- python/tvm/relay/frontend/tensorflow.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b5a9da17d6d8..8345254b2ed2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -818,14 +818,12 @@ def all_class_impl( max_total_size, output_format="tensorflow", ) - k = _op.minimum(_op.const(max_total_size, dtype="int64"), _op.max(num_detections)) nmsed_scores, topk_indices = _op.topk( - selected_scores, k, axis=1, ret_type="both" + selected_scores, k=max_total_size, axis=1, ret_type="both" ) - num_detections = _op.minimum(num_detections, k) topk_indices = _op.expand_dims(topk_indices, axis=0) indices = _op.gather_nd(selected_indices, topk_indices, batch_dims=1) - + num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64")) nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32") nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1) @@ -834,7 +832,7 @@ def all_class_impl( nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) - # Fill in invalid boxes with 0 + # Fill in invalid entries with 0 box_range = _op.arange( _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" )