diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a99c1084f60e5..c1087d2dafdf3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -799,7 +799,9 @@ def _impl(inputs, attr, params, mod): boxes = inputs[0] scores = inputs[1] try: - max_output_boxes_per_class = int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0]) + max_output_boxes_per_class = int( + np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0] + ) except Exception: try: max_output_boxes_per_class = ( @@ -822,35 +824,28 @@ def _impl(inputs, attr, params, mod): ) boxes = _op.squeeze(boxes, axis=[2]) # Transpose (batch_size, num_boxes, num_classes) -> (batch_size, num_classes, num_boxes) - scores = _op.transpose(scores, [0, 2, 1]) - indices, count = _op.vision.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="tensorflow" + scores_trans = _op.transpose(scores, [0, 2, 1]) + + indices, num_detections = _op.vision.all_class_non_max_suppression( + boxes, + scores_trans, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + output_format="tensorflow", ) - # Slice indices to count - three = _op.const(np.array([3]), dtype="int64") - begin = _op.const(np.array([0, 0]), dtype="int64") - end = _op.concatenate([count, three], axis=0) - strides = _op.const(np.array([1, 1]), dtype="int64") - indices = _op.strided_slice(indices, begin, end, strides) - - # Trim or pad to max_total_size - - # Get NMSed boxes. - box_indices = _op.take(indices, _op.const([0, 2]), axis=1) - #box_indices = _op. - box_indicies_flat = _op.prod(indices, axis=[1]) - #boxes_flat = _op.reshape(boxes, [-1, 4]) - nmsed_boxes = _op.take(boxes, box_indices) + + nmsed_box_indices = _op.take(indices, 0, axis=2) + nmsed_classes = _op.take(indices, 1, axis=2) + nmsed_boxes = _op.gather_nd(boxes, nmsed_box_indices, batch_dims=1) + nmsed_scores = _op.gather_nd(scores, indices, batch_dims=1) + if attr["clip_boxes"]: nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) - # Get NMSed scores, classes, count - nmsed_scores = _op.take(scores, indices) - nmsed_classes = _op.take(indices, _op.const([1]), axis=1) - nms_count = count # min(nms_count, total_size) return _expr.TupleWrapper( - _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, nms_count]), 4 + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 ) return _impl