Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 70c65f0 commit 1913f97
Showing 1 changed file with 88 additions and 55 deletions.
143 changes: 88 additions & 55 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
j = bx * max_threads + tx
with ib.if_scope(j < nkeep):
src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length
# Fill in out with sorted boxes
with ib.for_range(0, 4) as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k]

Expand Down Expand Up @@ -532,7 +531,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):

with ib.for_range(0, 4) as k:
out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k]

out_scores[i * num_anchors + j] = data[src_offset + score_index]

if id_index >= 0:
Expand Down Expand Up @@ -600,6 +598,7 @@ def nms_inner_loop(ib, j):
with ib.if_scope(iou >= iou_threshold):
# invalidate the box k
out_scores[i, k] = -1.0

if return_indices is False and id_index >= 0:
out_class_ids[i, k] = -1.0

Expand All @@ -615,6 +614,7 @@ def nms_inner_loop(ib, j):
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
Expand Down Expand Up @@ -688,6 +688,75 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
return sort_tensor


def _run_nms(
data,
data_buf,
sort_tensor,
valid_count,
indices,
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
id_index,
score_index,
return_indices,
):
sort_tensor_buf = tvm.tir.decl_buffer(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
)

valid_count_dtype = "int32"
valid_count_buf = tvm.tir.decl_buffer(
valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4
)
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

batch_size = data.shape[0]
num_anchors = data.shape[1]

# output shapes
bbox_shape = (batch_size, num_anchors, 4)
score_shape = (batch_size, num_anchors)
class_id_shape = score_shape
box_indices_shape = score_shape
num_valid_boxes_shape = (batch_size, 1)

return te.extern(
[bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
ins[1],
ins[2],
ins[3],
outs[0], # sorted bbox
outs[1], # sorted scores
outs[2], # sorted class ids
outs[3], # box_indices
outs[4], # num_valid_boxes
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
id_index,
score_index,
return_indices,
),
dtype=[data.dtype, "float32", "float32", "int32", "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
)


def _concatenate_outputs(out_bboxes, out_scores, out_sorted_ids, score_index, id_index):
# TODO
return None


def non_max_suppression(
data,
valid_count,
Expand Down Expand Up @@ -779,63 +848,27 @@ def non_max_suppression(
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]

valid_count_dtype = "int32"
valid_count_buf = tvm.tir.decl_buffer(
valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4
)

score_shape = (batch_size, num_anchors)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)

sort_tensor = _get_sorted_indices(data, data_buf, score_index, score_shape)
sort_tensor_buf = tvm.tir.decl_buffer(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
)

indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

bbox_shape = (batch_size, num_anchors, 4)
class_id_shape = score_shape
box_indices_shape = score_shape
num_valid_boxes_shape = (batch_size, 1)

out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = te.extern(
[bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
ins[1],
ins[2],
ins[3],
outs[0], # sorted bbox
outs[1], # sorted scores
outs[2], # sorted class ids
outs[3], # box_indices
outs[4], # num_valid_boxes
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
id_index,
score_index,
return_indices,
),
dtype=[data.dtype, "float32", "float32", "int32", "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
sort_tensor = _get_sorted_indices(data, data_buf, score_index, (data.shape[0], data.shape[1]))

out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = _run_nms(
data,
data_buf,
sort_tensor,
valid_count,
indices,
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
id_index,
score_index,
return_indices,
)

if return_indices:
return [box_indices, num_valid_boxes]

# TODO: do concat
return out_bboxes
# if id_index >= 0:
# return concatenate([out_bboxes

# return out
return _concatenate_outputs(out_bboxes, out_scores, out_sorted_ids, score_index, id_index)

0 comments on commit 1913f97

Please sign in to comment.