From 46ce5f04f8fd9480a750d0beefa0f22dff3937b4 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 24 Dec 2020 07:33:47 +0900 Subject: [PATCH] [TOPI] GPU sort IR refactor to enable sort by keys (#7157) * sort refactor initial import * sort test working * scatter 1d with positive indices working * remove negatiev indices, using extern for now * minor fix * minor fix * add sort by key test * revert scatter change * add document * fix py format Co-authored-by: masa --- python/tvm/topi/cuda/nms.py | 4 +- python/tvm/topi/cuda/scatter.py | 12 +- python/tvm/topi/cuda/sort.py | 561 ++++++++++++++++-------------- tests/python/contrib/test_sort.py | 35 +- 4 files changed, 343 insertions(+), 269 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index cea287edd62ea..020cf9b5bc631 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -737,9 +737,7 @@ def non_max_suppression( score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype ) else: - sort_tensor = argsort( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype - ) + sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype) sort_tensor_buf = tvm.tir.decl_buffer( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 9916e2a7fa6d5..be602c8ab7a30 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -424,6 +424,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): Sorting of indices, and sorting of updates with respect to indices, can be done at the same time by thrust's sort_by_key function. It is important that sorting be done in a "stable" way via stable_sort, to guarantee deterministic output. + Negative indices are assumed to have been converted to corresponding positive + indices. Parameters ---------- @@ -473,12 +475,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): ni = indices_sorted.shape[0] - def do_update(ib, index, update): - with ib.if_scope(index < 0): - out_ptr[index + n] = update - with ib.else_scope(): - out_ptr[index] = update - with ib.new_scope(): nthread_bx = ceil_div(ni, nthread_tx) tx = te.thread_axis("threadIdx.x") @@ -491,7 +487,7 @@ def do_update(ib, index, update): # The last element can always update. index = indices_ptr[tid] update = updates_ptr[tid] - do_update(ib, index, update) + out_ptr[index] = update with ib.else_scope(): with ib.if_scope(tid < ni - 1): @@ -503,7 +499,7 @@ def do_update(ib, index, update): # This thread can update the output. with ib.if_scope(index != index_next): update = updates_ptr[tid] - do_update(ib, index, update) + out_ptr[index] = update return ib.get() diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 039ebe3aea4ea..18872a242160f 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -21,7 +21,6 @@ from tvm._ffi import get_global_func from .injective import schedule_injective_from_existing -from ..math import identity from ..transform import strided_slice, transpose from .. import tag @@ -62,46 +61,14 @@ def traverse(op): return s -def sort_ir( - data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None -): - """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. - - Parameters - ---------- - data: Buffer - Buffer of input data. Data will be sorted in place. - - values_out : Buffer - Output buffer of values of sorted tensor with same shape as data. - - values_out_swap : Buffer - Output buffer of values with same shape as data to use as swap. - - axis : Int - Axis long which to sort the input tensor. - - is_ascend : Boolean - Whether to sort in ascending or descending order. - - indicess_out : Buffer - Output buffer of indices of sorted tensor with same shape as data. - - indices_out_swap : Buffer - Output buffer of indices with same shape as data to use as swap. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ +def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) - def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) +def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): + """Initialize the output buffers by copying from inputs""" axis_mul_before = 1 axis_mul_after = 1 - shape = data.shape if axis < 0: axis = len(shape) + axis for i, value in enumerate(shape, 0): @@ -110,16 +77,6 @@ def ceil_div(a, b): elif i > axis: axis_mul_after *= value - ib = tvm.tir.ir_builder.create() - - data = ib.buffer_ptr(data) - values_out = ib.buffer_ptr(values_out) - values_out_swap = ib.buffer_ptr(values_out_swap) - if indices_out is not None: - indices_out = ib.buffer_ptr(indices_out) - assert indices_out_swap is not None - indices_out_swap = ib.buffer_ptr(indices_out_swap) - # Set up threading max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads @@ -127,7 +84,7 @@ def ceil_div(a, b): nthread_by = axis_mul_before nthread_bz = axis_mul_after - # Copy the data to initial output + # Copy the keys_in to initial output with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -141,9 +98,25 @@ def ceil_div(a, b): ib.scope_attr(bz, "thread_extent", nthread_bz) idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): - values_out[idx] = data[idx] - if indices_out is not None: - indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) + keys_out[idx] = keys_in[idx] + if values_out is not None: + values_out[idx] = value_init_func(idx, tid) + + return axis_mul_before, axis_mul_after + + +def _sort_common( + ib, + size, + axis_mul_before, + axis_mul_after, + is_ascend, + keys, + keys_swap, + values=None, + values_swap=None, +): + """Either sort only values or sort values by keys.""" ## we are looping over the array doing mergesort from the bottom up. ## The outer loop runs on the host and launches a cuda kernel for each iteration @@ -155,8 +128,85 @@ def ceil_div(a, b): ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc ## On the final iteration of the algorithm, one thread will merge two sorted lists ## to sort the entire array + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(size, max_threads) + nthread_by = axis_mul_before + nthread_bz = axis_mul_after + + def compare(a, b): + """ + Compare a and b in proper ascending or descending order + """ + if is_ascend: + out = a <= b + else: + out = b <= a + return out + + def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): + """ + Merge the two sections of the array assigned to this thread + """ + # pylint: disable=arguments-out-of-order + # initialize iterators + i[0] = start + j[0] = middle + # set up indexes + base_idx = by * size * axis_mul_after + bz + # iterate over the output loop + with ib.for_range(0, end - start) as k: + i_idx = base_idx + i[0] * axis_mul_after + j_idx = base_idx + j[0] * axis_mul_after + k_idx = base_idx + (k + start) * axis_mul_after + + def swap_values(source, dest, source_idx, dest_idx): + def assign_i(): + """assign i value to current output""" + dest[k_idx] = source[i_idx] + if values is not None: + dest_idx[k_idx] = source_idx[i_idx] + i[0] += 1 + + def assign_j(): + """assign j value to current output""" + dest[k_idx] = source[j_idx] + if values is not None: + dest_idx[k_idx] = source_idx[j_idx] + j[0] += 1 + + ## if both of the iterators are in range + with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): + # compare them and insert whichever is next into the output + with ib.if_scope(compare(source[i_idx], source[j_idx])): + assign_i() + with ib.else_scope(): + assign_j() + # otherwise, simply copy the remainder of the valid iterator to the output + with ib.else_scope(): + with ib.if_scope(i[0] < middle): + assign_i() + with ib.else_scope(): + assign_j() + + # Switch which input is the source and which is the destination each iteration + with ib.if_scope(even): + swap_values(source, dest, source_idx, dest_idx) + with ib.else_scope(): + swap_values(dest, source, dest_idx, source_idx) + + def mergesort(source, dest, source_idx, dest_idx, size, width, even): + # calculate the start, mid, and end points of this section + start[0] = width * tid + with ib.if_scope(start[0] < size): + middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) + end[0] = tvm.te.min(start[0] + width, size) + ## merge the start->middle and middle->end arrays + bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even) + lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width @@ -174,7 +224,7 @@ def ceil_div(a, b): ib.scope_attr( bx, "thread_extent", - tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"), + tvm.tir.generic.cast(ceil_div(size, width * max_threads), "int32"), ) tid = bx * nthread_tx + tx @@ -183,85 +233,13 @@ def ceil_div(a, b): ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(bz, "thread_extent", nthread_bz) - def compare(a, b): - """ - Compare a and b in proper ascending or descending order - """ - if is_ascend: - out = a <= b - else: - out = b <= a - return out - - def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): - """ - Merge the two sections of the array assigned to this thread - """ - # pylint: disable=arguments-out-of-order - # initialize iterators - i[0] = start - j[0] = middle - # set up indexes - base_idx = by * shape[axis] * axis_mul_after + bz - # iterate over the output loop - with ib.for_range(0, end - start) as k: - i_idx = base_idx + i[0] * axis_mul_after - j_idx = base_idx + j[0] * axis_mul_after - k_idx = base_idx + (k + start) * axis_mul_after - - def swap_values(source, dest, source_idx, dest_idx): - def assign_i(): - """assign i value to current output""" - dest[k_idx] = source[i_idx] - if indices_out is not None: - dest_idx[k_idx] = source_idx[i_idx] - i[0] += 1 - - def assign_j(): - """assign j value to current output""" - dest[k_idx] = source[j_idx] - if indices_out is not None: - dest_idx[k_idx] = source_idx[j_idx] - j[0] += 1 - - ## if both of the iterators are in range - with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): - # compare them and insert whichever is next into the output - with ib.if_scope(compare(source[i_idx], source[j_idx])): - assign_i() - with ib.else_scope(): - assign_j() - # otherwise, simply copy the remainder of the valid iterator to the output - with ib.else_scope(): - with ib.if_scope(i[0] < middle): - assign_i() - with ib.else_scope(): - assign_j() - - # Switch which input is the source and which is the destination each iteration - with ib.if_scope(even): - swap_values(source, dest, source_idx, dest_idx) - with ib.else_scope(): - swap_values(dest, source, dest_idx, source_idx) - - def mergesort(source, dest, source_idx, dest_idx, size, width, even): - # calculate the start, mid, and end points of this section - start[0] = width * tid - with ib.if_scope(start[0] < size): - middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) - end[0] = tvm.te.min(start[0] + width, size) - ## merge the start->middle and middle->end arrays - bottom_up_merge( - source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even - ) - # Call the kernel mergesort( - values_out, - values_out_swap, - indices_out, - indices_out_swap, - shape[axis], + keys, + keys_swap, + values, + values_swap, + size, width, tvm.tir.indexmod(l2_width, 2) == 0, ) @@ -279,29 +257,31 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): bz = te.thread_axis("blockIdx.z") ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(bz, "thread_extent", nthread_bz) - idx = (by * shape[axis] + tid) * axis_mul_after + bz - with ib.if_scope(tid < shape[axis]): - idx = (by * shape[axis] + tid) * axis_mul_after + bz - values_out[idx] = values_out_swap[idx] - if indices_out is not None: - indices_out[idx] = indices_out_swap[idx] + idx = (by * size + tid) * axis_mul_after + bz + with ib.if_scope(tid < size): + idx = (by * size + tid) * axis_mul_after + bz + keys[idx] = keys_swap[idx] + if values is not None: + values[idx] = values_swap[idx] return ib.get() -def sort_nms_ir(data, valid_count, output, axis, is_ascend): - """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. +def sort_ir( + data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None +): + """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. Parameters ---------- data: Buffer - Buffer of input data. + Buffer of input data. Data will be sorted in place. - valid_count : Buffer - 1D Buffer of number of valid number of boxes. + values_out : Buffer + Output buffer of values of sorted tensor with same shape as data. - output : Buffer - Output buffer of indicies of sorted tensor with same shape as data. + values_out_swap : Buffer + Output buffer of values with same shape as data to use as swap. axis : Int Axis long which to sort the input tensor. @@ -309,82 +289,124 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): is_ascend : Boolean Whether to sort in ascending or descending order. + indicess_out : Buffer + Output buffer of indices of sorted tensor with same shape as data. + + indices_out_swap : Buffer + Output buffer of indices with same shape as data to use as swap. + Returns ------- stmt : Stmt The result IR statement. """ - - size = 1 - axis_mul_before = 1 - axis_mul_after = 1 - shape = data.shape - if axis < 0: - axis = len(shape) + axis - for i, value in enumerate(shape, 0): - size *= value - if i < axis: - axis_mul_before *= value - elif i > axis: - axis_mul_after *= value - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) ib = tvm.tir.ir_builder.create() + shape = data.shape + data = ib.buffer_ptr(data) - valid_count = ib.buffer_ptr(valid_count) - output = ib.buffer_ptr(output) - nthread_tx = max_threads - nthread_bx = size // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") - is_ascend = tvm.tir.IntImm("int32", is_ascend) - - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.for_range(0, axis_mul_before) as i: - with ib.for_range(0, axis_mul_after) as j: - current_sort_num = valid_count[i * axis_mul_after + j] - base_idx = i * shape[axis] * axis_mul_after + j - with ib.if_scope(tid < shape[axis]): - output[base_idx + tid * axis_mul_after] = tid - # OddEvenTransposeSort - with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): - offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after - with ib.if_scope( - tvm.tir.all( - is_ascend == 1, - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - data[offset] > data[offset + axis_mul_after], - ) - ): - temp_data[0] = data[offset] - data[offset] = data[offset + axis_mul_after] - data[offset + axis_mul_after] = temp_data[0] - temp_index[0] = output[offset] - output[offset] = output[offset + axis_mul_after] - output[offset + axis_mul_after] = temp_index[0] - with ib.if_scope( - tvm.tir.all( - is_ascend == 0, - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - data[offset] < data[offset + axis_mul_after], - ) - ): - temp_data[0] = data[offset] - data[offset] = data[offset + axis_mul_after] - data[offset + axis_mul_after] = temp_data[0] - temp_index[0] = output[offset] - output[offset] = output[offset + axis_mul_after] - output[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + values_out = ib.buffer_ptr(values_out) + values_out_swap = ib.buffer_ptr(values_out_swap) + if indices_out is not None: + indices_out = ib.buffer_ptr(indices_out) + assert indices_out_swap is not None + indices_out_swap = ib.buffer_ptr(indices_out_swap) - return ib.get() + axis_mul_before, axis_mul_after = _sort_init( + ib, + shape, + axis, + data, + values_out, + indices_out, + value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype), + ) + + return _sort_common( + ib, + shape[axis], + axis_mul_before, + axis_mul_after, + is_ascend, + values_out, + values_out_swap, + values=indices_out, + values_swap=indices_out_swap, + ) + + +def sort_by_key_ir( + keys_in, values_in, keys_out, values_out, keys_out_swap, values_out_swap, axis, is_ascend +): + """Low level IR to do sort by key on the GPU. + + Parameters + ---------- + keys_in: Buffer + Buffer of input keys. + + values_in: Buffer + Buffer of input keys. + + keys_out : Buffer + Buffer of output sorted keys. + + values_out : Buffer + Buffer of output sorted values. + + keys_out_swap : Buffer + Output buffer of values with same shape as keys_in to use as swap. + + values_out_swap : Buffer + Output buffer of values with same shape as values_in to use as swap. + + axis : Int + Axis long which to sort the input tensor. + + is_ascend : Boolean + Whether to sort in ascending or descending order. + + indicess_out : Buffer + Output buffer of indices of sorted tensor with same shape as keys_in. + + values_out_swap : Buffer + Output buffer of indices with same shape as keys_in to use as swap. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + shape = keys_in.shape + + keys_in = ib.buffer_ptr(keys_in) + values_in = ib.buffer_ptr(values_in) + keys_out = ib.buffer_ptr(keys_out) + keys_out_swap = ib.buffer_ptr(keys_out_swap) + values_out = ib.buffer_ptr(values_out) + values_out_swap = ib.buffer_ptr(values_out_swap) + + axis_mul_before, axis_mul_after = _sort_init( + ib, + shape, + axis, + keys_in, + keys_out, + values_out, + value_init_func=lambda idx, _: values_in[idx], + ) + + return _sort_common( + ib, + shape[axis], + axis_mul_before, + axis_mul_after, + is_ascend, + keys_out, + keys_out_swap, + values=values_out, + values_swap=values_out_swap, + ) def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"): @@ -534,7 +556,7 @@ def sort_thrust(data, axis=-1, is_ascend=1): return out -def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -543,9 +565,6 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): data: tvm.te.Tensor The input array. - valid_count : tvm.te.Tensor, optional - The number of valid elements to be sorted. - axis : int, optional Axis long which to sort the input tensor. @@ -560,48 +579,26 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): out : tvm.te.Tensor The output of this function. """ - if valid_count is not None: - sorted_data = identity(data) - sorted_data_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "sorted_data_buf", data_alignment=8 - ) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4 - ) - out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) - out = te.extern( - [data.shape], - [sorted_data, valid_count], - lambda ins, outs: sort_nms_ir(ins[0], ins[1], outs[0], axis, is_ascend), - dtype="int32", - in_buffers=[sorted_data_buf, valid_count_buf], - out_buffers=[out_buf], - name="argsort_nms_gpu", - tag="argsort_nms_gpu", - ) - else: - value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - value_swap_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "value_swap_buf", data_alignment=8 - ) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) - indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) - out = te.extern( - [data.shape, data.shape, data.shape, data.shape], - [data], - lambda ins, outs: sort_ir( - ins[0], - outs[0], - outs[2], - axis, - is_ascend, - indices_out=outs[1], - indices_out_swap=outs[3], - ), - out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], - name="argsort_gpu", - tag="argsort_gpu", - )[1] + value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8) + indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) + out = te.extern( + [data.shape, data.shape, data.shape, data.shape], + [data], + lambda ins, outs: sort_ir( + ins[0], + outs[0], + outs[2], + axis, + is_ascend, + indices_out=outs[1], + indices_out_swap=outs[3], + ), + out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], + name="argsort_gpu", + tag="argsort_gpu", + )[1] return out @@ -862,6 +859,56 @@ def schedule_topk(outs): return _schedule_sort(outs) +def sort_by_key(keys, values, axis=-1, is_ascend=1): + """Sort values with respect to keys. Both keys and values will + be sorted and returned. + + Parameters + ---------- + keys: tvm.te.Tensor + The input keys. + + values : tvm.te.Tensor, + The input values. + + axis : int, optional + Axis long which to sort the input tensor. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + Returns + ------- + keys_sorted : tvm.te.Tensor + The sorted keys + + values_sorted : tvm.te.Tensor + The values sorted with respect to the keys + """ + keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8) + values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8) + + out_bufs = [ + tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8), + tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8), + tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf", data_alignment=8), + tvm.tir.decl_buffer(values.shape, values.dtype, "values_swap_buf", data_alignment=8), + ] + out = te.extern( + [keys.shape, values.shape, keys.shape, values.shape], + [keys, values], + lambda ins, outs: sort_by_key_ir( + ins[0], ins[1], outs[0], outs[1], outs[2], outs[3], axis, is_ascend + ), + in_buffers=[keys_buf, values_buf], + out_buffers=out_bufs, + dtype=[keys.dtype, values.dtype], + name="sort_by_key", + tag="sort_by_key", + ) + return out[0], out[1] + + def stable_sort_by_key_thrust(keys, values, for_scatter=False): """Sort values with respect to keys using thrust. Both keys and values will be sorted and returned. diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index 9d6eb7cb3a1e4..f338276ca1189 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available +from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available, sort_by_key import numpy as np @@ -123,7 +123,40 @@ def test_thrust_stable_sort_by_key(): tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) +def test_sort_by_key_gpu(): + size = 6 + keys = te.placeholder((size,), name="keys", dtype="int32") + values = te.placeholder((size,), name="values", dtype="int32") + + for target in ["cuda", "nvptx", "opencl", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue + + with tvm.target.Target(target): + keys_out, values_out = sort_by_key(keys, values) + ctx = tvm.context(target) + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + if __name__ == "__main__": test_sort() test_sort_np() test_thrust_stable_sort_by_key() + test_sort_by_key_gpu()