Skip to content

Commit

Permalink
Add device argsort. (#6749)
Browse files Browse the repository at this point in the history
This is part of #6747 .
  • Loading branch information
trivialfis committed Mar 16, 2021
1 parent 325bc93 commit 1a73a28
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/xgboost/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace common {
#define KERNEL_CHECK(cond) \
(XGBOOST_EXPECT((cond), true) \
? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER(e), __FILE__, __LINE__, \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
__PRETTY_FUNCTION__))

#endif // defined(_MSC_VER)
Expand Down
73 changes: 73 additions & 0 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ inline void LaunchN(int device_idx, size_t n, L lambda) {
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
}

template <typename Container>
void Iota(Container array, int32_t device = CurrentDevice()) {
LaunchN(device, array.size(), [=] __device__(size_t i) { array[i] = i; });
}

namespace detail {
/** \brief Keeps track of global device memory allocations. Thread safe.*/
class MemoryLogger {
Expand Down Expand Up @@ -1179,4 +1184,72 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
}
return aggregate;
}

template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::common::Span<U> values, xgboost::common::Span<IdxT> sorted_idx) {
size_t bytes = 0;
Iota(sorted_idx);
CHECK_LT(sorted_idx.size(), 1 << 31);
TemporaryArray<U> out(values.size());
if (accending) {
cub::DeviceRadixSort::SortPairs(nullptr, bytes, values.data(),
out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size());
dh::TemporaryArray<char> storage(bytes);
cub::DeviceRadixSort::SortPairs(storage.data().get(), bytes, values.data(),
out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size());
} else {
cub::DeviceRadixSort::SortPairsDescending(
nullptr, bytes, values.data(), out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.size());
dh::TemporaryArray<char> storage(bytes);
cub::DeviceRadixSort::SortPairsDescending(
storage.data().get(), bytes, values.data(), out.data().get(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size());
}
}

namespace detail {
// Wrapper around cub sort for easier `descending` sort
template <bool descending, typename KeyT, typename ValueT, typename OffsetIteratorT>
void DeviceSegmentedRadixSortPair(
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out,
size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets, int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8) {
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in),
d_values_out);
using OffsetT = size_t;
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, OffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr, false)));
}
} // namespace detail

template <bool accending, typename U, typename V, typename IdxT>
void SegmentedArgSort(xgboost::common::Span<U> values,
xgboost::common::Span<V> group_ptr,
xgboost::common::Span<IdxT> sorted_idx) {
CHECK_GE(group_ptr.size(), 1ul);
size_t n_groups = group_ptr.size() - 1;
size_t bytes = 0;
Iota(sorted_idx);
CHECK_LT(sorted_idx.size(), 1 << 31);
TemporaryArray<U> values_out(values.size());
detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
group_ptr.data(), group_ptr.data() + 1);
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
group_ptr.data(), group_ptr.data() + 1);
}
} // namespace dh
23 changes: 23 additions & 0 deletions tests/cpp/common/test_device_helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,28 @@ TEST(Allocator, OOM) {
// Clear last error so we don't fail subsequent tests
cudaGetLastError();
}

TEST(DeviceHelpers, ArgSort) {
dh::device_vector<float> values(20);
dh::Iota(dh::ToSpan(values)); // accending
dh::device_vector<size_t> sorted_idx(20);
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));

dh::Iota(dh::ToSpan(values));
dh::device_vector<size_t> groups(3);
groups[0] = 0;
groups[1] = 10;
groups[2] = 20;
dh::SegmentedArgSort<false>(dh::ToSpan(values), dh::ToSpan(groups),
dh::ToSpan(sorted_idx));
ASSERT_FALSE(thrust::is_sorted(thrust::device, sorted_idx.begin(),
sorted_idx.end(), thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin(), sorted_idx.begin() + 10,
thrust::greater<size_t>{}));
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
thrust::greater<size_t>{}));
}
} // namespace common
} // namespace xgboost

0 comments on commit 1a73a28

Please sign in to comment.