Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use matrix::select_k in brute_force::knn call #1463

Merged
merged 3 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ if(RAFT_COMPILE_LIBRARY)
src/matrix/detail/select_k_double_uint32_t.cu
src/matrix/detail/select_k_float_int64_t.cu
src/matrix/detail/select_k_float_uint32_t.cu
src/matrix/detail/select_k_float_int32.cu
src/matrix/detail/select_k_half_int64_t.cu
src/matrix/detail/select_k_half_uint32_t.cu
src/neighbors/ball_cover.cu
Expand Down Expand Up @@ -600,9 +601,7 @@ target_link_libraries(raft::raft INTERFACE
# Use `rapids_export` for 22.04 as it will have COMPONENT support
rapids_export(
INSTALL raft
EXPORT_SET raft-exports
COMPONENTS ${raft_components}
COMPONENTS_EXPORT_SET ${raft_export_sets}
EXPORT_SET raft-exports COMPONENTS ${raft_components} COMPONENTS_EXPORT_SET ${raft_export_sets}
GLOBAL_TARGETS raft compiled distributed
NAMESPACE raft::
DOCUMENTATION doc_string
Expand All @@ -613,9 +612,7 @@ rapids_export(
# * build export -------------------------------------------------------------
rapids_export(
BUILD raft
EXPORT_SET raft-exports
COMPONENTS ${raft_components}
COMPONENTS_EXPORT_SET ${raft_export_sets}
EXPORT_SET raft-exports COMPONENTS ${raft_components} COMPONENTS_EXPORT_SET ${raft_export_sets}
GLOBAL_TARGETS raft compiled distributed
DOCUMENTATION doc_string
NAMESPACE raft::
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
instantiate_raft_matrix_detail_select_k(float, uint32_t);
// needed for brute force knn
instantiate_raft_matrix_detail_select_k(float, int);
// We did not have these two for double before, but there are tests for them. We
// therefore include them here.
instantiate_raft_matrix_detail_select_k(double, int64_t);
Expand Down
42 changes: 22 additions & 20 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
#include <raft/linalg/map.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/neighbors/detail/faiss_select/DistanceUtils.h>
#include <raft/neighbors/detail/faiss_select/Select.cuh>
#include <raft/neighbors/detail/knn_merge_parts.cuh>
#include <raft/neighbors/detail/selection_faiss.cuh>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/haversine_distance.cuh>
#include <raft/spatial/knn/detail/processing.cuh>
Expand Down Expand Up @@ -230,15 +229,16 @@ void tiled_brute_force_knn(const raft::resources& handle,
}
}

select_k<IndexType, ElementType>(temp_distances.data(),
nullptr,
current_query_size,
current_centroid_size,
distances + i * k,
indices + i * k,
select_min,
current_k,
stream);
matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, int64_t, row_major>(
temp_distances.data(), current_query_size, current_centroid_size),
std::nullopt,
raft::make_device_matrix_view<ElementType, int64_t, row_major>(
distances + i * k, current_query_size, current_k),
raft::make_device_matrix_view<IndexType, int64_t, row_major>(
indices + i * k, current_query_size, current_k),
select_min);

// if we're tiling over columns, we need to do a couple things to fix up
// the output of select_k
Expand Down Expand Up @@ -270,15 +270,17 @@ void tiled_brute_force_knn(const raft::resources& handle,

if (tile_cols != n) {
// select the actual top-k items here from the temporary output
select_k<IndexType, ElementType>(temp_out_distances.data(),
temp_out_indices.data(),
current_query_size,
temp_out_cols,
distances + i * k,
indices + i * k,
select_min,
k,
stream);
matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, int64_t, row_major>(
temp_out_distances.data(), current_query_size, temp_out_cols),
raft::make_device_matrix_view<const IndexType, int64_t, row_major>(
temp_out_indices.data(), current_query_size, temp_out_cols),
raft::make_device_matrix_view<ElementType, int64_t, row_major>(
distances + i * k, current_query_size, k),
raft::make_device_matrix_view<IndexType, int64_t, row_major>(
indices + i * k, current_query_size, k),
select_min);
}
}
}
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_float_int32.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(float, int);

#undef instantiate_raft_matrix_detail_select_k