Skip to content

Commit

Permalink
[THRUST] Faster multi dimensional argsort by segmented sort (#7195)
Browse files Browse the repository at this point in the history
* remove sort nms

* add segmented sort by key impl

* bug fix, test pass

* updated fast path condition to work for all dims
  • Loading branch information
masahi authored Jan 13, 2021
1 parent 86479ba commit 1d07f1a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 124 deletions.
6 changes: 2 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,9 @@ def non_max_suppression(
if (
target
and target.kind.name == "cuda"
and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True)
and tvm.get_global_func("tvm.contrib.thrust.sort", allow_missing=True)
):
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)

Expand Down
73 changes: 2 additions & 71 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,68 +409,6 @@ def sort_by_key_ir(
)


def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
if axis < 0:
axis = ndim + axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(
valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4
)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8),
]
out = te.extern(
[data.shape, data.shape],
[data, valid_count],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend
),
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_bufs,
dtype=[data.dtype, "int32"],
name="nms_argsort_gpu",
tag="nms_argsort_gpu",
)

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]

return out[1]


def sort(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array of
sorted values with the same shape as the input data.
Expand Down Expand Up @@ -602,7 +540,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
return out


def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand All @@ -611,9 +549,6 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
data: tvm.te.Tensor
The input array.
valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
Expand All @@ -628,11 +563,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
out : tvm.te.Tensor
The output of this function.
"""
if valid_count is not None:
out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
return topk_thrust(data, 0, axis, "indices", is_ascend, dtype)


def schedule_sort(outs):
Expand Down
117 changes: 68 additions & 49 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
*/

#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/gather.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
Expand All @@ -41,85 +43,122 @@ void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend,
const std::function<int(int)> &get_sort_len) {
int n_values) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

int n_values = input->shape[input->ndim - 1];
int n_iter = 1;
for (int i = 0; i < input->ndim - 1; ++i) {
n_iter *= input->shape[i];
size_t size = 1;
for (int i = 0; i < input->ndim; ++i) {
size *= input->shape[i];
}
thrust::copy(data_ptr, data_ptr + size, values_ptr);

thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
n_values = get_sort_len(i);
if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
// A fast path for single segment case
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
values_ptr += n_values;
indices_ptr += n_values;
} else {
// segmented sort by key
// Follow the back-to-back stable_sort_by_key strategy explained below
// https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY
thrust::device_vector<int64_t> argsort_order(size);
thrust::sequence(argsort_order.begin(), argsort_order.end());

// First, sort values and store the sorted order in argsort_order.
if (is_ascend) {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin());
} else {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(),
thrust::greater<DataType>());
}

// The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2
// without materializing it
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
return i % n_values;
}; // NOLINT(*)
auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
linear_index_to_sort_axis_index);

// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr);

thrust::device_vector<int> segment_ids(size);
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
return i / n_values;
}; // NOLINT(*)
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
linear_index_to_segment_id);

// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
// Since sorting has been done in a stable way, relative orderings of values and indices
// in the segment do not change and hence they remain sorted.
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip);
}
}

void thrust_sort_common(DLTensor* input,
DLTensor* values_out,
DLTensor* indices_out,
bool is_ascend,
const std::function<int(int)> &get_sort_len,
int sort_len,
std::string data_dtype,
std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
Expand All @@ -128,25 +167,6 @@ void thrust_sort_common(DLTensor* input,
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* input = args[0];
DLTensor* valid_count = args[1];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[4];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 4);
Expand All @@ -159,8 +179,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
auto get_sort_len = [=](int i) { return n_values; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype);
});

Expand Down

0 comments on commit 1d07f1a

Please sign in to comment.