Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Get Valid Counts when the number of boxes is zero #7229

Merged
merged 1 commit into from
Jan 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 89 additions & 83 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,98 +151,104 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.if_scope(num_anchors > 0):
# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
with ib.else_scope():
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = 0

return ib.get()

Expand Down