From 3fe91e8846b6d2075ae1d9a162c4b70b08cc8024 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 29 May 2021 06:59:39 +0900 Subject: [PATCH] batch issue fixed --- python/tvm/relay/frontend/tensorflow.py | 15 ++++++++++++--- python/tvm/topi/cuda/nms.py | 11 +++++++++-- src/relay/op/vision/nms.cc | 2 +- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 10b94d3b1ab20..d6def8043b843 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -855,10 +855,19 @@ def _impl(inputs, attr, params, mod): nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) # Fill in invalid entries with 0 - box_range = _op.arange(_expr.const(0, dtype="int64"), _op.cast(max_total_size, "int64"), dtype="int64") + box_range = _op.arange( + _expr.const(0, dtype="int64"), _op.cast(max_total_size, "int64"), dtype="int64" + ) batch_size = indices_shape[0] - box_range = _op.tile(box_range, _op.const([batch_size])) - valid_mask = _op.cast(_op.less(box_range, _op.expand_dims(num_detections, axis=1)), "float32") + + if isinstance(batch_size, tvm.tir.Any): + box_range_2d = _op.tile(box_range, _op.concatenate([batch_size, 1])) + else: + box_range_2d = _op.tile(box_range, _op.const([batch_size, 1])) + + valid_mask = _op.cast( + _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" + ) nmsed_scores = nmsed_scores * valid_mask nmsed_classes = nmsed_classes * valid_mask nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index f765e1ba4d39c..a5981a9dd5270 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1156,11 +1156,18 @@ def all_class_non_max_suppression( selected_indices, selected_scores = collect_selected_indices_tf( selected_indices, selected_scores, num_detections_per_batch, row_offsets ) + + # TODO + # max_total_detections = reduction.max(num_total_detections) # selected_scores = strided_slice( - # selected_scores, begin=[0, 0], end=[batch, reduction.max(num_total_detections)] + # selected_scores, begin=[0, 0], end=expand_dims(max_total_detections, 0) # ) topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0] topk_indices = expand_dims(topk_indices, axis=0) final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1) + + # TODO # num_detections = minimum(num_total_detections, max_detection_per_batch) - return [final_indices, num_total_detections] + num_detections = num_total_detections + + return [final_indices, num_detections] diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 9524d268a9266..1e63ccd04721b 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -170,7 +170,7 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs ICHECK(param->max_total_size) << "max_total_size required for tf mode"; Integer max_total_size = param->max_total_size.value(); std::vector oshape{batch, max_total_size, 2}; - std::vector countshape{1}; + std::vector countshape{batch}; fields.push_back(TensorType(oshape, DataType::Int(64))); fields.push_back(TensorType(countshape, DataType::Int(64))); }