Skip to content

Commit

Permalink
[TOPI] Simplify GPU NMS IR and optimize a bit (#7136)
Browse files Browse the repository at this point in the history
* remove get_valid_counts from pytorch nms

* fix pytorch nms for negative score

* merge reset by -1

* move max_out_size handling to triangle loop

* update torch nms test

* fuse the last two kernels

* parallelize the first kernel

* merge first and last kernel

* remove unnecessary cases

* fix typo

* revert pytorch frontend change

* fuse rearrange step with triangle loop

* fix max_output_size handling

* check if already surpressed

* fix topi vision test by wrapping tir const around int argument

* fix for num anchors = 0 case

* fix missing zero init of num valid boxes when the input is empty

* add some comments and missing doc

* typo fix

* add a guard against zero dim grid / thread block inside ir_buidlder

* typo fix

* trigger CI
  • Loading branch information
masahi authored Dec 21, 2020
1 parent 9914685 commit 82942fb
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 181 deletions.
4 changes: 4 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from . import stmt as _stmt
from . import expr as _expr
from . import op


class WithScope(object):
Expand Down Expand Up @@ -200,6 +201,9 @@ def scope_attr(self, node, attr_key, value):
node = _expr.StringImm(node)
if isinstance(value, string_types):
value = _expr.StringImm(value)
# thread_extent could be zero for dynamic workloads
if attr_key == "thread_extent":
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
Expand Down
279 changes: 98 additions & 181 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,68 +51,8 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)


def rearrange_indices_out_ir(data, output, valid_box_count):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.
Parameters
----------
data : tvm.te.Tensor or numpy NDArray
NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6] or
[batch_size, num_anchors, 5], or 2-D
tensor with shape [batch_size, num_anchors].
one: tvm.tir.const
Constant one with the same dtype as data.
batch_size: tvm.tir.IntImm or tvm.tir.Var
Batch size. We need to pass it in since hybrid script doesn't support
binding variable to symbolic dim.
num_anchors: tvm.tir.IntImm or tvm.tir.Var
Number of anchors.
Returns
-------
output : tvm.te.Tensor or numpy NDArray
2-D tensor with shape [batch_size, num_anchors].
valid_box_count : tvm.te.Tensor or numpy NDArray
Tensor with shape [batch_size, 1], indicates
the valid number of boxes.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
valid_box_count = ib.buffer_ptr(valid_box_count)
output = ib.buffer_ptr(output)

with ib.new_scope():
i = te.thread_axis("blockIdx.x")
ib.scope_attr(i, "thread_extent", batch_size)
valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
valid_idx[0] = 0
with ib.for_range(0, num_anchors, name="j") as j:
with ib.if_scope(data[i, j] >= 0):
with ib.if_scope(data[i, j] > num_anchors):
output[i, valid_idx[0]] = 0
valid_idx[0] = valid_idx[0] + 1
with ib.else_scope():
output[i, valid_idx[0]] = data[i, j]
valid_idx[0] = valid_idx[0] + 1
with ib.else_scope():
with ib.if_scope(data[i, j] < -num_anchors):
output[i, valid_idx[0]] = 0
valid_idx[0] = valid_idx[0] + 1
with ib.if_scope(j >= valid_idx[0]):
output[i, j] = -1
valid_box_count[i, 0] = valid_idx[0]

return ib.get()
def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)


