@@ -92,10 +92,19 @@ THRUST_RUNTIME_FUNCTION cudaError_t doit_step(
9292 using ItemsInputIt = cub::NullType*;
9393 ItemsInputIt items = nullptr ;
9494
95- using DispatchMergeSortT = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, Size, CompareOp> ;
95+ cudaError_t status = cudaSuccess ;
9696
97- return DispatchMergeSortT::Dispatch (
98- d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count, compare_op, stream);
97+ using dispatch32_t = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, std::uint32_t , CompareOp>;
98+ using dispatch64_t = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, std::uint64_t , CompareOp>;
99+
100+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
101+ status,
102+ dispatch32_t ::Dispatch,
103+ dispatch64_t ::Dispatch,
104+ keys_count,
105+ (d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count_fixed, compare_op, stream));
106+
107+ return status;
99108}
100109
101110template <class KeysIt , class ItemsIt , class Size , class CompareOp >
@@ -109,10 +118,19 @@ THRUST_RUNTIME_FUNCTION cudaError_t doit_step(
109118 cudaStream_t stream,
110119 thrust::detail::integral_constant<bool , true > /* sort_items */ )
111120{
112- using DispatchMergeSortT = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, Size, CompareOp>;
121+ cudaError_t status = cudaSuccess;
122+
123+ using dispatch32_t = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, std::uint32_t , CompareOp>;
124+ using dispatch64_t = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, std::uint64_t , CompareOp>;
125+
126+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
127+ status,
128+ dispatch32_t ::Dispatch,
129+ dispatch64_t ::Dispatch,
130+ keys_count,
131+ (d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count_fixed, compare_op, stream));
113132
114- return DispatchMergeSortT::Dispatch (
115- d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count, compare_op, stream);
133+ return status;
116134}
117135
118136template <class SORT_ITEMS , class /* STABLE */ , class KeysIt , class ItemsIt , class Size , class CompareOp >
0 commit comments