Skip to content

Commit

Permalink
[TOPI][OP] Use Thrust sort for argsort and topk (#5097)
Browse files Browse the repository at this point in the history
* [TOPI][OP] Use Thrust sort for argsort and topk

The current GPU sort implementation (odd-even transposition sort) is too slow
when the number of elements is large.  This PR introduces Thrust implementation
of sort which is much faster.

Note that this change requires CMake 3.8 or later since we have to use nvcc to
compile a thrust code.

* cmake: make CUDA optional

* allow .cu file to be into the repository

* pylint fix and cleanup

* require cmake 3.8 only when thrust is enabled

* fix nvcc compiler error when passing -pthread

* add missing include

* add USE_THRUST option in config.cmake

* retrigger CI

* retrigger CI
  • Loading branch information
kazum authored Mar 20, 2020
1 parent dbd805c commit 2e8f3a9
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" OFF)
Expand Down Expand Up @@ -101,9 +102,11 @@ else(MSVC)
message("Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
else()
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O2 -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
if (HIDE_PRIVATE_SYMBOLS)
message(STATUS "Hide private symbols...")
set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}")
Expand Down Expand Up @@ -262,6 +265,7 @@ if(NOT MSVC)
check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
message(STATUS "Build with c++14")
set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 14)
endif()

add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,6 @@ set(USE_VTA_FPGA OFF)

# Whether to build the example external runtime module
set(USE_EXAMPLE_EXT_RUNTIME OFF)

# Whether use Thrust
set(USE_THRUST OFF)
8 changes: 8 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ if(USE_CUDA)
endif()
endif(USE_CUBLAS)

if(USE_THRUST)
message(STATUS "Build with Thrust support")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)

else(USE_CUDA)
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc)
endif(USE_CUDA)
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.te import SpecializedCondition
from .generic import *
from .. import op as _op
from .... import get_global_func

@schedule_injective.register(["cuda", "gpu"])
def schedule_injective_cuda(attrs, outs, target):
Expand Down Expand Up @@ -328,6 +329,11 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_argsort(topi.cuda.argsort),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort.cuda")
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
strategy.add_implementation(wrap_compute_argsort(topi.cuda.argsort_thrust),
wrap_topi_schedule(topi.cuda.schedule_argsort),
name="argsort_thrust.cuda",
plevel=15)
return strategy

@topk_strategy.register(["cuda", "gpu"])
Expand All @@ -337,6 +343,11 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
strategy.add_implementation(wrap_compute_topk(topi.cuda.topk),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk.cuda")
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
strategy.add_implementation(wrap_compute_topk(topi.cuda.topk_thrust),
wrap_topi_schedule(topi.cuda.schedule_topk),
name="topk_thrust.cuda",
plevel=15)
return strategy

@multibox_prior_strategy.register(["cuda", "gpu"])
Expand Down
133 changes: 133 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file Use external Thrust library call
*/

#include <thrust/device_ptr.h>
#include <thrust/sort.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>

namespace tvm {
namespace contrib {

using namespace runtime;

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

int n_values = input->shape[input->ndim - 1];
int n_iter = 1;
for (int i = 0; i < input->ndim - 1; ++i) {
n_iter *= input->shape[i];
}

thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
values_ptr += n_values;
indices_ptr += n_values;
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
DLTensor* indices_out = args[2];
bool is_ascend = args[3];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
});

} // namespace contrib
} // namespace tvm
1 change: 1 addition & 0 deletions tests/lint/check_file_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"pxi",
"pyd",
"pyx",
"cu",
# relay text format
"rly",
# configurations
Expand Down
112 changes: 111 additions & 1 deletion topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .injective import schedule_injective_from_existing
from ..math import identity
from ..transform import strided_slice
from ..transform import strided_slice, transpose
from .. import tag

def _schedule_sort(outs):
Expand Down Expand Up @@ -291,6 +291,40 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
tag="argsort_gpu")[1]
return out

def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
if valid_count is not None:
# TODO: implement argsort_nms with Thrust
out = argsort(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out


def schedule_argsort(outs):
"""Schedule for argsort operator.
Expand Down Expand Up @@ -384,6 +418,82 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
return output


def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
"""Get the top k elements in an input tensor along the given axis.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : tvm.te.Tensor or List[tvm.te.Tensor]
The computed result.
"""
assert ret_type in ["both", "values", "indices"]
ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis

def swap(arr):
""" swap arr[axis] and arr[-1] """
return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]

if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)))
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
]

out = te.extern([data.shape, data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend),
in_buffers=[data_buf],
out_buffers=out_bufs,
name="topk_gpu",
tag="topk_gpu")

if k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k]
out = [strided_slice(o, beg, end) for o in out]

if axis != ndim - 1:
axes = swap(list(range(ndim)))
out = [transpose(o, axes) for o in out]

if ret_type == "values":
out = out[0]
elif ret_type == "indices":
out = out[1]

return out


def schedule_topk(outs):
"""Schedule for argsort operator.
Expand Down

0 comments on commit 2e8f3a9

Please sign in to comment.