Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mdspan view for IVF-PQ API #1236

Merged
merged 18 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ struct ivf_pq_knn {
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_pq::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}

void search(const raft::device_resources& handle,
Expand All @@ -192,8 +193,13 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;

auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view);
}
};

Expand Down
139 changes: 137 additions & 2 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,143 @@ namespace raft::neighbors::ivf_pq {
* @{
*/

/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param params configure the index building
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
*
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
inline auto build(raft::device_resources const& handle,
const index_params& params,
const raft::device_matrix_view<const T, IdxT, row_major>& dataset) -> index<IdxT>
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
return raft::spatial::knn::ivf_pq::detail::build(
handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
* Implementation note:
* The new data is clustered according to existing kmeans clusters, then the cluster
* centers are unchanged.
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
inline auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
const raft::device_matrix_view<const T, IdxT, row_major>& new_vectors,
const raft::device_matrix_view<const IdxT, IdxT, row_major>& new_indices)
-> index<IdxT>
{
IdxT n_rows = new_vectors.extent(0);
ASSERT(n_rows == new_indices.extent(0),
"new_vectors and new_indices have different number of rows");
ASSERT(new_vectors.extent(1) == orig_index.dim(),
"new_vectors should have the same dimension as the index");
return raft::spatial::knn::ivf_pq::detail::extend(
handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows);
}

/**
* @brief Extend the index with the new data.
* *
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*/
template <typename T, typename IdxT>
inline void extend(raft::device_resources const& handle,
index<IdxT>* index,
const raft::device_matrix_view<const T, IdxT, row_major>& new_vectors,
const raft::device_matrix_view<const IdxT, IdxT, row_major>& new_indices)
{
*index = extend(handle, *index, new_vectors, new_indices);
}

/**
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
* @brief Search ANN using the constructed index.
*
* See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
* The exact size of the temporary buffer depends on multiple factors and is an implementation
* detail. However, you can safely specify a small initial size for the memory pool, so that only a
* few allocations happen to grow it during the first invocations of the `search`.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param k the number of neighbors to find for each query.
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
const raft::device_matrix_view<const T, IdxT, row_major>& queries,
uint32_t k,
const raft::device_matrix_view<IdxT, IdxT, row_major>& neighbors,
const raft::device_matrix_view<float, IdxT, row_major>& distances)
{
IdxT n_queries = queries.extent(0);
bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0));
ASSERT(check_n_rows,
"queries, neighbors and distances parameters have inconsistent number of rows");
return raft::spatial::knn::ivf_pq::detail::search(handle,
params,
index,
queries.data_handle(),
n_queries,
k,
neighbors.data_handle(),
distances.data_handle(),
handle.get_workspace_resource());
}

/** @} */ // end group ivf_pq

/**
* @brief Build the index from the dataset for efficient search.
*
Expand Down Expand Up @@ -198,6 +335,4 @@ void search(raft::device_resources const& handle,
handle, params, index, queries, n_queries, k, neighbors, distances, mr);
}

/** @} */ // end group ivf_pq

} // namespace raft::neighbors::ivf_pq
13 changes: 10 additions & 3 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ void approx_knn_build_index(raft::device_resources const& handle,
params.pq_bits = ivf_pq_pams->n_bits;
params.pq_dim = ivf_pq_pams->M;
// TODO: handle ivf_pq_pams.usePrecomputedTables ?
index->ivf_pq = std::make_unique<const neighbors::ivf_pq::index<int64_t>>(
neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D));

auto index_view = raft::make_device_matrix_view<const T, IntType>(index_array, n, D);
index->ivf_pq = std::make_unique<const neighbors::ivf_pq::index<int64_t>>(
neighbors::ivf_pq::build(handle, params, index_view));
} else {
RAFT_FAIL("Unrecognized index type.");
}
Expand Down Expand Up @@ -110,8 +112,13 @@ void approx_knn_search(raft::device_resources const& handle,
} else if (index->ivf_pq) {
neighbors::ivf_pq::search_params params;
params.n_probes = index->nprobe;

auto query_view =
raft::make_device_matrix_view<const T, IntType>(query_array, n, index->ivf_pq->dim());
auto indices_view = raft::make_device_matrix_view<IntType, IntType>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, IntType>(distances, n, k);
neighbors::ivf_pq::search(
handle, params, *index->ivf_pq, query_array, n, k, indices, distances);
handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view);
} else {
RAFT_FAIL("The model is not trained");
}
Expand Down
66 changes: 29 additions & 37 deletions cpp/include/raft_runtime/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@

namespace raft::runtime::neighbors::ivf_pq {

#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
const raft::neighbors::ivf_pq::search_params&, \
const raft::neighbors::ivf_pq::index<IdxT>&, \
const T*, \
uint32_t, \
uint32_t, \
IdxT*, \
float*, \
rmm::mr::device_memory_resource*);
#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
const raft::neighbors::ivf_pq::search_params&, \
const raft::neighbors::ivf_pq::index<IdxT>&, \
const raft::device_matrix_view<const T, IdxT, row_major>&, \
uint32_t, \
const raft::device_matrix_view<IdxT, IdxT, row_major>&, \
const raft::device_matrix_view<float, IdxT, row_major>&);

RAFT_INST_SEARCH(float, uint64_t);
RAFT_INST_SEARCH(int8_t, uint64_t);
Expand All @@ -40,33 +38,27 @@ RAFT_INST_SEARCH(uint8_t, uint64_t);
// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim, \
raft::neighbors::ivf_pq::index<IdxT>* idx); \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_pq::index<IdxT>* idx, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows);
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
const raft::neighbors::ivf_pq::index_params& params, \
const raft::device_matrix_view<const T, IdxT, row_major>& dataset) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index<IdxT>& orig_index, \
const raft::device_matrix_view<const T, IdxT, row_major>& new_vectors, \
const raft::device_matrix_view<const IdxT, IdxT, row_major>& new_indices) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
const raft::device_matrix_view<const T, IdxT, row_major>& dataset, \
raft::neighbors::ivf_pq::index<IdxT>* idx); \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_pq::index<IdxT>* idx, \
const raft::device_matrix_view<const T, IdxT, row_major>& new_vectors, \
const raft::device_matrix_view<const IdxT, IdxT, row_major>& new_indices);

RAFT_INST_BUILD_EXTEND(float, uint64_t)
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t)
Expand Down
Loading