Skip to content

Commit

Permalink
zero padding working
Browse files Browse the repository at this point in the history
This reverts commit 58c3413.
  • Loading branch information
masahi committed May 29, 2021
1 parent ce7848b commit 19e3e84
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
19 changes: 9 additions & 10 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,6 @@ def _impl(inputs, attr, params, mod):
# Transpose (batch_size, num_boxes, num_classes) -> (batch_size, num_classes, num_boxes)
scores_trans = _op.transpose(scores, [0, 2, 1])

print(max_output_boxes_per_class)
print(iou_threshold)
print(score_threshold)

indices, num_detections = _op.vision.all_class_non_max_suppression(
boxes,
scores_trans,
Expand All @@ -846,10 +842,11 @@ def _impl(inputs, attr, params, mod):
# )

nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.take(indices, _op.const(0), axis=2)
nmsed_classes = _op.cast(_op.take(indices, _op.const(0), axis=2), "float32")
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)

indices_dims = len(_infer_shape(indices, mod))
indices_shape = _infer_shape(indices, mod)
indices_dims = len(indices_shape)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
nmsed_scores = _op.gather_nd(scores, indices, batch_dims=1)

Expand All @@ -858,11 +855,13 @@ def _impl(inputs, attr, params, mod):
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

# Fill in invalid entries with 0
box_range = _op.arange(_expr.const(0, dtype="int32"), max_total_size, dtype="int32")
box_range = _op.broadcast_to(_op.cast(box_range, "int64"), _op.shape_of(nmsed_scores))
valid_mask = _op.cast(_op.less(box_range, num_detections), "float32")
box_range = _op.arange(_expr.const(0, dtype="int64"), _op.cast(max_total_size, "int64"), dtype="int64")
batch_size = indices_shape[0]
box_range = _op.tile(box_range, _op.const([batch_size]))
valid_mask = _op.cast(_op.less(box_range, _op.expand_dims(num_detections, axis=1)), "float32")
nmsed_scores = nmsed_scores * valid_mask
# nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)
nmsed_classes = nmsed_classes * valid_mask
nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)

return _expr.TupleWrapper(
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,5 @@ def all_class_non_max_suppression(
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
print(final_indices.shape)
print(num_total_detections.shape)
# num_detections = minimum(num_total_detections, max_detection_per_batch)
return [final_indices, num_total_detections]

0 comments on commit 19e3e84

Please sign in to comment.