Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cub/benchmarks/bench/merge_sort/keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*
******************************************************************************/

#include <cub/detail/choose_offset.cuh>
#include <cub/device/device_merge_sort.cuh>

#include <nvbench_helper.cuh>
Expand Down Expand Up @@ -84,7 +85,7 @@ void keys(nvbench::state& state, nvbench::type_list<T, OffsetT>)
using value_input_it_t = value_t*;
using key_it_t = key_t*;
using value_it_t = value_t*;
using offset_t = OffsetT;
using offset_t = cub::detail::choose_offset_t<OffsetT>;
using compare_op_t = less_t;

#if !TUNE_BASE
Expand Down
3 changes: 2 additions & 1 deletion cub/benchmarks/bench/merge_sort/pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*
******************************************************************************/

#include <cub/detail/choose_offset.cuh>
#include <cub/device/device_merge_sort.cuh>

#include <nvbench_helper.cuh>
Expand Down Expand Up @@ -81,7 +82,7 @@ void pairs(nvbench::state& state, nvbench::type_list<KeyT, ValueT, OffsetT>)
using value_input_it_t = value_t*;
using key_it_t = key_t*;
using value_it_t = value_t*;
using offset_t = OffsetT;
using offset_t = cub::detail::choose_offset_t<OffsetT>;
using compare_op_t = less_t;

#if !TUNE_BASE
Expand Down
30 changes: 24 additions & 6 deletions thrust/thrust/system/cuda/detail/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,19 @@ THRUST_RUNTIME_FUNCTION cudaError_t doit_step(
using ItemsInputIt = cub::NullType*;
ItemsInputIt items = nullptr;

using DispatchMergeSortT = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, Size, CompareOp>;
cudaError_t status = cudaSuccess;

return DispatchMergeSortT::Dispatch(
d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count, compare_op, stream);
using dispatch32_t = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, std::uint32_t, CompareOp>;
using dispatch64_t = cub::DispatchMergeSort<KeysIt, ItemsInputIt, KeysIt, ItemsInputIt, std::uint64_t, CompareOp>;

THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
status,
dispatch32_t::Dispatch,
dispatch64_t::Dispatch,
keys_count,
(d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count_fixed, compare_op, stream));

return status;
}

template <class KeysIt, class ItemsIt, class Size, class CompareOp>
Expand All @@ -109,10 +118,19 @@ THRUST_RUNTIME_FUNCTION cudaError_t doit_step(
cudaStream_t stream,
thrust::detail::integral_constant<bool, true> /* sort_items */)
{
using DispatchMergeSortT = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, Size, CompareOp>;
cudaError_t status = cudaSuccess;

using dispatch32_t = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, std::uint32_t, CompareOp>;
using dispatch64_t = cub::DispatchMergeSort<KeysIt, ItemsIt, KeysIt, ItemsIt, std::uint64_t, CompareOp>;

THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
status,
dispatch32_t::Dispatch,
dispatch64_t::Dispatch,
keys_count,
(d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count_fixed, compare_op, stream));

return DispatchMergeSortT::Dispatch(
d_temp_storage, temp_storage_bytes, keys, items, keys, items, keys_count, compare_op, stream);
return status;
}

template <class SORT_ITEMS, class /* STABLE */, class KeysIt, class ItemsIt, class Size, class CompareOp>
Expand Down
Loading