diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index f657070df4..e6533eaf51 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -1065,6 +1065,14 @@ void ivfflat_interleaved_scan(const index& index, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // related function calls after editing this function definition. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. + const int capacity = bound_by_power_of_two(k); select_interleaved_scan_kernel::run(capacity, index.veclen(), diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index f244d5875c..aedfc42698 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -117,6 +117,14 @@ void refine_device(raft::device_resources const& handle, n_queries, n_candidates); + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // and adjust the extern template definition and the instantiation when the + // below function call is edited. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 013c7359e5..161f3462c9 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -20,6 +20,13 @@ namespace raft::neighbors::ivf_flat { +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided here. Please check related function +// calls after editing template definition below. Search for +// `greppable-id-specializations-ivf-flat-search` to find them. #define RAFT_INST(T, IdxT) \ extern template auto build(raft::device_resources const& handle, \ const index_params& params, \ @@ -44,7 +51,23 @@ namespace raft::neighbors::ivf_flat { const raft::neighbors::ivf_flat::index&, \ raft::device_matrix_view, \ raft::device_matrix_view, \ - raft::device_matrix_view); + raft::device_matrix_view); \ + \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); RAFT_INST(float, int64_t); RAFT_INST(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu index 6de65546c8..dce7083139 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu @@ -18,12 +18,37 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided. To make sure +// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate +// it below. Please check related function calls after editing template +// definition below. Search for `greppable-id-specializations-ivf-flat-search` +// to find them. +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu index 8eda240ccd..b03d878bae 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu index 8ff6533628..2d42bae0d1 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(uint8_t, int64_t);