def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index):
Expand Down Expand Up @@ -400,6 +340,7 @@ def nms_ir(
indices,
out,
box_indices,
num_valid_boxes,
max_output_size,
iou_threshold,
force_suppress,
Expand Down Expand Up @@ -430,7 +371,15 @@ def nms_ir(
is not used before non_max_suppression.
out : Buffer
Output buffer.
Output buffer, to be filled with sorted boxes.
box_indices : Buffer
A indices tensor mapping sorted indices to original indices
This is the first output of NMS when return_indices=True.
num_valid_boxes : Buffer
Record the number of boxes that have survived IOU tests.
This is the second output of NMS when return_indices=True.
max_output_size : int
Max number of output valid boxes for each instance.
Expand Down Expand Up @@ -509,6 +458,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
sorted_index = ib.buffer_ptr(sorted_index)
valid_count = ib.buffer_ptr(valid_count)
indices = ib.buffer_ptr(indices)
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)

Expand All @@ -523,132 +473,111 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = by
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
nkeep = if_then_else(
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
)
with ib.for_range(0, nkeep) as j:
j = bx * max_threads + tx
with ib.if_scope(j < num_anchors):
box_indices[i * num_anchors + j] = -1
with ib.if_scope(j < nkeep):
# Fill in out with sorted boxes
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = data[
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
]
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
with ib.for_range(0, valid_count[i] - nkeep) as j:
with ib.else_scope():
# Indices > nkeep are discarded
with ib.if_scope(j < num_anchors):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
box_indices[i * num_anchors + (j + nkeep)] = -1
out[(base_idx + j * box_data_length + k)] = -1.0
with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
with ib.for_range(0, box_data_length) as k:
offset = base_idx + j * box_data_length + k
out[offset] = data[offset]
box_indices[i * num_anchors + j] = j

with ib.new_scope():
nthread_by = batch_size
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
i = by
base_idx = i * num_anchors * box_data_length
num_valid_boxes_local = ib.allocate(
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
num_valid_boxes_local[0] = 0

def nms_inner_loop(ib, j):
offset_j = j * box_data_length

with ib.for_range(0, j) as k:
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_j + score_index] > -1.0, # if already surpressed
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0

# Has the box j survived IOU tests?
with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
# When return_indices is False, no need to populate box_indices
if return_indices:
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
num_valid_boxes_local[0] += 1

if isinstance(max_output_size, int):
max_output_size = tvm.tir.const(max_output_size)

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.for_range(0, j) as k:
offset_k = k * box_data_length
with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
)
):
offset_j = j * box_data_length
with ib.if_scope(
tvm.tir.all(
j > k,
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_by = batch_size
nthread_bz = box_data_length
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
tid = bx * max_threads + tx
i = by
j = tid
k = bz
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
pass
with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
box_indices[i * num_anchors + j] = j

with ib.new_scope():
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
i = bx
base_idx = i * num_anchors * box_data_length
# Set invalid entry to be -1
with ib.for_range(0, num_anchors - valid_count[i]) as j:
with ib.for_range(0, box_data_length) as k:
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
box_indices[i * num_anchors + j + valid_count[i]] = -1
# Only return max_output_size number of valid boxes
num_valid_boxes[0] = 0
with ib.if_scope(max_output_size > 0):
with ib.for_range(0, valid_count[i]) as j:
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.if_scope(num_valid_boxes[0] == max_output_size):
with ib.for_range(0, box_data_length) as k:
out[base_idx + offset_j + k] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.if_scope(
tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
num_valid_boxes[0] += 1
nms_inner_loop(ib, j)

if return_indices:
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = bx * max_threads + tx
with ib.if_scope(i < batch_size):
with ib.for_range(0, valid_count[i]) as j:
idx = box_indices[i * num_anchors + j]
with ib.if_scope(idx >= 0):
box_indices[i * num_anchors + j] = indices[i * num_anchors + idx]
num_valid_boxes[i] = num_valid_boxes_local[0]

with ib.else_scope():
num_valid_boxes[i] = 0

return ib.get()

Expand Down Expand Up @@ -816,13 +745,11 @@ def non_max_suppression(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
)

indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

out, box_indices = te.extern(
[data.shape, score_shape],
out, box_indices, num_valid_boxes = te.extern(
[data.shape, score_shape, [batch_size, 1]],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
Expand All @@ -831,6 +758,7 @@ def non_max_suppression(
ins[3],
outs[0],
outs[1],
outs[2],
max_output_size,
iou_threshold,
force_suppress,
Expand All @@ -840,24 +768,13 @@ def non_max_suppression(
score_index,
return_indices,
),
dtype=[data.dtype, "int32"],
dtype=[data.dtype, "int32", "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
)

if return_indices:
out_shape = box_indices.shape
valid_box_count_shape = [box_indices.shape[0], 1]
valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count")
output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
return te.extern(
[out_shape, valid_box_count_shape],
[box_indices],
lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]),
dtype="int32",
out_buffers=[output, valid_box_count],
name="rearrange_indices_out_gpu",
tag="rearrange_indices_out_gpu",
)
return [box_indices, num_valid_boxes]

return out

0 comments on commit 82942fb

Please sign in to comment.