Skip to content

Commit

Permalink
Adding sort option to matrix::select_k api (#1615)
Browse files Browse the repository at this point in the history
The current `raft::matrix::select_k` API doesn't always return sorted neighborhoods, which can easily cause issues when the sorting is expected. This PR provides a new argument to the select_k API to guarantee the output is always sorted.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1615
  • Loading branch information
cjnolet authored Jul 3, 2023
1 parent fda9cd1 commit 744881e
Show file tree
Hide file tree
Showing 20 changed files with 312 additions and 144 deletions.
32 changes: 17 additions & 15 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cstdint> // uint32_t
#include <cuda_fp16.h> // __half
#include <raft/core/device_resources.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::mr::device_memory_resource
Expand All @@ -27,32 +28,33 @@
namespace raft::matrix::detail {

template <typename T, typename IdxT>
void select_k(const T* in_val,
void select_k(raft::resources const& handle,
const T* in_val,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT;
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(raft::resources const& handle, \
const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::mr::device_memory_resource* mr, \
bool sorted)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
171 changes: 157 additions & 14 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
#include "select_radix.cuh"
#include "select_warpsort.cuh"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/matrix/init.cuh>

#include <raft/core/resource/thrust_policy.hpp>
#include <raft/neighbors/detail/selection_faiss.cuh>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <thrust/scan.h>

namespace raft::matrix::detail {

Expand Down Expand Up @@ -116,6 +121,121 @@ inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
}
}

/**
* Performs a segmented sorting of a keys array with respect to
* the segments of a values array.
* @tparam KeyT
* @tparam ValT
* @param handle
* @param values
* @param keys
* @param n_segments
* @param k
* @param select_min
*/
template <typename KeyT, typename ValT>
void segmented_sort_by_key(raft::resources const& handle,
KeyT* keys,
ValT* values,
size_t n_segments,
size_t n_elements,
const ValT* offsets,
bool asc)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto out_inds = raft::make_device_vector<ValT, ValT>(handle, n_elements);
auto out_dists = raft::make_device_vector<KeyT, ValT>(handle, n_elements);

// Determine temporary device storage requirements
auto d_temp_storage = raft::make_device_vector<char, int>(handle, 0);
size_t temp_storage_bytes = 0;
if (asc) {
cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(),
temp_storage_bytes,
keys,
out_dists.data_handle(),
values,
out_inds.data_handle(),
n_elements,
n_segments,
offsets,
offsets + 1,
0,
sizeof(ValT) * 8,
stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(),
temp_storage_bytes,
keys,
out_dists.data_handle(),
values,
out_inds.data_handle(),
n_elements,
n_segments,
offsets,
offsets + 1,
0,
sizeof(ValT) * 8,
stream);
}

d_temp_storage = raft::make_device_vector<char, int>(handle, temp_storage_bytes);

if (asc) {
// Run sorting operation
cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(),
temp_storage_bytes,
keys,
out_dists.data_handle(),
values,
out_inds.data_handle(),
n_elements,
n_segments,
offsets,
offsets + 1,
0,
sizeof(ValT) * 8,
stream);

} else {
// Run sorting operation
cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(),
temp_storage_bytes,
keys,
out_dists.data_handle(),
values,
out_inds.data_handle(),
n_elements,
n_segments,
offsets,
offsets + 1,
0,
sizeof(ValT) * 8,
stream);
}

raft::copy(values, out_inds.data_handle(), out_inds.size(), stream);
raft::copy(keys, out_dists.data_handle(), out_dists.size(), stream);
}

template <typename KeyT, typename ValT>
void segmented_sort_by_key(raft::resources const& handle,
raft::device_vector_view<const ValT, ValT> offsets,
raft::device_vector_view<KeyT, ValT> keys,
raft::device_vector_view<ValT, ValT> values,
bool asc)
{
RAFT_EXPECTS(keys.size() == values.size(),
"Keys and values must contain the same number of elements.");
segmented_sort_by_key<KeyT, ValT>(handle,
keys.data_handle(),
values.data_handle(),
offsets.size() - 1,
keys.size(),
offsets.data_handle(),
asc);
}

/**
* Select k smallest or largest key/values from each row in the input data.
*
Expand Down Expand Up @@ -154,40 +274,63 @@ inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
* memory pool here to avoid memory allocations within the call).
*/
template <typename T, typename IdxT>
void select_k(const T* in_val,
void select_k(raft::resources const& handle,
const T* in_val,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);

auto algo = choose_select_k_algorithm(batch_size, len, k);
auto stream = raft::resource::get_cuda_stream(handle);
auto algo = choose_select_k_algorithm(batch_size, len, k);

switch (algo) {
case Algo::kRadix11bits:
return detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream);
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream);

