Skip to content

Commit

Permalink
[IR] Try to improve nms and get_valid_count (#3282)
Browse files Browse the repository at this point in the history
* improve nms

* add back get_valid_count syncs
  • Loading branch information
Laurawly authored and vinx13 committed Jun 5, 2019
1 parent befd8c1 commit f2ddb19
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
box_indices = ib.buffer_ptr(box_indices)
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")

max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
k = bx * max_threads + tx
j = bx * max_threads + tx

iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
Expand All @@ -480,22 +480,22 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
nkeep = if_then_else( \
tvm.all(top_k > 0, top_k < valid_count[i]),
top_k, valid_count[i])
with ib.for_range(0, nkeep) as j:
with ib.if_scope(k < box_data_length):
with ib.if_scope(j < nkeep):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = \
data[(base_idx + sorted_index[i * num_anchors + j] \
* box_data_length + k)]
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
with ib.for_range(0, valid_count[i] - nkeep) as j:
with ib.if_scope(k < box_data_length):
with ib.if_scope(j < valid_count[i] - nkeep):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
box_indices[i * num_anchors + (j + nkeep)] = -1
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.if_scope(k < valid_count[i]):
with ib.for_range(0, valid_count[i]) as k:
offset_k = k * box_data_length
with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
tvm.any(force_suppress > 0, id_index < 0, \
Expand All @@ -506,35 +506,29 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k] = -1.0
box_indices[i * num_anchors + k] = -1
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.else_scope():
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(k < box_data_length):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
box_indices[i * num_anchors + j] = j
# Set invalid entry to be -1
with ib.for_range(0, num_anchors - valid_count[i]) as j:
with ib.if_scope(k < box_data_length):
with ib.if_scope(j < num_anchors - valid_count[i]):
with ib.for_range(0, box_data_length) as k:
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
box_indices[i * num_anchors + j + valid_count[i]] = -1
# Only return max_output_size number of valid boxes
num_valid_boxes[0] = 0
with ib.if_scope(max_output_size > 0):
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.if_scope(num_valid_boxes[0] == max_output_size):
with ib.if_scope(k < box_data_length):
with ib.for_range(0, box_data_length) as k:
out[base_idx + offset_j + k] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.else_scope():
num_valid_boxes[0] += 1
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))

return ib.get()

Expand Down

0 comments on commit f2ddb19

Please sign in to comment.