Skip to content

Commit

Permalink
Fix Get Valid Counts when the number of boxes is zero (#7229)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored Jan 8, 2021
1 parent 4911a08 commit 4e8cc4f
Showing 1 changed file with 89 additions and 83 deletions.
172 changes: 89 additions & 83 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,98 +151,104 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.if_scope(num_anchors > 0):
# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
with ib.else_scope():
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = 0

return ib.get()

Expand Down

0 comments on commit 4e8cc4f

Please sign in to comment.