Skip to content

Commit

Permalink
better parallelize get_valid_counts
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Dec 10, 2020
1 parent f332512 commit 75af88f
Showing 1 changed file with 153 additions and 41 deletions.
194 changes: 153 additions & 41 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,89 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
return ib.get()


def get_valid_counts_ir(
data, valid_count, out, out_indices, score_threshold, id_index, score_index
):
def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index):
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)

valid_boxes = ib.buffer_ptr(valid_boxes)
if isinstance(score_threshold, float):
score_threshold = tvm.tir.FloatImm("float32", score_threshold)
id_index = tvm.tir.IntImm("int32", id_index)
score_index = tvm.tir.IntImm("int32", score_index)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * max_threads + tx

with ib.if_scope(tid < num_anchors):
i = by
j = tid
score = data[(i * num_anchors + j) * elem_length + score_index]
with ib.if_scope(
tvm.tir.all(
score > score_threshold,
tvm.tir.any(
id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0
),
)
):
valid_boxes[i * num_anchors + j] = 1
with ib.else_scope():
valid_boxes[i * num_anchors + j] = 0
return ib.get()


def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
batch_size = valid_boxes.shape[0]
num_anchors = valid_boxes.shape[1]

ib = tvm.tir.ir_builder.create()

valid_boxes = ib.buffer_ptr(valid_boxes)

valid_count = ib.buffer_ptr(valid_count)
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
# TODO(mbrookhart): Parallelize the sum and cumsum here
current_index = ib.allocate("int32", (1,), name="current_index", scope="local")
with ib.if_scope(tid < batch_size):
current_index[0] = 0
valid_count[tid] = 0
with ib.for_range(0, num_anchors) as j:
idx = tid * num_anchors + j
valid_count[tid] = valid_count[tid] + valid_boxes[idx]
with ib.if_scope(valid_boxes[idx] == 1):
valid_indices[idx] = current_index[0]
current_index[0] = current_index[0] + 1
with ib.else_scope():
valid_indices[idx] = -1
return ib.get()


def get_valid_counts_ir(data, valid_indices, out, out_indices):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
top of input data.
Expand Down Expand Up @@ -158,47 +238,51 @@ def get_valid_counts_ir(

data = ib.buffer_ptr(data)

valid_count = ib.buffer_ptr(valid_count)
valid_indices = ib.buffer_ptr(valid_indices)
out = ib.buffer_ptr(out)
out_indices = ib.buffer_ptr(out_indices)
one = tvm.tir.const(1, dtype=out.dtype)
if isinstance(score_threshold, float):
score_threshold = tvm.tir.FloatImm("float32", score_threshold)
id_index = tvm.tir.IntImm("int32", id_index)
score_index = tvm.tir.IntImm("int32", score_index)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_by = batch_size
nthread_bz = elem_length
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
valid_count[tid] = 0
i = tid
with ib.for_range(0, num_anchors) as j:
score = data[(i * num_anchors + j) * elem_length + score_index]
with ib.if_scope(
tvm.tir.all(
score > score_threshold,
tvm.tir.any(
id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0
),
)
):
with ib.for_range(0, elem_length) as k:
out[(i * num_anchors + valid_count[i]) * elem_length + k] = data[
(i * num_anchors + j) * elem_length + k
]
out_indices[i * num_anchors + valid_count[i]] = j
valid_count[i] += 1
with ib.if_scope(j >= valid_count[i]):
with ib.for_range(0, elem_length) as k:
out[(i * num_anchors + j) * elem_length + k] = -one
out_indices[i * num_anchors + j] = -1
with ib.if_scope(tid < num_anchors):
i = by
j = tid
k = bz
out[(i * num_anchors + j) * elem_length + k] = -one
out_indices[i * num_anchors + j] = -1
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
tid = bx * max_threads + tx
with ib.if_scope(tid < num_anchors):
i = by
j = tid
k = bz
with ib.if_scope(valid_indices[i, tid] >= 0):
out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[
(i * num_anchors + j) * elem_length + k
]
out_indices[i * num_anchors + valid_indices[i, tid]] = j
return ib.get()


Expand Down Expand Up @@ -231,23 +315,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
batch_size = data.shape[0]
num_anchors = data.shape[1]
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
valid_boxes_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "valid_boxes_buf", data_alignment=8
)
valid_boxes = te.extern(
[(batch_size, num_anchors)],
[data],
lambda ins, outs: get_valid_boxes_ir(
ins[0], outs[0], score_threshold, id_index, score_index
),
dtype=["int32"],
in_buffers=[data_buf],
out_buffers=[valid_boxes_buf],
name="get_valid_boxes",
tag="get_valid_boxes_gpu",
)

valid_indices_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8
)
valid_count_buf = tvm.tir.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8
)
valid_count, valid_indices = te.extern(
[(batch_size,), (batch_size, num_anchors)],
[valid_boxes],
lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]),
dtype=["int32"],
in_buffers=[valid_boxes_buf],
out_buffers=[valid_count_buf, valid_indices_buf],
name="get_valid_indices",
tag="get_valid_indices_gpu",
)

out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
out_indices_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "out_buf", data_alignment=8
)

valid_count, out, out_indices = te.extern(
[(batch_size,), data.shape, (batch_size, num_anchors)],
[data],
lambda ins, outs: get_valid_counts_ir(
ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index
),
out, out_indices = te.extern(
[data.shape, (batch_size, num_anchors)],
[data, valid_indices],
lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]),
dtype=["int32", data.dtype],
in_buffers=[data_buf],
out_buffers=[valid_count_buf, out_buf, out_indices_buf],
in_buffers=[data_buf, valid_indices_buf],
out_buffers=[out_buf, out_indices_buf],
name="get_valid_counts",
tag="get_valid_counts_gpu",
)
Expand Down

0 comments on commit 75af88f

Please sign in to comment.