Skip to content

Commit

Permalink
Add extern template for ivfflat_interleaved_scan (#1360)
Browse files Browse the repository at this point in the history
This should cut compilation time for refine_d_int64_t_float.cu.o et al from ~900 seconds to 29 seconds.

The refine specialization contain >100 instances of the ivfflat_interleaved_scan kernel, even though these should be seperately compiled by the ivfflat_search specializations. 

The call to ivf_flat_interleaved_scan is [here](https://github.com/rapidsai/raft/blob/56ac43ad93a319a61073dce1b3b937f6f13ade63/cpp/include/raft/neighbors/detail/refine.cuh#L121). 

Depends on (so please merge after) PR #1307.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1360
  • Loading branch information
ahendriksen authored Mar 25, 2023
1 parent f4c7f1f commit 76c828d
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 19 deletions.
8 changes: 8 additions & 0 deletions cpp/include/raft/neighbors/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,14 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
{
// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan
// function is used in both raft::neighbors::ivf_flat::search and
// raft::neighbors::detail::refine_device. To prevent a duplicate
// instantiation of this function (which defines ~270 kernels) in the refine
// specializations, an extern template definition is provided. Please check
// related function calls after editing this function definition. Search for
// `greppable-id-specializations-ivf-flat-search` to find them.

const int capacity = bound_by_power_of_two(k);
select_interleaved_scan_kernel<T, AccT, IdxT>::run(capacity,
index.veclen(),
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/neighbors/detail/refine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ void refine_device(raft::device_resources const& handle,
n_queries,
n_candidates);

// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan
// function is used in both raft::neighbors::ivf_flat::search and
// raft::neighbors::detail::refine_device. To prevent a duplicate
// instantiation of this function (which defines ~270 kernels) in the refine
// specializations, an extern template definition is provided. Please check
// and adjust the extern template definition and the instantiation when the
// below function call is edited. Search for
// `greppable-id-specializations-ivf-flat-search` to find them.
uint32_t grid_dim_x = 1;
raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<
data_t,
Expand Down
25 changes: 24 additions & 1 deletion cpp/include/raft/neighbors/specializations/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

namespace raft::neighbors::ivf_flat {

// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan
// function is used in both raft::neighbors::ivf_flat::search and
// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation
// of this function (which defines ~270 kernels) in the refine specializations,
// an extern template definition is provided here. Please check related function
// calls after editing template definition below. Search for
// `greppable-id-specializations-ivf-flat-search` to find them.
#define RAFT_INST(T, IdxT) \
extern template auto build(raft::device_resources const& handle, \
const index_params& params, \
Expand All @@ -44,7 +51,23 @@ namespace raft::neighbors::ivf_flat {
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);
raft::device_matrix_view<float, IdxT, row_major>); \
\
extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \
T, \
typename raft::spatial::knn::detail::utils::config<T>::value_t, \
IdxT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream);

RAFT_INST(float, int64_t);
RAFT_INST(int8_t, int64_t);
Expand Down
37 changes: 31 additions & 6 deletions cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,37 @@

namespace raft::neighbors::ivf_flat {

#define RAFT_MAKE_INSTANCE(T, IdxT) \
template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan
// function is used in both raft::neighbors::ivf_flat::search and
// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation
// of this function (which defines ~270 kernels) in the refine specializations,
// an extern template definition is provided. To make sure
// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate
// it below. Please check related function calls after editing template
// definition below. Search for `greppable-id-specializations-ivf-flat-search`
// to find them.
#define RAFT_MAKE_INSTANCE(T, IdxT) \
template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \
T, \
typename raft::spatial::knn::detail::utils::config<T>::value_t, \
IdxT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream); \
\
template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_MAKE_INSTANCE(float, int64_t);
Expand Down
28 changes: 22 additions & 6 deletions cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,28 @@

namespace raft::neighbors::ivf_flat {

#define RAFT_MAKE_INSTANCE(T, IdxT) \
template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
#define RAFT_MAKE_INSTANCE(T, IdxT) \
template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \
T, \
typename raft::spatial::knn::detail::utils::config<T>::value_t, \
IdxT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream); \
\
template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_MAKE_INSTANCE(int8_t, int64_t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,28 @@

namespace raft::neighbors::ivf_flat {

#define RAFT_MAKE_INSTANCE(T, IdxT) \
template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
#define RAFT_MAKE_INSTANCE(T, IdxT) \
template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \
T, \
typename raft::spatial::knn::detail::utils::config<T>::value_t, \
IdxT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream); \
\
template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_MAKE_INSTANCE(uint8_t, int64_t);
Expand Down

0 comments on commit 76c828d

Please sign in to comment.