Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Dec 15, 2020
1 parent f2af723 commit 1771a20
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
from .. import tag


def ceil_div(a, b):
return (a + b - 1) // b


def swap(arr, axis):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
Expand Down Expand Up @@ -145,16 +141,20 @@ def ceil_div(a, b):
ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
idx = (by * shape[axis] + tid) * axis_mul_after + bz
values_out[idx] = data[idx]
if indices_out is not None:
indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)

## we are looping over the array doing mergesort from the bottom up.
## The outer loop runs on the host and launches a cuda kernel for each iteration
## of the algorithm.
## The basic idea is that at iteration 0, each thread does sort on 2 elements. On iteration 1, each thread merges 2 sorted arrays of 2 elements, to deal with 4 total elements. On iteration 2, each thread merges 2 sorted arrays of 4 elements, to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
## On the final iteration of the algorithm, one thread will merge two sorted lists to sort the entire array
## The basic idea is that at iteration 0, each thread does sort on 2 elements.
## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
## to deal with 4 total elements.
## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
## On the final iteration of the algorithm, one thread will merge two sorted lists
## to sort the entire array
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64"
)
Expand Down

0 comments on commit 1771a20

Please sign in to comment.