From 6f6c82142925f07d3fd6c4f8c5bce1f0fe7199b6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 Mar 2021 17:20:58 +0900 Subject: [PATCH 1/4] sort started to working --- python/tvm/topi/cuda/sort.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ff5cc0681ad2..0dc8c3e358e5 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -142,6 +142,8 @@ def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even """ # pylint: disable=arguments-out-of-order # initialize iterators + i = ib.allocate("int64", (1,), name="i", scope="local") + j = ib.allocate("int64", (1,), name="j", scope="local") i[0] = start j[0] = middle # set up indexes @@ -189,12 +191,16 @@ def assign_j(): def mergesort(source, dest, source_idx, dest_idx, size, width, even): # calculate the start, mid, and end points of this section - start[0] = width * tid - with ib.if_scope(start[0] < size): - middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) - end[0] = tvm.te.min(start[0] + width, size) + start = width * tid + + with ib.if_scope(start < size): + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + + middle[0] = tvm.te.min(start + tvm.tir.indexdiv(width, 2), size) + end[0] = tvm.te.min(start + width, size) ## merge the start->middle and middle->end arrays - bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even) + bottom_up_merge(source, dest, source_idx, dest_idx, start, middle[0], end[0], even) lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" @@ -203,11 +209,6 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): width = 2 << l2_width # Define and launch the cuda kernel with ib.new_scope(): - i = ib.allocate("int64", (1,), name="i", scope="local") - j = ib.allocate("int64", (1,), name="j", scope="local") - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) From a0353caeddcfc0d8908fe51b4d179adc57c66cbd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 Mar 2021 17:30:22 +0900 Subject: [PATCH 2/4] static size sort seems to be working --- python/tvm/topi/cuda/sort.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 0dc8c3e358e5..ca832ef0ef36 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,6 +23,7 @@ from ..transform import strided_slice, transpose from .. import tag from ..utils import ceil_div, swap +from ..math import cast def _schedule_sort(outs): @@ -194,13 +195,10 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): start = width * tid with ib.if_scope(start < size): - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - - middle[0] = tvm.te.min(start + tvm.tir.indexdiv(width, 2), size) - end[0] = tvm.te.min(start + width, size) - ## merge the start->middle and middle->end arrays - bottom_up_merge(source, dest, source_idx, dest_idx, start, middle[0], end[0], even) + middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64") + end = cast(tvm.te.min(start + width, size), "int64") + # merge the start->middle and middle->end arrays + bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even) lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" From 203078692cedc7280f328e28e4c5e4fdc0acfac5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 8 Mar 2021 17:42:22 +0900 Subject: [PATCH 3/4] test sort on vulkan --- tests/python/topi/python/test_topi_sort.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/topi/python/test_topi_sort.py b/tests/python/topi/python/test_topi_sort.py index 626218f30144..0879a0aa9dfe 100644 --- a/tests/python/topi/python/test_topi_sort.py +++ b/tests/python/topi/python/test_topi_sort.py @@ -75,7 +75,7 @@ def check_device(device): f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_sort, rtol=1e0) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan"]: check_device(device) @@ -115,7 +115,7 @@ def check_device(device): f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan"]: check_device(device) @@ -167,7 +167,7 @@ def check_device(device): else: tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan"]: check_device(device) From fee29cdbf67d10abb09f8a6f1e4894defc59c80f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Mar 2021 05:15:43 +0900 Subject: [PATCH 4/4] add nvptx to sort test too --- tests/python/topi/python/test_topi_sort.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/topi/python/test_topi_sort.py b/tests/python/topi/python/test_topi_sort.py index 0879a0aa9dfe..85a35488ab22 100644 --- a/tests/python/topi/python/test_topi_sort.py +++ b/tests/python/topi/python/test_topi_sort.py @@ -75,7 +75,7 @@ def check_device(device): f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_sort, rtol=1e0) - for device in ["llvm", "cuda", "opencl", "vulkan"]: + for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: check_device(device) @@ -115,7 +115,7 @@ def check_device(device): f(tvm_data, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0) - for device in ["llvm", "cuda", "opencl", "vulkan"]: + for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: check_device(device) @@ -167,7 +167,7 @@ def check_device(device): else: tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices) - for device in ["llvm", "cuda", "opencl", "vulkan"]: + for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]: check_device(device)