Skip to content

Commit

Permalink
tf frontend update
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent cde4a1f commit 7218b2f
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,9 @@ def _impl(inputs, attr, params, mod):
boxes = inputs[0]
scores = inputs[1]
try:
max_output_boxes_per_class = int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0])
max_output_boxes_per_class = int(
np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0]
)
except Exception:
try:
max_output_boxes_per_class = (
Expand All @@ -822,35 +824,28 @@ def _impl(inputs, attr, params, mod):
)
boxes = _op.squeeze(boxes, axis=[2])
# Transpose (batch_size, num_boxes, num_classes) -> (batch_size, num_classes, num_boxes)
scores = _op.transpose(scores, [0, 2, 1])
indices, count = _op.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="tensorflow"
scores_trans = _op.transpose(scores, [0, 2, 1])

indices, num_detections = _op.vision.all_class_non_max_suppression(
boxes,
scores_trans,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
output_format="tensorflow",
)
# Slice indices to count
three = _op.const(np.array([3]), dtype="int64")
begin = _op.const(np.array([0, 0]), dtype="int64")
end = _op.concatenate([count, three], axis=0)
strides = _op.const(np.array([1, 1]), dtype="int64")
indices = _op.strided_slice(indices, begin, end, strides)

# Trim or pad to max_total_size

# Get NMSed boxes.
box_indices = _op.take(indices, _op.const([0, 2]), axis=1)
#box_indices = _op.
box_indicies_flat = _op.prod(indices, axis=[1])
#boxes_flat = _op.reshape(boxes, [-1, 4])
nmsed_boxes = _op.take(boxes, box_indices)

nmsed_box_indices = _op.take(indices, 0, axis=2)
nmsed_classes = _op.take(indices, 1, axis=2)
nmsed_boxes = _op.gather_nd(boxes, nmsed_box_indices, batch_dims=1)
nmsed_scores = _op.gather_nd(scores, indices, batch_dims=1)

if attr["clip_boxes"]:
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))
# Get NMSed scores, classes, count
nmsed_scores = _op.take(scores, indices)
nmsed_classes = _op.take(indices, _op.const([1]), axis=1)
nms_count = count # min(nms_count, total_size)

return _expr.TupleWrapper(
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, nms_count]), 4
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4
)

return _impl
Expand Down

0 comments on commit 7218b2f

Please sign in to comment.