diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 8946446f3cdc..a4080e585a45 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -819,11 +819,9 @@ def non_max_suppression( if ( target and target.kind.name == "cuda" - and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True) + and tvm.get_global_func("tvm.contrib.thrust.sort", allow_missing=True) ): - sort_tensor = argsort_thrust( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype - ) + sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype) else: sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 18872a242160..9b6a18a8b06b 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -409,68 +409,6 @@ def sort_by_key_ir( ) -def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"): - """Performs sorting along the given axis and returns an array of indicies - having same shape as an input array that index data in sorted order. - - Parameters - ---------- - data: tvm.te.Tensor - The input array. - - valid_count : tvm.te.Tensor, optional - The number of valid elements to be sorted. - - axis : int, optional - Axis long which to sort the input tensor. - - is_ascend : boolean, optional - Whether to sort in ascending or descending order. - - dtype : string, optional - DType of the output indices. - - Returns - ------- - out : tvm.te.Tensor - The output of this function. - """ - ndim = len(data.shape) - if axis < 0: - axis = ndim + axis - if axis != ndim - 1: - # Prepare for sorting along axis -1. - axes = swap(list(range(ndim)), axis) - data = transpose(data, axes) - - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4 - ) - out_bufs = [ - tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8), - tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8), - ] - out = te.extern( - [data.shape, data.shape], - [data, valid_count], - lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend - ), - in_buffers=[data_buf, valid_count_buf], - out_buffers=out_bufs, - dtype=[data.dtype, "int32"], - name="nms_argsort_gpu", - tag="nms_argsort_gpu", - ) - - if axis != ndim - 1: - axes = swap(list(range(ndim)), axis) - out = [transpose(o, axes) for o in out] - - return out[1] - - def sort(data, axis=-1, is_ascend=1): """Performs sorting along the given axis and returns an array of sorted values with the same shape as the input data. @@ -602,7 +540,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): return out -def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): +def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -611,9 +549,6 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32" data: tvm.te.Tensor The input array. - valid_count : tvm.te.Tensor, optional - The number of valid elements to be sorted. - axis : int, optional Axis long which to sort the input tensor. @@ -628,11 +563,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32" out : tvm.te.Tensor The output of this function. """ - if valid_count is not None: - out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype) - else: - out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype) - return out + return topk_thrust(data, 0, axis, "indices", is_ascend, dtype) def schedule_sort(outs): diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index dddbb043fddc..6a48f1ad876a 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -22,7 +22,9 @@ */ #include +#include #include +#include #include #include @@ -41,21 +43,19 @@ void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend, - const std::function &get_sort_len) { + int n_values) { thrust::device_ptr data_ptr(static_cast(input->data)); thrust::device_ptr values_ptr(static_cast(out_values->data)); thrust::device_ptr indices_ptr(static_cast(out_indices->data)); - int n_values = input->shape[input->ndim - 1]; - int n_iter = 1; - for (int i = 0; i < input->ndim - 1; ++i) { - n_iter *= input->shape[i]; + size_t size = 1; + for (int i = 0; i < input->ndim; ++i) { + size *= input->shape[i]; } + thrust::copy(data_ptr, data_ptr + size, values_ptr); - thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr); - - for (int i = 0 ; i < n_iter; ++i) { - n_values = get_sort_len(i); + if (size == static_cast(input->shape[input->ndim - 1])) { + // A fast path for single segment case thrust::sequence(indices_ptr, indices_ptr + n_values); if (is_ascend) { thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr); @@ -63,8 +63,47 @@ void thrust_sort(DLTensor* input, thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr, thrust::greater()); } - values_ptr += n_values; - indices_ptr += n_values; + } else { + // segmented sort by key + // Follow the back-to-back stable_sort_by_key strategy explained below + // https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY + thrust::device_vector argsort_order(size); + thrust::sequence(argsort_order.begin(), argsort_order.end()); + + // First, sort values and store the sorted order in argsort_order. + if (is_ascend) { + thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin()); + } else { + thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(), + thrust::greater()); + } + + // The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2 + // without materializing it + auto counting_iter = thrust::counting_iterator(0); + auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) { + return i % n_values; + }; // NOLINT(*) + auto init_indices_iter = thrust::make_transform_iterator(counting_iter, + linear_index_to_sort_axis_index); + + // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr + thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr); + + thrust::device_vector segment_ids(size); + auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) { + return i / n_values; + }; // NOLINT(*) + // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr + thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(), + linear_index_to_segment_id); + + // The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ... + // values_ptr and indices_ptr will also be sorted in the order of segmend_ids above + // Since sorting has been done in a stable way, relative orderings of values and indices + // in the segment do not change and hence they remain sorted. + auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr)); + thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip); } } @@ -72,54 +111,54 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out, bool is_ascend, - const std::function &get_sort_len, + int sort_len, std::string data_dtype, std::string out_dtype) { if (data_dtype == "float32") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int32") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } @@ -128,25 +167,6 @@ void thrust_sort_common(DLTensor* input, } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_GE(args.num_args, 5); - DLTensor* input = args[0]; - DLTensor* valid_count = args[1]; - DLTensor* values_out = args[2]; - DLTensor* indices_out = args[3]; - bool is_ascend = args[4]; - - auto data_dtype = DLDataType2String(input->dtype); - auto out_dtype = DLDataType2String(indices_out->dtype); - - thrust::device_ptr valid_count_ptr(static_cast(valid_count->data)); - auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; }; - thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, - data_dtype, out_dtype); -}); - - TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") .set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_GE(args.num_args, 4); @@ -159,8 +179,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") auto out_dtype = DLDataType2String(indices_out->dtype); int n_values = input->shape[input->ndim - 1]; - auto get_sort_len = [=](int i) { return n_values; }; - thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype); });