Skip to content

Commit

Permalink
Fix for function exposing KNN merge
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Apr 14, 2023
1 parent c950854 commit fd6dacb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
12 changes: 8 additions & 4 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ inline void knn_merge_parts(
RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0),
"in_keys and in_values must have the same shape.");
RAFT_EXPECTS(
out_keys.extent(0) == out_values.extent(0) == n_samples,
out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples,
"Number of rows in output keys and val matrices must equal number of rows in search matrix.");
RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1),
"Number of columns in output indices and distances matrices must be equal to k");
RAFT_EXPECTS(
out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1),
"Number of columns in output indices and distances matrices must be equal to k");

idx_t* translations_ptr = nullptr;
if (translations.has_value()) { translations_ptr = translations.value().data_handle(); }

auto n_parts = in_keys.extent(0) / n_samples;
detail::knn_merge_parts(in_keys.data_handle(),
Expand All @@ -103,7 +107,7 @@ inline void knn_merge_parts(
n_parts,
in_keys.extent(1),
handle.get_stream(),
translations.value_or(nullptr));
translations_ptr);
}

/**
Expand Down
24 changes: 12 additions & 12 deletions cpp/include/raft/neighbors/detail/knn_merge_parts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ template <typename value_idx = std::int64_t,
int warp_q,
int thread_q,
int tpb>
__global__ void knn_merge_parts_kernel(value_t* inK,
value_idx* inV,
value_t* outK,
value_idx* outV,
__global__ void knn_merge_parts_kernel(const value_t* const inK,
const value_idx* const inV,
value_t* const outK,
value_idx* const outV,
size_t n_samples,
int n_parts,
value_t initK,
Expand Down Expand Up @@ -99,10 +99,10 @@ __global__ void knn_merge_parts_kernel(value_t* inK,
}

template <typename value_idx = std::int64_t, typename value_t = float, int warp_q, int thread_q>
inline void knn_merge_parts_impl(value_t* inK,
value_idx* inV,
value_t* outK,
value_idx* outV,
inline void knn_merge_parts_impl(const value_t* const inK,
const value_idx* const inV,
value_t* const outK,
value_idx* const outV,
size_t n_samples,
int n_parts,
int k,
Expand Down Expand Up @@ -137,10 +137,10 @@ inline void knn_merge_parts_impl(value_t* inK,
* @param translations mapping of index offsets for each partition
*/
template <typename value_idx = std::int64_t, typename value_t = float>
inline void knn_merge_parts(value_t* inK,
value_idx* inV,
value_t* outK,
value_idx* outV,
inline void knn_merge_parts(const value_t* const inK,
const value_idx* const inV,
value_t* const outK,
value_idx* const outV,
size_t n_samples,
int n_parts,
int k,
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/spatial/knn/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ namespace raft::spatial::knn {
* @param translations
*/
template <typename idx_t = int64_t, typename value_t = float>
inline void knn_merge_parts(value_t* in_keys,
idx_t* in_values,
value_t* out_keys,
idx_t* out_values,
inline void knn_merge_parts(const value_t* const in_keys,
const idx_t* const in_values,
value_t* const out_keys,
idx_t* const out_values,
size_t n_samples,
int n_parts,
int k,
Expand Down

0 comments on commit fd6dacb

Please sign in to comment.