Skip to content

Commit

Permalink
Fix for function exposing KNN merge (#1418)
Browse files Browse the repository at this point in the history
Authors:
  - Victor Lafargue (https://github.com/viclafargue)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1418
  • Loading branch information
viclafargue authored Jun 26, 2023
1 parent 185e933 commit 1034a41
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
12 changes: 8 additions & 4 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,14 @@ inline void knn_merge_parts(
RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0),
"in_keys and in_values must have the same shape.");
RAFT_EXPECTS(
out_keys.extent(0) == out_values.extent(0) == n_samples,
out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples,
"Number of rows in output keys and val matrices must equal number of rows in search matrix.");
RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1),
"Number of columns in output indices and distances matrices must be equal to k");
RAFT_EXPECTS(
out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1),
"Number of columns in output indices and distances matrices must be equal to k");

idx_t* translations_ptr = nullptr;
if (translations.has_value()) { translations_ptr = translations.value().data_handle(); }

auto n_parts = in_keys.extent(0) / n_samples;
detail::knn_merge_parts(in_keys.data_handle(),
Expand All @@ -104,7 +108,7 @@ inline void knn_merge_parts(
n_parts,
in_keys.extent(1),
resource::get_cuda_stream(handle),
translations.value_or(nullptr));
translations_ptr);
}

/**
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/raft/neighbors/detail/knn_merge_parts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ template <typename value_idx = std::int64_t,
int warp_q,
int thread_q,
int tpb>
__global__ void knn_merge_parts_kernel(value_t* inK,
value_idx* inV,
__global__ void knn_merge_parts_kernel(const value_t* inK,
const value_idx* inV,
value_t* outK,
value_idx* outV,
size_t n_samples,
Expand Down Expand Up @@ -65,8 +65,8 @@ __global__ void knn_merge_parts_kernel(value_t* inK,

int col = i % k;

value_t* inKStart = inK + (row_idx + col);
value_idx* inVStart = inV + (row_idx + col);
const value_t* inKStart = inK + (row_idx + col);
const value_idx* inVStart = inV + (row_idx + col);

int limit = Pow2<WarpSize>::roundDown(total_k);
value_idx translation = 0;
Expand Down Expand Up @@ -99,8 +99,8 @@ __global__ void knn_merge_parts_kernel(value_t* inK,
}

template <typename value_idx = std::int64_t, typename value_t = float, int warp_q, int thread_q>
inline void knn_merge_parts_impl(value_t* inK,
value_idx* inV,
inline void knn_merge_parts_impl(const value_t* inK,
const value_idx* inV,
value_t* outK,
value_idx* outV,
size_t n_samples,
Expand Down Expand Up @@ -137,8 +137,8 @@ inline void knn_merge_parts_impl(value_t* inK,
* @param translations mapping of index offsets for each partition
*/
template <typename value_idx = std::int64_t, typename value_t = float>
inline void knn_merge_parts(value_t* inK,
value_idx* inV,
inline void knn_merge_parts(const value_t* inK,
const value_idx* inV,
value_t* outK,
value_idx* outV,
size_t n_samples,
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/spatial/knn/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ namespace raft::spatial::knn {
* @param translations
*/
template <typename idx_t = int64_t, typename value_t = float>
inline void knn_merge_parts(value_t* in_keys,
idx_t* in_values,
inline void knn_merge_parts(const value_t* in_keys,
const idx_t* in_values,
value_t* out_keys,
idx_t* out_values,
size_t n_samples,
Expand Down

0 comments on commit 1034a41

Please sign in to comment.