diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 058bd62d62267..dc9d741b27264 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host): lambda f: "calling_conv" not in f.attrs or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH ), - tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)), + tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerCustomDatatypes(), diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ea149054fa65f..039ebe3aea4ea 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return """Sort related operators """ import tvm from tvm import te @@ -62,7 +62,9 @@ def traverse(op): return s -def sort_ir(data, values_out, axis, is_ascend, indices_out=None): +def sort_ir( + data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None +): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. Parameters @@ -70,8 +72,11 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): data: Buffer Buffer of input data. Data will be sorted in place. - output : Buffer - Output buffer of indicies of sorted tensor with same shape as data. + values_out : Buffer + Output buffer of values of sorted tensor with same shape as data. + + values_out_swap : Buffer + Output buffer of values with same shape as data to use as swap. axis : Int Axis long which to sort the input tensor. @@ -79,11 +84,21 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): is_ascend : Boolean Whether to sort in ascending or descending order. + indicess_out : Buffer + Output buffer of indices of sorted tensor with same shape as data. + + indices_out_swap : Buffer + Output buffer of indices with same shape as data to use as swap. + Returns ------- stmt : Stmt The result IR statement. """ + + def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) + axis_mul_before = 1 axis_mul_after = 1 shape = data.shape @@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): axis_mul_before *= value elif i > axis: axis_mul_after *= value - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + ib = tvm.tir.ir_builder.create() + data = ib.buffer_ptr(data) values_out = ib.buffer_ptr(values_out) + values_out_swap = ib.buffer_ptr(values_out_swap) if indices_out is not None: indices_out = ib.buffer_ptr(indices_out) - nthread_tx = max_threads - nthread_bx = shape[axis] // max_threads + 1 + assert indices_out_swap is not None + indices_out_swap = ib.buffer_ptr(indices_out_swap) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * nthread_tx + tx - temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local") - if indices_out is not None: - temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local") + # Set up threading + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(shape[axis], max_threads) + nthread_by = axis_mul_before + nthread_bz = axis_mul_after + + # Copy the data to initial output + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + idx = (by * shape[axis] + tid) * axis_mul_after + bz + with ib.if_scope(tid < shape[axis]): + values_out[idx] = data[idx] + if indices_out is not None: + indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) + + ## we are looping over the array doing mergesort from the bottom up. + ## The outer loop runs on the host and launches a cuda kernel for each iteration + ## of the algorithm. + ## The basic idea is that at iteration 0, each thread does sort on 2 elements. + ## On iteration 1, each thread merges 2 sorted arrays of 2 elements, + ## to deal with 4 total elements. + ## On iteration 2, each thread merges 2 sorted arrays of 4 elements, + ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc + ## On the final iteration of the algorithm, one thread will merge two sorted lists + ## to sort the entire array + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + # Define and launch the cuda kernel + with ib.new_scope(): + i = ib.allocate("int64", (1,), name="i", scope="local") + j = ib.allocate("int64", (1,), name="j", scope="local") + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + # Reduce the number of blocks as the work per thread grows + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + + def compare(a, b): + """ + Compare a and b in proper ascending or descending order + """ + if is_ascend: + out = a <= b + else: + out = b <= a + return out + + def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): + """ + Merge the two sections of the array assigned to this thread + """ + # pylint: disable=arguments-out-of-order + # initialize iterators + i[0] = start + j[0] = middle + # set up indexes + base_idx = by * shape[axis] * axis_mul_after + bz + # iterate over the output loop + with ib.for_range(0, end - start) as k: + i_idx = base_idx + i[0] * axis_mul_after + j_idx = base_idx + j[0] * axis_mul_after + k_idx = base_idx + (k + start) * axis_mul_after + + def swap_values(source, dest, source_idx, dest_idx): + def assign_i(): + """assign i value to current output""" + dest[k_idx] = source[i_idx] + if indices_out is not None: + dest_idx[k_idx] = source_idx[i_idx] + i[0] += 1 + + def assign_j(): + """assign j value to current output""" + dest[k_idx] = source[j_idx] + if indices_out is not None: + dest_idx[k_idx] = source_idx[j_idx] + j[0] += 1 + + ## if both of the iterators are in range + with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): + # compare them and insert whichever is next into the output + with ib.if_scope(compare(source[i_idx], source[j_idx])): + assign_i() + with ib.else_scope(): + assign_j() + # otherwise, simply copy the remainder of the valid iterator to the output + with ib.else_scope(): + with ib.if_scope(i[0] < middle): + assign_i() + with ib.else_scope(): + assign_j() + + # Switch which input is the source and which is the destination each iteration + with ib.if_scope(even): + swap_values(source, dest, source_idx, dest_idx) + with ib.else_scope(): + swap_values(dest, source, dest_idx, source_idx) + + def mergesort(source, dest, source_idx, dest_idx, size, width, even): + # calculate the start, mid, and end points of this section + start[0] = width * tid + with ib.if_scope(start[0] < size): + middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) + end[0] = tvm.te.min(start[0] + width, size) + ## merge the start->middle and middle->end arrays + bottom_up_merge( + source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even + ) - with ib.for_range(0, axis_mul_before) as i: - with ib.for_range(0, axis_mul_after) as j: - base_idx = i * shape[axis] * axis_mul_after + j + # Call the kernel + mergesort( + values_out, + values_out_swap, + indices_out, + indices_out_swap, + shape[axis], + width, + tvm.tir.indexmod(l2_width, 2) == 0, + ) + + ## if the final sorted data ended up in the swap, copy it to the real output + with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1): + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): - values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after] + idx = (by * shape[axis] + tid) * axis_mul_after + bz + values_out[idx] = values_out_swap[idx] if indices_out is not None: - indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast( - tid, indices_out.dtype - ) - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.for_range(0, axis_mul_before) as i: - with ib.for_range(0, axis_mul_after) as j: - current_sort_num = shape[axis] - base_idx = i * shape[axis] * axis_mul_after + j - # OddEvenTransposeSort - with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): - offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after - if is_ascend: - cond = tvm.tir.all( - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - values_out[offset] > values_out[offset + axis_mul_after], - ) - else: - cond = tvm.tir.all( - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - values_out[offset] < values_out[offset + axis_mul_after], - ) - with ib.if_scope(cond): - temp_data[0] = values_out[offset] - values_out[offset] = values_out[offset + axis_mul_after] - values_out[offset + axis_mul_after] = temp_data[0] - if indices_out is not None: - temp_index[0] = indices_out[offset] - indices_out[offset] = indices_out[offset + axis_mul_after] - indices_out[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + indices_out[idx] = indices_out_swap[idx] return ib.get() @@ -336,14 +469,13 @@ def sort(data, axis=-1, is_ascend=1): out : tvm.te.Tensor The output of this function. """ - dtype = "float32" value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8) out = te.extern( [data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), - out_buffers=[value_buf, indices_buf], + lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend), + out_buffers=[value_buf, value_buf_swap], name="sort_gpu", tag="sort_gpu", )[0] @@ -449,12 +581,24 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): ) else: value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + value_swap_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "value_swap_buf", data_alignment=8 + ) indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) out = te.extern( - [data.shape, data.shape], + [data.shape, data.shape, data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), - out_buffers=[value_buf, indices_buf], + lambda ins, outs: sort_ir( + ins[0], + outs[0], + outs[2], + axis, + is_ascend, + indices_out=outs[1], + indices_out_swap=outs[3], + ), + out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], name="argsort_gpu", tag="argsort_gpu", )[1] @@ -564,25 +708,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): axis = axis + ndim if axis < 0 else axis assert 0 <= axis < ndim values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) + values_swap_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "values_swap_buf", data_alignment=8 + ) indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) + indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8) if ret_type == "values": output = te.extern( - [data.shape], + [data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend), - out_buffers=[values_buf], + lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend), + out_buffers=[values_buf, values_swap_buf], name="topk_gpu", tag="topk_gpu", - ) + )[0] else: output = te.extern( - [data.shape, data.shape], + [data.shape, data.shape, data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), - out_buffers=[values_buf, indices_buf], + lambda ins, outs: sort_ir( + ins[0], + outs[0], + outs[2], + axis, + is_ascend, + indices_out=outs[1], + indices_out_swap=outs[3], + ), + out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf], name="topk_gpu", tag="topk_gpu", - ) + )[0:2] if isinstance(k, int) and k < 1: if ret_type == "indices": return output[1] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index dfc03c0cf6b11..e6812aa3bbfad 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -250,9 +250,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): check_result([data], mod, expected, flatten=True) -# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have -# to use thrust to guarantee the correct results which has been tested locally. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_any_argwhere(): verify_any_argwhere(any_dims(1), (5,)) verify_any_argwhere(any_dims(2), (5, 5)) @@ -839,8 +837,7 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): check_result(in_vals, mod, ref_out) -# TODO(kevinthesun): enable this test when Thrust is available in ci. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") verify_any_topk(any_dims(2), 2, (6, 3), "int32") diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index a5ce1fdcf5897..0dac69e360258 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -53,6 +53,8 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): verify_sort((2, 3, 4), axis=0, is_ascend=False, is_dyn=is_dyn) verify_sort((1, 4, 6), axis=1, is_ascend=True, is_dyn=is_dyn) verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn) + verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn) + verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn) @tvm.testing.uses_gpu @@ -66,9 +68,9 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): func = relay.Function([x], z) x_data = np.random.uniform(size=shape).astype("float32") if is_ascend: - ref_res = np.argsort(x_data, axis=axis) + ref_res = np.argsort(x_data, axis=axis, kind="stable") else: - ref_res = np.argsort(-x_data, axis=axis) + ref_res = np.argsort(-x_data, axis=axis, kind="stable") if is_dyn: backends = ["vm", "debug"] @@ -86,6 +88,8 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn) verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index 433030863a43d..69993d287b794 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -60,15 +60,10 @@ def check_device(device, ctx): tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out)) for target, ctx in tvm.testing.enabled_targets(): - # TODO(zhiics) Enable argwhere gpu test after sort is fixed. - if ctx.device_type != 1: - continue check_device(target, ctx) -# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have -# to use thrust to guarantee the correct results which has been tested locally. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_argwhere(): verify_argwhere((1,)) verify_argwhere((100,))