if (sorted) {
auto offsets = raft::make_device_vector<IdxT, IdxT>(handle, (IdxT)(batch_size + 1));

raft::matrix::fill(handle, offsets.view(), (IdxT)k);

thrust::exclusive_scan(raft::resource::get_thrust_policy(handle),
offsets.data_handle(),
offsets.data_handle() + offsets.size(),
offsets.data_handle(),
0);

auto keys = raft::make_device_vector_view<T, IdxT>(out_val, (IdxT)(batch_size * k));
auto vals = raft::make_device_vector_view<IdxT, IdxT>(out_idx, (IdxT)(batch_size * k));

segmented_sort_by_key<T, IdxT>(
handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min);
}
return;
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
} // namespace raft::matrix::detail
16 changes: 11 additions & 5 deletions cpp/include/raft/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace raft::matrix {
* @tparam IdxT
* the index type (what is being selected together with the keys).
*
* @param[in] handle
* @param[in] handle container of reusable resources
* @param[in] in_val
* inputs values [batch_size, len];
* these are compared and selected.
Expand All @@ -74,14 +74,17 @@ namespace raft::matrix {
* the payload selected together with `out_val`.
* @param[in] select_min
* whether to select k smallest (true) or largest (false) keys.
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
*/
template <typename T, typename IdxT>
void select_k(const resources& handle,
void select_k(raft::resources const& handle,
raft::device_matrix_view<const T, int64_t, row_major> in_val,
std::optional<raft::device_matrix_view<const IdxT, int64_t, row_major>> in_idx,
raft::device_matrix_view<T, int64_t, row_major> out_val,
raft::device_matrix_view<IdxT, int64_t, row_major> out_idx,
bool select_min)
bool select_min,
bool sorted = false)
{
RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits<int>::max()),
"output k must fit the int type.");
Expand All @@ -95,15 +98,18 @@ void select_k(const resources& handle,
RAFT_EXPECTS(len == in_idx->extent(1), "value and index input lengths must be equal");
}
RAFT_EXPECTS(int64_t(k) == out_idx.extent(1), "value and index output lengths must be equal");
return detail::select_k<T, IdxT>(in_val.data_handle(),

return detail::select_k<T, IdxT>(handle,
in_val.data_handle(),
in_idx.has_value() ? in_idx->data_handle() : nullptr,
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
resource::get_cuda_stream(handle));
nullptr,
sorted);
}

/** @} */ // end of group select_k
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ void search_impl(raft::resources const& handle,
stream);

RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min<uint32_t>(20, index.n_lists()));
matrix::detail::select_k<AccT, uint32_t>(distance_buffer_dev.data(),
matrix::detail::select_k<AccT, uint32_t>(handle,
distance_buffer_dev.data(),
nullptr,
n_queries,
index.n_lists(),
n_probes,
coarse_distances_dev.data(),
coarse_indices_dev.data(),
select_min,
stream,
search_mr);
RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes);
RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes);
Expand Down Expand Up @@ -191,15 +191,15 @@ void search_impl(raft::resources const& handle,

// Merge topk values from different blocks
if (grid_dim_x > 1) {
matrix::detail::select_k<AccT, IdxT>(refined_distances_dev.data(),
matrix::detail::select_k<AccT, IdxT>(handle,
refined_distances_dev.data(),
refined_indices_dev.data(),
n_queries,
k * grid_dim_x,
k,
distances,
neighbors,
select_min,
stream,
search_mr);
}
}
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ void select_clusters(raft::resources const& handle,

// Select neighbor clusters for each query.
rmm::device_uvector<float> cluster_dists(n_queries * n_probes, stream, mr);
matrix::detail::select_k<float, uint32_t>(qc_distances.data(),
matrix::detail::select_k<float, uint32_t>(handle,
qc_distances.data(),
nullptr,
n_queries,
n_lists,
n_probes,
cluster_dists.data(),
clusters_to_probe,
true,
stream,
mr);
}

Expand Down Expand Up @@ -581,15 +581,15 @@ void ivfpq_search_worker(raft::resources const& handle,

// Select topk vectors for each query
rmm::device_uvector<ScoreT> topk_dists(n_queries * topK, stream, mr);
matrix::detail::select_k<ScoreT, uint32_t>(distances_buf.data(),
matrix::detail::select_k<ScoreT, uint32_t>(handle,
distances_buf.data(),
neighbors_ptr,
n_queries,
topk_len,
topK,
topk_dists.data(),
neighbors_uint32,
true,
stream,
mr);

// Postprocessing
Expand Down
Loading

0 comments on commit 744881e

Please sign in to comment.