diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ff5cc0681ad2..ca832ef0ef36 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,6 +23,7 @@ from ..transform import strided_slice, transpose from .. import tag from ..utils import ceil_div, swap +from ..math import cast def _schedule_sort(outs): @@ -142,6 +143,8 @@ def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even """ # pylint: disable=arguments-out-of-order # initialize iterators + i = ib.allocate("int64", (1,), name="i", scope="local") + j = ib.allocate("int64", (1,), name="j", scope="local") i[0] = start j[0] = middle # set up indexes @@ -189,12 +192,13 @@ def assign_j(): def mergesort(source, dest, source_idx, dest_idx, size, width, even): # calculate the start, mid, and end points of this section - start[0] = width * tid - with ib.if_scope(start[0] < size): - middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) - end[0] = tvm.te.min(start[0] + width, size) - ## merge the start->middle and middle->end arrays - bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even) + start = width * tid + + with ib.if_scope(start < size): + middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64") + end = cast(tvm.te.min(start + width, size), "int64") + # merge the start->middle and middle->end arrays + bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even) lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" @@ -203,11 +207,6 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): width = 2 << l2_width # Define and launch the cuda kernel with ib.new_scope(): - i = ib.allocate("int64", (1,), name="i", scope="local") - j = ib.allocate("int64", (1,), name="j", scope="local") - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) diff --git a/tests/python/topi/python/test_topi_sort.py b/tests/python/topi/python/test_topi_sort.py index 626218f30144..85a35488ab22 100644 --- a/tests/python/topi/python/test_topi_sort.py +++ b/tests/python/topi/python/test_topi_sort.py @@ -75,7 +75,7 @@ def check_device(device): f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_sort, rtol=1e0) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: check_device(device) @@ -115,7 +115,7 @@ def check_device(device): f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: check_device(device) @@ -167,7 +167,7 @@ def check_device(device): else: tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: check_device(device)