diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 445b0212426d1..1e919ab6d7c7d 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -26,10 +26,6 @@ from .. import tag -def ceil_div(a, b): - return (a + b - 1) // b - - def swap(arr, axis): """ swap arr[axis] and arr[-1] """ return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] @@ -145,7 +141,6 @@ def ceil_div(a, b): ib.scope_attr(bz, "thread_extent", nthread_bz) idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): - idx = (by * shape[axis] + tid) * axis_mul_after + bz values_out[idx] = data[idx] if indices_out is not None: indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) @@ -153,8 +148,13 @@ def ceil_div(a, b): ## we are looping over the array doing mergesort from the bottom up. ## The outer loop runs on the host and launches a cuda kernel for each iteration ## of the algorithm. - ## The basic idea is that at iteration 0, each thread does sort on 2 elements. On iteration 1, each thread merges 2 sorted arrays of 2 elements, to deal with 4 total elements. On iteration 2, each thread merges 2 sorted arrays of 4 elements, to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc - ## On the final iteration of the algorithm, one thread will merge two sorted lists to sort the entire array + ## The basic idea is that at iteration 0, each thread does sort on 2 elements. + ## On iteration 1, each thread merges 2 sorted arrays of 2 elements, + ## to deal with 4 total elements. + ## On iteration 2, each thread merges 2 sorted arrays of 4 elements, + ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc + ## On the final iteration of the algorithm, one thread will merge two sorted lists + ## to sort the entire array lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" )