Skip to content

Commit

Permalink
[TOPI] GPU sort IR refactor to enable sort by keys (apache#7157)
Browse files Browse the repository at this point in the history
* sort refactor initial import

* sort test working

* scatter 1d with positive indices working

* remove negatiev indices, using extern for now

* minor fix

* minor fix

* add sort by key test

* revert scatter change

* add document

* fix py format

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa committed Dec 24, 2020
1 parent b5aef16 commit 25a8d00
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 269 deletions.
4 changes: 1 addition & 3 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,7 @@ def non_max_suppression(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)
else:
sort_tensor = argsort(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)

sort_tensor_buf = tvm.tir.decl_buffer(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
Sorting of indices, and sorting of updates with respect to indices, can be done
at the same time by thrust's sort_by_key function. It is important that sorting
be done in a "stable" way via stable_sort, to guarantee deterministic output.
Negative indices are assumed to have been converted to corresponding positive
indices.
Parameters
----------
Expand Down Expand Up @@ -473,12 +475,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):

ni = indices_sorted.shape[0]

def do_update(ib, index, update):
with ib.if_scope(index < 0):
out_ptr[index + n] = update
with ib.else_scope():
out_ptr[index] = update

with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
Expand All @@ -491,7 +487,7 @@ def do_update(ib, index, update):
# The last element can always update.
index = indices_ptr[tid]
update = updates_ptr[tid]
do_update(ib, index, update)
out_ptr[index] = update

with ib.else_scope():
with ib.if_scope(tid < ni - 1):
Expand All @@ -503,7 +499,7 @@ def do_update(ib, index, update):
# This thread can update the output.
with ib.if_scope(index != index_next):
update = updates_ptr[tid]
do_update(ib, index, update)
out_ptr[index] = update

return ib.get()

Expand Down
Loading

0 comments on commit 25a8d00

Please sign in to comment.