Skip to content

Commit

Permalink
unpack input data
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 3a27397 commit 70c65f0
Showing 1 changed file with 99 additions and 63 deletions.
162 changes: 99 additions & 63 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import te

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .sort import argsort, argsort_thrust, is_thrust_available


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -338,7 +338,9 @@ def nms_ir(
sorted_index,
valid_count,
indices,
out,
out_bboxes,
out_scores,
out_class_ids,
box_indices,
num_valid_boxes,
max_output_size,
Expand Down Expand Up @@ -458,9 +460,13 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
sorted_index = ib.buffer_ptr(sorted_index)
valid_count = ib.buffer_ptr(valid_count)
indices = ib.buffer_ptr(indices)
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
out = ib.buffer_ptr(out)

# outputs
out_bboxes = ib.buffer_ptr(out_bboxes)
out_scores = ib.buffer_ptr(out_scores)
out_class_ids = ib.buffer_ptr(out_class_ids)
box_indices = ib.buffer_ptr(box_indices)
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)

if isinstance(iou_threshold, float):
iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
Expand All @@ -483,36 +489,55 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = by
base_idx = i * num_anchors * box_data_length
base_src_idx = i * num_anchors * box_data_length
base_bbox_idx = i * num_anchors * 4

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
nkeep = if_then_else(
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
)
j = bx * max_threads + tx
with ib.if_scope(j < nkeep):
src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length
# Fill in out with sorted boxes
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = data[
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
]
with ib.for_range(0, 4) as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k]

out_scores[i * num_anchors + j] = data[src_idx + score_index]

if id_index >= 0:
out_class_ids[i * num_anchors + j] = data[src_idx + id_index]

with ib.else_scope():
# Indices > nkeep are discarded
# Only needed for return_indices = False case
if return_indices is False:
with ib.if_scope(j < num_anchors):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = -1.0
with ib.for_range(0, 4) as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0

out_scores[i, j] = -1.0

if id_index >= 0:
out_class_ids[i, j] = -1.0

if return_indices:
with ib.if_scope(j < num_anchors):
box_indices[i * num_anchors + j] = -1

with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
with ib.for_range(0, box_data_length) as k:
offset = base_idx + j * box_data_length + k
out[offset] = data[offset]
src_offset = base_src_idx + j * box_data_length

with ib.for_range(0, 4) as k:
out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k]

out_scores[i * num_anchors + j] = data[src_offset + score_index]

if id_index >= 0:
out_class_ids[i * num_anchors + j] = data[src_offset + id_index]

box_indices[i * num_anchors + j] = j

with ib.new_scope():
Expand All @@ -526,7 +551,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):

i = by

base_idx = i * num_anchors * box_data_length
base_bbox_idx = i * num_anchors * 4
num_valid_boxes_local = ib.allocate(
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
Expand All @@ -549,37 +574,35 @@ def nms_inner_loop(ib, j):

num_valid_boxes_local[0] += 1

offset_j = j * box_data_length
offset_j = j * 4
num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx)

with ib.for_range(0, num_iter_per_thread) as _k:
k = j + 1 + _k * nthread_tx + tx
offset_k = k * box_data_length
offset_k = k * 4

with ib.if_scope(
tvm.tir.all(
k < num_anchors,
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
out_scores[i, k] > 0, # is the box k still valid?
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
out_class_ids[i, k] == out_class_ids[i, j],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
out_bboxes,
base_bbox_idx + offset_j,
base_bbox_idx + offset_k,
)
with ib.if_scope(iou >= iou_threshold):
# invalidate the box k
out[base_idx + offset_k + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_k + id_index] = -1.0
out_scores[i, k] = -1.0
if return_indices is False and id_index >= 0:
out_class_ids[i, k] = -1.0

# Make sure to do the next loop in a lock step
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))

if isinstance(max_output_size, int):
Expand All @@ -589,7 +612,7 @@ def nms_inner_loop(ib, j):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
Expand Down Expand Up @@ -638,6 +661,33 @@ def _fetch_score_ir(data, score, axis):
return ib.get()


def _get_sorted_indices(data, data_buf, score_index, score_shape):
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
score_tensor = te.extern(
[score_shape],
[data],
lambda ins, outs: _fetch_score_ir(
ins[0],
outs[0],
score_index,
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[score_buf],
name="fetch_score",
tag="fetch_score",
)

if is_thrust_available():
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype="int32"
)
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")

return sort_tensor


def non_max_suppression(
data,
valid_count,
Expand Down Expand Up @@ -736,54 +786,35 @@ def non_max_suppression(
valid_count_buf = tvm.tir.decl_buffer(
valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4
)
score_axis = score_index

score_shape = (batch_size, num_anchors)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
score_tensor = te.extern(
[score_shape],
[data],
lambda ins, outs: _fetch_score_ir(
ins[0],
outs[0],
score_axis,
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[score_buf],
name="fetch_score",
tag="fetch_score",
)
target = tvm.target.Target.current()
if (
target
and target.kind.name == "cuda"
and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True)
):
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)

sort_tensor = _get_sorted_indices(data, data_buf, score_index, score_shape)
sort_tensor_buf = tvm.tir.decl_buffer(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

out, box_indices, num_valid_boxes = te.extern(
[data.shape, score_shape, [batch_size, 1]],
bbox_shape = (batch_size, num_anchors, 4)
class_id_shape = score_shape
box_indices_shape = score_shape
num_valid_boxes_shape = (batch_size, 1)

out_bboxes, out_scores, out_sorted_ids, box_indices, num_valid_boxes = te.extern(
[bbox_shape, score_shape, class_id_shape, box_indices_shape, num_valid_boxes_shape],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
ins[1],
ins[2],
ins[3],
outs[0],
outs[1],
outs[2],
outs[0], # sorted bbox
outs[1], # sorted scores
outs[2], # sorted class ids
outs[3], # box_indices
outs[4], # num_valid_boxes
max_output_size,
iou_threshold,
force_suppress,
Expand All @@ -793,7 +824,7 @@ def non_max_suppression(
score_index,
return_indices,
),
dtype=[data.dtype, "int32", "int32"],
dtype=[data.dtype, "float32", "float32", "int32", "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
Expand All @@ -802,4 +833,9 @@ def non_max_suppression(
if return_indices:
return [box_indices, num_valid_boxes]

return out
# TODO: do concat
return out_bboxes
# if id_index >= 0:
# return concatenate([out_bboxes

# return out

0 comments on commit 70c65f0

Please sign in to comment.