Skip to content

Commit b6472f9

Browse files
committed
uses thrust's dynamic dispatch for merge_sort
1 parent 3e1e6e0 commit b6472f9

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

cub/benchmarks/bench/merge_sort/keys.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
*
2626
******************************************************************************/
2727

28+
#include <cub/detail/choose_offset.cuh>
2829
#include <cub/device/device_merge_sort.cuh>
2930

3031
#include <nvbench_helper.cuh>
@@ -84,7 +85,7 @@ void keys(nvbench::state& state, nvbench::type_list<T, OffsetT>)
8485
using value_input_it_t = value_t*;
8586
using key_it_t = key_t*;
8687
using value_it_t = value_t*;
87-
using offset_t = OffsetT;
88+
using offset_t = cub::detail::choose_offset_t<OffsetT>;
8889
using compare_op_t = less_t;
8990

9091
#if !TUNE_BASE

cub/benchmarks/bench/merge_sort/pairs.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
******************************************************************************/
2727

2828
#include <cub/device/device_merge_sort.cuh>
29+
#include <cub/detail/choose_offset.cuh>
2930

3031
#include <nvbench_helper.cuh>
3132

@@ -81,7 +82,7 @@ void pairs(nvbench::state& state, nvbench::type_list<KeyT, ValueT, OffsetT>)
8182
using value_input_it_t = value_t*;
8283
using key_it_t = key_t*;
8384
using value_it_t = value_t*;
84-
using offset_t = OffsetT;
85+
using offset_t = cub::detail::choose_offset_t<OffsetT>;
8586
using compare_op_t = less_t;
8687

8788
#if !TUNE_BASE

thrust/thrust/system/cuda/detail/sort.h

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,19 @@ THRUST_RUNTIME_FUNCTION cudaError_t doit_step(
9292
using ItemsInputIt = cub::NullType*;
9393
ItemsInputIt items = nullptr;
9494

95-
using DispatchMergeSortT = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, Size, CompareOp>;
95+
cudaError_t status = cudaSuccess;
9696

97-
return DispatchMergeSortT::Dispatch(
98-
d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count, compare_op, stream);
97+
using dispatch32_t = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, std::uint32_t, CompareOp>;
98+
using dispatch64_t = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, std::uint64_t, CompareOp>;
99+
100+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
101+
status,
102+
dispatch32_t::Dispatch,
103+
dispatch64_t::Dispatch,
104+
keys_count,
105+
(d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count_fixed, compare_op, stream));
106+
107+
return status;
99108
}
100109

101110
template <class KeysIt, class ItemsIt, class Size, class CompareOp>
@@ -109,10 +118,19 @@ THRUST_RUNTIME_FUNCTION cudaError_t doit_step(
109118
cudaStream_t stream,
110119
thrust::detail::integral_constant<bool, true> /* sort_items */)
111120
{
112-
using DispatchMergeSortT = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, Size, CompareOp>;
121+
cudaError_t status = cudaSuccess;
122+
123+
using dispatch32_t = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, std::uint32_t, CompareOp>;
124+
using dispatch64_t = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, std::uint64_t, CompareOp>;
125+
126+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
127+
status,
128+
dispatch32_t::Dispatch,
129+
dispatch64_t::Dispatch,
130+
keys_count,
131+
(d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count_fixed, compare_op, stream));
113132

114-
return DispatchMergeSortT::Dispatch(
115-
d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count, compare_op, stream);
133+
return status;
116134
}
117135

118136
template <class SORT_ITEMS, class /* STABLE */, class KeysIt, class ItemsIt, class Size, class CompareOp>

0 commit comments

Comments
 (0)