From a0c57f4ca33d017600176feb6d5732468378fc03 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 3 Feb 2023 16:31:58 +0100 Subject: [PATCH 1/7] Allow use of mdspan view in IVF-PQ API --- cpp/bench/neighbors/knn.cuh | 12 +++- cpp/include/raft/neighbors/ivf_pq.cuh | 48 ++++++------- .../raft/spatial/knn/detail/ann_quantized.cuh | 13 +++- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 65 ++++++++---------- cpp/src/distance/neighbors/ivfpq_build.cu | 67 +++++++++---------- cpp/src/distance/neighbors/ivfpq_search.cu | 25 ++++--- cpp/test/neighbors/ann_ivf_pq.cuh | 35 ++++++---- 7 files changed, 137 insertions(+), 128 deletions(-) diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index eec1cba99e..7a2eadd096 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -183,8 +183,9 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - index.emplace(raft::neighbors::ivf_pq::build( - handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } void search(const raft::device_resources& handle, @@ -193,8 +194,13 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; + + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view); } }; diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 287f0bc5f4..fe1817c203 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -59,19 +59,18 @@ namespace raft::neighbors::ivf_pq { * @param handle * @param params configure the index building * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] - * @param n_rows the number of samples - * @param dim the dimensionality of the data * * @return the constructed ivf-pq index */ template inline auto build(raft::device_resources const& handle, const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index + const raft::device_matrix_view& dataset) -> index { - return raft::spatial::knn::ivf_pq::detail::build(handle, params, dataset, n_rows, dim); + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return raft::spatial::knn::ivf_pq::detail::build( + handle, params, dataset.data_handle(), n_rows, dim); } /** @@ -102,19 +101,18 @@ inline auto build(raft::device_resources const& handle, * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples * * @return the constructed extended ivf-pq index */ template inline auto extend(raft::device_resources const& handle, const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index + const raft::device_matrix_view& new_vectors, + const raft::device_matrix_view& new_indices) -> index { + IdxT n_rows = new_vectors.extent(0); return raft::spatial::knn::ivf_pq::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); + handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); } /** @@ -129,16 +127,14 @@ inline auto extend(raft::device_resources const& handle, * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples */ template inline void extend(raft::device_resources const& handle, index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) + const raft::device_matrix_view& new_vectors, + const raft::device_matrix_view& new_indices) { - *index = extend(handle, *index, new_vectors, new_indices, n_rows); + *index = extend(handle, *index, new_vectors, new_indices); } /** @@ -175,7 +171,6 @@ inline void extend(raft::device_resources const& handle, * @param params configure the search * @param index ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param n_queries the batch size * @param k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] @@ -187,15 +182,22 @@ template inline void search(raft::device_resources const& handle, const search_params& params, const index& index, - const T* queries, - uint32_t n_queries, + const raft::device_matrix_view& queries, uint32_t k, - IdxT* neighbors, - float* distances, + const raft::device_matrix_view& neighbors, + const raft::device_matrix_view& distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::spatial::knn::ivf_pq::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); + IdxT n_queries = queries.extent(0); + return raft::spatial::knn::ivf_pq::detail::search(handle, + params, + index, + queries.data_handle(), + n_queries, + k, + neighbors.data_handle(), + distances.data_handle(), + mr); } /** @} */ // end group ivf_pq diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 427e812cda..066dcaaa6b 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -78,8 +78,10 @@ void approx_knn_build_index(raft::device_resources const& handle, params.pq_bits = ivf_pq_pams->n_bits; params.pq_dim = ivf_pq_pams->M; // TODO: handle ivf_pq_pams.usePrecomputedTables ? - index->ivf_pq = std::make_unique>( - neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D)); + + auto index_view = raft::make_device_matrix_view(index_array, n, D); + index->ivf_pq = std::make_unique>( + neighbors::ivf_pq::build(handle, params, index_view)); } else { RAFT_FAIL("Unrecognized index type."); } @@ -110,8 +112,13 @@ void approx_knn_search(raft::device_resources const& handle, } else if (index->ivf_pq) { neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe; + + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); + auto indices_view = raft::make_device_matrix_view(indices, n, k); + auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( - handle, params, *index->ivf_pq, query_array, n, k, indices, distances); + handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index cae32c9530..40272987c0 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -20,15 +20,14 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_SEARCH(T, IdxT) \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - const T*, \ - uint32_t, \ - uint32_t, \ - IdxT*, \ - float*, \ +#define RAFT_INST_SEARCH(T, IdxT) \ + void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + const raft::device_matrix_view&, \ + uint32_t, \ + const raft::device_matrix_view&, \ + const raft::device_matrix_view&, \ rmm::mr::device_memory_resource*); RAFT_INST_SEARCH(float, uint64_t); @@ -40,33 +39,27 @@ RAFT_INST_SEARCH(uint8_t, uint64_t); // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const raft::device_matrix_view& dataset) \ + ->raft::neighbors::ivf_pq::index; \ + \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + ->raft::neighbors::ivf_pq::index; \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const raft::device_matrix_view& dataset, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices); RAFT_INST_BUILD_EXTEND(float, uint64_t) RAFT_INST_BUILD_EXTEND(int8_t, uint64_t) diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 650767f918..caa2092543 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -19,43 +19,36 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index \ - { \ - return raft::neighbors::ivf_pq::build(handle, params, dataset, n_rows, dim); \ - } \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_pq::index \ - { \ - return raft::neighbors::ivf_pq::extend( \ - handle, orig_index, new_vectors, new_indices, n_rows); \ - } \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim, \ - raft::neighbors::ivf_pq::index* idx) \ - { \ - *idx = raft::neighbors::ivf_pq::build(handle, params, dataset, n_rows, dim); \ - } \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - { \ - raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices, n_rows); \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const raft::device_matrix_view& dataset) \ + ->raft::neighbors::ivf_pq::index \ + { \ + return raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + ->raft::neighbors::ivf_pq::index \ + { \ + return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ + } \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const raft::device_matrix_view& dataset, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + { \ + raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ } RAFT_INST_BUILD_EXTEND(float, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search.cu b/cpp/src/distance/neighbors/ivfpq_search.cu index 05ab890ea5..6ce4cd8375 100644 --- a/cpp/src/distance/neighbors/ivfpq_search.cu +++ b/cpp/src/distance/neighbors/ivfpq_search.cu @@ -20,19 +20,18 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const raft::device_matrix_view& queries, \ + uint32_t k, \ + const raft::device_matrix_view& neighbors, \ + const raft::device_matrix_view& distances, \ + rmm::mr::device_memory_resource* mr) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances, mr); \ } RAFT_SEARCH_INST(float, uint64_t); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 488041f527..bf60c2f589 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -183,7 +183,9 @@ class ivf_pq_test : public ::testing::TestWithParam { auto ipams = ps.index_params; ipams.add_data_on_build = true; - return ivf_pq::build(handle_, ipams, database.data(), ps.num_db_vecs, ps.dim); + auto index_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + return ivf_pq::build(handle_, ipams, index_view); } auto build_2_extends() @@ -203,11 +205,17 @@ class ivf_pq_test : public ::testing::TestWithParam { auto ipams = ps.index_params; ipams.add_data_on_build = false; - auto index = - ivf_pq::build(handle_, ipams, database.data(), ps.num_db_vecs, ps.dim); + auto database_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_pq::build(handle_, ipams, database_view); - ivf_pq::extend(handle_, &index, vecs_2, inds_2, size_2); - return ivf_pq::extend(handle_, index, vecs_1, inds_1, size_1); + auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); + auto inds_2_view = raft::make_device_matrix_view(inds_2, size_2, 1); + ivf_pq::extend(handle_, &index, vecs_2_view, inds_2_view); + + auto vecs_1_view = raft::make_device_matrix_view(vecs_1, size_1, ps.dim); + auto inds_1_view = raft::make_device_matrix_view(inds_1, size_1, 1); + return ivf_pq::extend(handle_, index, vecs_1_view, inds_1_view); } template @@ -226,14 +234,15 @@ class ivf_pq_test : public ::testing::TestWithParam { rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); - ivf_pq::search(handle_, - ps.search_params, - index, - search_queries.data(), - ps.num_queries, - ps.k, - indices_ivf_pq_dev.data(), - distances_ivf_pq_dev.data()); + auto query_view = + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = + raft::make_device_matrix_view(indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = + raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + + ivf_pq::search( + handle_, ps.search_params, index, query_view, ps.k, inds_view, dists_view); update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); From 45553a1b3364744ae7d7ceb1a639bc8f2a93fef7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 6 Feb 2023 11:12:29 +0100 Subject: [PATCH 2/7] restore legacy API --- cpp/include/raft/neighbors/ivf_pq.cuh | 128 ++++++++++++++++++ cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 36 +++++ 2 files changed, 164 insertions(+) diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index fe1817c203..27f0404b6d 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -59,6 +59,35 @@ namespace raft::neighbors::ivf_pq { * @param handle * @param params configure the index building * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] + * @param n_rows the number of samples + * @param dim the dimensionality of the data + * + * @return the constructed ivf-pq index + */ +template +inline auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index +{ + return raft::spatial::knn::ivf_pq::detail::build(handle, params, dataset, n_rows, dim); +} + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * * @return the constructed ivf-pq index */ @@ -101,6 +130,37 @@ inline auto build(raft::device_resources const& handle, * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param n_rows the number of samples + * + * @return the constructed extended ivf-pq index + */ +template +inline auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + return raft::spatial::knn::ivf_pq::detail::extend( + handle, orig_index, new_vectors, new_indices, n_rows); +} + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are unchanged. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param orig_index original index + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. * * @return the constructed extended ivf-pq index */ @@ -127,6 +187,30 @@ inline auto extend(raft::device_resources const& handle, * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param n_rows the number of samples + */ +template +inline void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) +{ + *index = extend(handle, *index, new_vectors, new_indices, n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param[inout] index + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. */ template inline void extend(raft::device_resources const& handle, @@ -171,6 +255,7 @@ inline void extend(raft::device_resources const& handle, * @param params configure the search * @param index ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param n_queries the batch size * @param k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] @@ -179,6 +264,49 @@ inline void extend(raft::device_resources const& handle, * memory pool here to avoid memory allocations within search). */ template +inline void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + return raft::spatial::knn::ivf_pq::detail::search( + handle, params, index, queries, n_queries, k, neighbors, distances, mr); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param handle + * @param params configure the search + * @param index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param k the number of neighbors to find for each query. + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * @param mr an optional memory resource to use across the searches (you can provide a large enough + * memory pool here to avoid memory allocations within search). + */ +template inline void search(raft::device_resources const& handle, const search_params& params, const index& index, diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 40272987c0..7e83aa668a 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -21,6 +21,16 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_INST_SEARCH(T, IdxT) \ + void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + const T*, \ + uint32_t, \ + uint32_t, \ + IdxT*, \ + float*, \ + rmm::mr::device_memory_resource*); \ + \ void search(raft::device_resources const&, \ const raft::neighbors::ivf_pq::search_params&, \ const raft::neighbors::ivf_pq::index&, \ @@ -40,6 +50,32 @@ RAFT_INST_SEARCH(uint8_t, uint64_t); // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. #define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; \ + \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ auto build(raft::device_resources const& handle, \ const raft::neighbors::ivf_pq::index_params& params, \ const raft::device_matrix_view& dataset) \ From c602e8d02f49b2a5055824c477dee271e5f66028 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 6 Feb 2023 14:47:00 +0100 Subject: [PATCH 3/7] row_major + assert tests --- cpp/include/raft/neighbors/ivf_pq.cuh | 24 ++-- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 130 +++++++++--------- cpp/src/distance/neighbors/ivfpq_build.cu | 12 +- cpp/src/distance/neighbors/ivfpq_search.cu | 24 ++-- 4 files changed, 98 insertions(+), 92 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 27f0404b6d..a945472e03 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -94,7 +94,7 @@ inline auto build(raft::device_resources const& handle, template inline auto build(raft::device_resources const& handle, const index_params& params, - const raft::device_matrix_view& dataset) -> index + const raft::device_matrix_view& dataset) -> index { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); @@ -167,10 +167,13 @@ inline auto extend(raft::device_resources const& handle, template inline auto extend(raft::device_resources const& handle, const index& orig_index, - const raft::device_matrix_view& new_vectors, - const raft::device_matrix_view& new_indices) -> index + const raft::device_matrix_view& new_vectors, + const raft::device_matrix_view& new_indices) + -> index { IdxT n_rows = new_vectors.extent(0); + ASSERT(n_rows == new_indices.extent(0), + "new_vectors and new_indices have different number of rows"); return raft::spatial::knn::ivf_pq::detail::extend( handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); } @@ -215,8 +218,8 @@ inline void extend(raft::device_resources const& handle, template inline void extend(raft::device_resources const& handle, index* index, - const raft::device_matrix_view& new_vectors, - const raft::device_matrix_view& new_indices) + const raft::device_matrix_view& new_vectors, + const raft::device_matrix_view& new_indices) { *index = extend(handle, *index, new_vectors, new_indices); } @@ -310,13 +313,16 @@ template inline void search(raft::device_resources const& handle, const search_params& params, const index& index, - const raft::device_matrix_view& queries, + const raft::device_matrix_view& queries, uint32_t k, - const raft::device_matrix_view& neighbors, - const raft::device_matrix_view& distances, + const raft::device_matrix_view& neighbors, + const raft::device_matrix_view& distances, rmm::mr::device_memory_resource* mr = nullptr) { - IdxT n_queries = queries.extent(0); + IdxT n_queries = queries.extent(0); + bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0)); + ASSERT(check_n_rows, + "queries, neighbors and distances parameters have inconsistent number of rows"); return raft::spatial::knn::ivf_pq::detail::search(handle, params, index, diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 7e83aa668a..8745d1e568 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -20,24 +20,24 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_SEARCH(T, IdxT) \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - const T*, \ - uint32_t, \ - uint32_t, \ - IdxT*, \ - float*, \ - rmm::mr::device_memory_resource*); \ - \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - const raft::device_matrix_view&, \ - uint32_t, \ - const raft::device_matrix_view&, \ - const raft::device_matrix_view&, \ +#define RAFT_INST_SEARCH(T, IdxT) \ + void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + const T*, \ + uint32_t, \ + uint32_t, \ + IdxT*, \ + float*, \ + rmm::mr::device_memory_resource*); \ + \ + void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + const raft::device_matrix_view&, \ + uint32_t, \ + const raft::device_matrix_view&, \ + const raft::device_matrix_view&, \ rmm::mr::device_memory_resource*); RAFT_INST_SEARCH(float, uint64_t); @@ -49,53 +49,53 @@ RAFT_INST_SEARCH(uint8_t, uint64_t); // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const raft::device_matrix_view& dataset) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const raft::device_matrix_view& new_vectors, \ - const raft::device_matrix_view& new_indices) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const raft::device_matrix_view& dataset, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const raft::device_matrix_view& new_vectors, \ - const raft::device_matrix_view& new_indices); +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; \ + \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const raft::device_matrix_view& dataset) \ + ->raft::neighbors::ivf_pq::index; \ + \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + ->raft::neighbors::ivf_pq::index; \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const raft::device_matrix_view& dataset, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices); RAFT_INST_BUILD_EXTEND(float, uint64_t) RAFT_INST_BUILD_EXTEND(int8_t, uint64_t) diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index caa2092543..828f317908 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -22,15 +22,15 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_INST_BUILD_EXTEND(T, IdxT) \ auto build(raft::device_resources const& handle, \ const raft::neighbors::ivf_pq::index_params& params, \ - const raft::device_matrix_view& dataset) \ + const raft::device_matrix_view& dataset) \ ->raft::neighbors::ivf_pq::index \ { \ return raft::neighbors::ivf_pq::build(handle, params, dataset); \ } \ auto extend(raft::device_resources const& handle, \ const raft::neighbors::ivf_pq::index& orig_index, \ - const raft::device_matrix_view& new_vectors, \ - const raft::device_matrix_view& new_indices) \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ ->raft::neighbors::ivf_pq::index \ { \ return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ @@ -38,15 +38,15 @@ namespace raft::runtime::neighbors::ivf_pq { \ void build(raft::device_resources const& handle, \ const raft::neighbors::ivf_pq::index_params& params, \ - const raft::device_matrix_view& dataset, \ + const raft::device_matrix_view& dataset, \ raft::neighbors::ivf_pq::index* idx) \ { \ *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ } \ void extend(raft::device_resources const& handle, \ raft::neighbors::ivf_pq::index* idx, \ - const raft::device_matrix_view& new_vectors, \ - const raft::device_matrix_view& new_indices) \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ { \ raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ } diff --git a/cpp/src/distance/neighbors/ivfpq_search.cu b/cpp/src/distance/neighbors/ivfpq_search.cu index 6ce4cd8375..8c20303149 100644 --- a/cpp/src/distance/neighbors/ivfpq_search.cu +++ b/cpp/src/distance/neighbors/ivfpq_search.cu @@ -20,18 +20,18 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const raft::device_matrix_view& queries, \ - uint32_t k, \ - const raft::device_matrix_view& neighbors, \ - const raft::device_matrix_view& distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const raft::device_matrix_view& queries, \ + uint32_t k, \ + const raft::device_matrix_view& neighbors, \ + const raft::device_matrix_view& distances, \ + rmm::mr::device_memory_resource* mr) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances, mr); \ } RAFT_SEARCH_INST(float, uint64_t); From 17ef73c8344836de5b0d93c8a5dba207afdc4100 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 9 Feb 2023 17:20:37 +0100 Subject: [PATCH 4/7] addressing review --- cpp/include/raft/neighbors/ivf_pq.cuh | 275 +++++++++--------- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 39 +-- .../neighbors/ivfpq_search_float_uint64_t.cu | 24 +- .../neighbors/ivfpq_search_int8_t_uint64_t.cu | 24 +- .../ivfpq_search_uint8_t_uint64_t.cu | 5 +- .../ivfpq_build_float_uint64_t.cu | 11 +- .../ivfpq_build_int8_t_uint64_t.cu | 10 +- .../ivfpq_build_uint8_t_uint64_t.cu | 10 +- .../ivfpq_extend_float_uint64_t.cu | 24 +- .../ivfpq_extend_int8_t_uint64_t.cu | 24 +- .../ivfpq_extend_uint8_uint64_t.cu | 24 +- python/pylibraft/pylibraft/common/mdspan.pyx | 48 +++ .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 52 ++-- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 65 ++--- .../pylibraft/pylibraft/neighbors/refine.pyx | 48 +-- 15 files changed, 309 insertions(+), 374 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 4311705303..d392cfd3e5 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -32,6 +32,143 @@ namespace raft::neighbors::ivf_pq { * @{ */ +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +template +inline auto build(raft::device_resources const& handle, + const index_params& params, + const raft::device_matrix_view& dataset) -> index +{ + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return raft::spatial::knn::ivf_pq::detail::build( + handle, params, dataset.data_handle(), n_rows, dim); +} + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are unchanged. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param orig_index original index + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * + * @return the constructed extended ivf-pq index + */ +template +inline auto extend(raft::device_resources const& handle, + const index& orig_index, + const raft::device_matrix_view& new_vectors, + const raft::device_matrix_view& new_indices) + -> index +{ + IdxT n_rows = new_vectors.extent(0); + ASSERT(n_rows == new_indices.extent(0), + "new_vectors and new_indices have different number of rows"); + ASSERT(new_vectors.extent(1) == orig_index.dim(), + "new_vectors should have the same dimension as the index"); + return raft::spatial::knn::ivf_pq::detail::extend( + handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param[inout] index + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + */ +template +inline void extend(raft::device_resources const& handle, + index* index, + const raft::device_matrix_view& new_vectors, + const raft::device_matrix_view& new_indices) +{ + *index = extend(handle, *index, new_vectors, new_indices); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param handle + * @param params configure the search + * @param index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param k the number of neighbors to find for each query. + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +inline void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const raft::device_matrix_view& queries, + uint32_t k, + const raft::device_matrix_view& neighbors, + const raft::device_matrix_view& distances) +{ + IdxT n_queries = queries.extent(0); + bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0)); + ASSERT(check_n_rows, + "queries, neighbors and distances parameters have inconsistent number of rows"); + return raft::spatial::knn::ivf_pq::detail::search(handle, + params, + index, + queries.data_handle(), + n_queries, + k, + neighbors.data_handle(), + distances.data_handle(), + handle.get_workspace_resource()); +} + +/** @} */ // end group ivf_pq + /** * @brief Build the index from the dataset for efficient search. * @@ -74,34 +211,6 @@ auto build(raft::device_resources const& handle, return raft::spatial::knn::ivf_pq::detail::build(handle, params, dataset, n_rows, dim); } -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param params configure the index building - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-pq index - */ -template -inline auto build(raft::device_resources const& handle, - const index_params& params, - const raft::device_matrix_view& dataset) -> index -{ - IdxT n_rows = dataset.extent(0); - IdxT dim = dataset.extent(1); - return raft::spatial::knn::ivf_pq::detail::build( - handle, params, dataset.data_handle(), n_rows, dim); -} - /** * @brief Build a new index containing the data of the original plus new extra vectors. * @@ -145,39 +254,6 @@ auto extend(raft::device_resources const& handle, handle, orig_index, new_vectors, new_indices, n_rows); } -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are unchanged. - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param orig_index original index - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * - * @return the constructed extended ivf-pq index - */ -template -inline auto extend(raft::device_resources const& handle, - const index& orig_index, - const raft::device_matrix_view& new_vectors, - const raft::device_matrix_view& new_indices) - -> index -{ - IdxT n_rows = new_vectors.extent(0); - ASSERT(n_rows == new_indices.extent(0), - "new_vectors and new_indices have different number of rows"); - return raft::spatial::knn::ivf_pq::detail::extend( - handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); -} - /** * @brief Extend the index with the new data. * * @@ -202,28 +278,6 @@ void extend(raft::device_resources const& handle, *index = extend(handle, *index, new_vectors, new_indices, n_rows); } -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - */ -template -inline void extend(raft::device_resources const& handle, - index* index, - const raft::device_matrix_view& new_vectors, - const raft::device_matrix_view& new_indices) -{ - *index = extend(handle, *index, new_vectors, new_indices); -} - /** * @brief Search ANN using the constructed index. * @@ -281,59 +335,4 @@ void search(raft::device_resources const& handle, handle, params, index, queries, n_queries, k, neighbors, distances, mr); } -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param handle - * @param params configure the search - * @param index ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param k the number of neighbors to find for each query. - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param mr an optional memory resource to use across the searches (you can provide a large enough - * memory pool here to avoid memory allocations within search). - */ -template -inline void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - const raft::device_matrix_view& queries, - uint32_t k, - const raft::device_matrix_view& neighbors, - const raft::device_matrix_view& distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - IdxT n_queries = queries.extent(0); - bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0)); - ASSERT(check_n_rows, - "queries, neighbors and distances parameters have inconsistent number of rows"); - return raft::spatial::knn::ivf_pq::detail::search(handle, - params, - index, - queries.data_handle(), - n_queries, - k, - neighbors.data_handle(), - distances.data_handle(), - mr); -} - -/** @} */ // end group ivf_pq - } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index c8981d3f34..3b74f96f93 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -21,24 +21,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_INST_SEARCH(T, IdxT) \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - const T*, \ - uint32_t, \ - uint32_t, \ - IdxT*, \ - float*, \ - rmm::mr::device_memory_resource*); \ - \ void search(raft::device_resources const&, \ const raft::neighbors::ivf_pq::search_params&, \ const raft::neighbors::ivf_pq::index&, \ const raft::device_matrix_view&, \ uint32_t, \ const raft::device_matrix_view&, \ - const raft::device_matrix_view&, \ - rmm::mr::device_memory_resource*); + const raft::device_matrix_view&); RAFT_INST_SEARCH(float, uint64_t); RAFT_INST_SEARCH(int8_t, uint64_t); @@ -50,32 +39,6 @@ RAFT_INST_SEARCH(uint8_t, uint64_t); // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. #define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); \ auto build(raft::device_resources const& handle, \ const raft::neighbors::ivf_pq::index_params& params, \ const raft::device_matrix_view& dataset) \ diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu index c463aa9845..8bba357f04 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu @@ -20,19 +20,17 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const raft::device_matrix_view& queries, \ + uint32_t k, \ + const raft::device_matrix_view& neighbors, \ + const raft::device_matrix_view& distances) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances); \ } RAFT_SEARCH_INST(float, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu index ab0dd576b9..4e30595bb4 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu @@ -20,19 +20,17 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const raft::device_matrix_view& queries, \ + uint32_t k, \ + const raft::device_matrix_view& neighbors, \ + const raft::device_matrix_view& distances) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances); \ } RAFT_SEARCH_INST(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu index d9c69ef426..a115b1d1fd 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu @@ -27,11 +27,10 @@ namespace raft::runtime::neighbors::ivf_pq { const raft::device_matrix_view& queries, \ uint32_t k, \ const raft::device_matrix_view& neighbors, \ - const raft::device_matrix_view& distances, \ - rmm::mr::device_memory_resource* mr) \ + const raft::device_matrix_view& distances) \ { \ raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances, mr); \ + handle, params, idx, queries, k, neighbors, distances); \ } RAFT_SEARCH_INST(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu index 0831311372..fcb99228f2 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu @@ -18,13 +18,12 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + const raft::device_matrix_view& dataset) \ ->index; + RAFT_INST_BUILD_EXTEND(float, uint64_t); #undef RAFT_INST_BUILD_EXTEND diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu index 1e2502fe32..addb94556d 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu @@ -18,12 +18,10 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + const raft::device_matrix_view& dataset) \ ->index; RAFT_INST_BUILD_EXTEND(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu index e3336ad95e..0f87a4737a 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu @@ -18,12 +18,10 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + const raft::device_matrix_view& dataset) \ ->index; RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu index 7aa09e3f43..b4d36c9741 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu @@ -18,18 +18,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + const index& orig_index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices); RAFT_INST_BUILD_EXTEND(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu index 440fe6a4a0..2e5568d4d5 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu @@ -18,18 +18,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + const index& orig_index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices); RAFT_INST_BUILD_EXTEND(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_uint64_t.cu index 9aee2dc7d1..7add2ad88b 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_uint64_t.cu @@ -18,18 +18,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + const index& orig_index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + const raft::device_matrix_view& new_vectors, \ + const raft::device_matrix_view& new_indices); RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t); diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index ec825495f4..49c72ec16e 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -144,3 +144,51 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): X2 = np.load(f) assert np.all(X.shape == X2.shape) assert np.all(X == X2) + + +cdef device_matrix_view[float, uint64_t, row_major] \ + get_device_matrix_view_float(array, check_shape=True) except *: + cai = array + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[float, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[uint64_t, uint64_t, row_major] \ + get_device_matrix_view_uint64(array, check_shape=True) except *: + cai = array + if cai.dtype != np.uint64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint64_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[uint8_t, uint64_t, row_major] \ + get_device_matrix_view_uint8(array, check_shape=True) except *: + cai = array + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint8_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[int8_t, uint64_t, row_major] \ + get_device_matrix_view_int8(array, check_shape=True) except *: + cai = array + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[int8_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index c56c3e9d9b..2dbeb8115d 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -37,6 +37,7 @@ from libcpp.string cimport string from rmm._lib.memory_resource cimport device_memory_resource from pylibraft.common.handle cimport device_resources +from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.distance.distance_type cimport DistanceType @@ -110,72 +111,57 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ cdef void build(const device_resources& handle, const index_params& params, - const float* dataset, - uint64_t n_rows, - uint32_t dim, + const device_matrix_view[float, uint64_t, row_major]& dataset, index[uint64_t]* index) except + cdef void build(const device_resources& handle, const index_params& params, - const int8_t* dataset, - uint64_t n_rows, - uint32_t dim, + const device_matrix_view[int8_t, uint64_t, row_major]& dataset, index[uint64_t]* index) except + cdef void build(const device_resources& handle, const index_params& params, - const uint8_t* dataset, - uint64_t n_rows, - uint32_t dim, + device_matrix_view[uint8_t, uint64_t, row_major]& dataset, index[uint64_t]* index) except + cdef void extend(const device_resources& handle, index[uint64_t]* index, - const float* new_vectors, - const uint64_t* new_indices, - uint64_t n_rows) except + + const device_matrix_view[float, uint64_t, row_major]& new_vectors, + const device_matrix_view[uint64_t, uint64_t, row_major]& new_indices) except + cdef void extend(const device_resources& handle, index[uint64_t]* index, - const int8_t* new_vectors, - const uint64_t* new_indices, - uint64_t n_rows) except + + const device_matrix_view[int8_t, uint64_t, row_major]& new_vectors, + const device_matrix_view[uint64_t, uint64_t, row_major]& new_indices) except + cdef void extend(const device_resources& handle, index[uint64_t]* index, - const uint8_t* new_vectors, - const uint64_t* new_indices, - uint64_t n_rows) except + + const device_matrix_view[uint8_t, uint64_t, row_major]& new_vectors, + const device_matrix_view[uint64_t, uint64_t, row_major]& new_indices) except + cdef void search(const device_resources& handle, const search_params& params, const index[uint64_t]& index, - const float* queries, - uint32_t n_queries, + const device_matrix_view[float, uint64_t, row_major]& queries, uint32_t k, - uint64_t* neighbors, - float* distances, - device_memory_resource* mr) except + + const device_matrix_view[uint64_t, uint64_t, row_major]& neighbors, + const device_matrix_view[float, uint64_t, row_major]& distances) except + cdef void search(const device_resources& handle, const search_params& params, const index[uint64_t]& index, - const int8_t* queries, - uint32_t n_queries, + const device_matrix_view[int8_t, uint64_t, row_major]& queries, uint32_t k, - uint64_t* neighbors, - float* distances, - device_memory_resource* mr) except + + const device_matrix_view[uint64_t, uint64_t, row_major]& neighbors, + const device_matrix_view[float, uint64_t, row_major]& distances) except + cdef void search(const device_resources& handle, const search_params& params, const index[uint64_t]& index, - const uint8_t* queries, - uint32_t n_queries, + const device_matrix_view[uint8_t, uint64_t, row_major]& queries, uint32_t k, - uint64_t* neighbors, - float* distances, - device_memory_resource* mr) except + + const device_matrix_view[uint64_t, uint64_t, row_major]& neighbors, + const device_matrix_view[float, uint64_t, row_major]& distances) except + cdef void serialize(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index e7b69ddbea..de6368f217 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -50,6 +50,15 @@ from pylibraft.common.handle cimport device_resources from pylibraft.common.handle import auto_sync_handle from pylibraft.common.input_validation import is_c_contiguous +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + make_device_matrix_view, + row_major, + get_device_matrix_view_float, + get_device_matrix_view_uint64, + get_device_matrix_view_uint8, + get_device_matrix_view_int8 +) from rmm._lib.memory_resource cimport ( DeviceMemoryResource, @@ -377,7 +386,6 @@ def build(IndexParams index_params, dataset, handle=None): dataset_dt = dataset_cai.dtype _check_input_array(dataset_cai, [np.dtype('float32'), np.dtype('byte'), np.dtype('ubyte')]) - cdef uintptr_t dataset_ptr = dataset_cai.data cdef uint64_t n_rows = dataset_cai.shape[0] cdef uint32_t dim = dataset_cai.shape[1] @@ -393,27 +401,21 @@ def build(IndexParams index_params, dataset, handle=None): with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - dataset_ptr, - n_rows, - dim, + get_device_matrix_view_float(dataset_cai), idx.index) idx.trained = True elif dataset_dt == np.byte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - dataset_ptr, - n_rows, - dim, + get_device_matrix_view_int8(dataset_cai), idx.index) idx.trained = True elif dataset_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - dataset_ptr, - n_rows, - dim, + get_device_matrix_view_uint8(dataset_cai), idx.index) idx.trained = True else: @@ -505,30 +507,24 @@ def extend(Index index, new_vectors, new_indices, handle=None): if len(idx_cai.shape)!=1: raise ValueError("Indices array is expected to be 1D") - cdef uintptr_t vecs_ptr = vecs_cai.data - cdef uintptr_t idx_ptr = idx_cai.data - if vecs_dt == np.float32: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - vecs_ptr, - idx_ptr, - n_rows) + get_device_matrix_view_float(vecs_cai), + get_device_matrix_view_uint64(idx_cai, check_shape=False)) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - vecs_ptr, - idx_ptr, - n_rows) + get_device_matrix_view_int8(vecs_cai), + get_device_matrix_view_uint64(idx_cai, check_shape=False)) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - vecs_ptr, - idx_ptr, - n_rows) + get_device_matrix_view_uint8(vecs_cai), + get_device_matrix_view_uint64(idx_cai, check_shape=False)) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -705,7 +701,6 @@ def search(SearchParams search_params, cdef c_ivf_pq.search_params params = search_params.params - cdef uintptr_t queries_ptr = queries_cai.data cdef uintptr_t neighbors_ptr = neighbors_cai.data cdef uintptr_t distances_ptr = distances_cai.data # TODO(tfeher) pass mr_ptr arg @@ -718,34 +713,28 @@ def search(SearchParams search_params, c_ivf_pq.search(deref(handle_), params, deref(index.index), - queries_ptr, - n_queries, + get_device_matrix_view_float(queries_cai), k, - neighbors_ptr, - distances_ptr, - mr_ptr) + get_device_matrix_view_uint64(neighbors_cai), + get_device_matrix_view_float(distances_cai)) elif queries_dt == np.byte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), params, deref(index.index), - queries_ptr, - n_queries, + get_device_matrix_view_int8(queries_cai), k, - neighbors_ptr, - distances_ptr, - mr_ptr) + get_device_matrix_view_uint64(neighbors_cai), + get_device_matrix_view_float(distances_cai)) elif queries_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), params, deref(index.index), - queries_ptr, - n_queries, + get_device_matrix_view_uint8(queries_cai), k, - neighbors_ptr, - distances_ptr, - mr_ptr) + get_device_matrix_view_uint64(neighbors_cai), + get_device_matrix_view_float(distances_cai)) else: raise ValueError("query dtype %s not supported" % queries_dt) diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index 5c652f7c73..2eca6a8e71 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -58,6 +58,10 @@ from pylibraft.common.cpp.mdspan cimport ( make_device_matrix_view, make_host_matrix_view, row_major, + get_device_matrix_view_float, + get_device_matrix_view_uint64, + get_device_matrix_view_uint8, + get_device_matrix_view_int8 ) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, @@ -125,50 +129,6 @@ cdef extern from "raft_runtime/neighbors/refine.hpp" \ DistanceType metric) except + -cdef device_matrix_view[float, uint64_t, row_major] \ - get_device_matrix_view_float(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.float32: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[float, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - -cdef device_matrix_view[uint64_t, uint64_t, row_major] \ - get_device_matrix_view_uint64(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.uint64: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[uint64_t, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - -cdef device_matrix_view[uint8_t, uint64_t, row_major] \ - get_device_matrix_view_uint8(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.uint8: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[uint8_t, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - -cdef device_matrix_view[int8_t, uint64_t, row_major] \ - get_device_matrix_view_int8(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.int8: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[int8_t, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - def _get_array_params(array_interface, check_dtype=None): dtype = np.dtype(array_interface["typestr"]) if check_dtype is None and dtype != check_dtype: From d4e36609567686744cd378eed807e3fb32f260e4 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 17 Feb 2023 19:10:31 +0100 Subject: [PATCH 5/7] fix style --- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index de6368f217..5b6aabcf88 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -50,15 +50,6 @@ from pylibraft.common.handle cimport device_resources from pylibraft.common.handle import auto_sync_handle from pylibraft.common.input_validation import is_c_contiguous -from pylibraft.common.cpp.mdspan cimport ( - device_matrix_view, - make_device_matrix_view, - row_major, - get_device_matrix_view_float, - get_device_matrix_view_uint64, - get_device_matrix_view_uint8, - get_device_matrix_view_int8 -) from rmm._lib.memory_resource cimport ( DeviceMemoryResource, @@ -66,6 +57,15 @@ from rmm._lib.memory_resource cimport ( ) cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + get_device_matrix_view_float, + get_device_matrix_view_int8, + get_device_matrix_view_uint8, + get_device_matrix_view_uint64, + make_device_matrix_view, + row_major, +) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, search_params, @@ -512,19 +512,22 @@ def extend(Index index, new_vectors, new_indices, handle=None): c_ivf_pq.extend(deref(handle_), index.index, get_device_matrix_view_float(vecs_cai), - get_device_matrix_view_uint64(idx_cai, check_shape=False)) + get_device_matrix_view_uint64(idx_cai, + check_shape=False)) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, get_device_matrix_view_int8(vecs_cai), - get_device_matrix_view_uint64(idx_cai, check_shape=False)) + get_device_matrix_view_uint64(idx_cai, + check_shape=False)) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, get_device_matrix_view_uint8(vecs_cai), - get_device_matrix_view_uint64(idx_cai, check_shape=False)) + get_device_matrix_view_uint64(idx_cai, + check_shape=False)) else: raise TypeError("query dtype %s not supported" % vecs_dt) From c73cfb2fddf13f8249040a7787180d5f3e15b428 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 1 Mar 2023 16:09:17 +0100 Subject: [PATCH 6/7] moving helper funcs around --- .../pylibraft/pylibraft/common/cpp/mdspan.pxd | 1 + python/pylibraft/pylibraft/common/mdspan.pxd | 39 +++++++ python/pylibraft/pylibraft/common/mdspan.pyx | 48 ++++++++ .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 107 ++++-------------- .../pylibraft/pylibraft/neighbors/refine.pyx | 95 ++++------------ 5 files changed, 135 insertions(+), 155 deletions(-) create mode 100644 python/pylibraft/pylibraft/common/mdspan.pxd diff --git a/python/pylibraft/pylibraft/common/cpp/mdspan.pxd b/python/pylibraft/pylibraft/common/cpp/mdspan.pxd index c3e5abb47e..a8c636f0b7 100644 --- a/python/pylibraft/pylibraft/common/cpp/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/cpp/mdspan.pxd @@ -19,6 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t from libcpp.string cimport string from pylibraft.common.handle cimport device_resources diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd new file mode 100644 index 0000000000..2a0bdaca62 --- /dev/null +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -0,0 +1,39 @@ +# +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, uint8_t, uint64_t +from libcpp.string cimport string + +from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major +from pylibraft.common.handle cimport device_resources + + +cdef device_matrix_view[float, uint64_t, row_major] get_dmv_float( + array, check_shape) except * + +cdef device_matrix_view[uint8_t, uint64_t, row_major] get_dmv_uint8( + array, check_shape) except * + +cdef device_matrix_view[int8_t, uint64_t, row_major] get_dmv_int8( + array, check_shape) except * + +cdef device_matrix_view[uint64_t, uint64_t, row_major] get_dmv_uint64( + array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index fb17a8b1a9..d8524d94b4 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -154,3 +154,51 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): X2 = np.load(f) assert np.all(X.shape == X2.shape) assert np.all(X == X2) + + +cdef device_matrix_view[float, uint64_t, row_major] \ + get_dmv_float(array, check_shape) except *: + cai = array + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[float, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[uint8_t, uint64_t, row_major] \ + get_dmv_uint8(array, check_shape) except *: + cai = array + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint8_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[int8_t, uint64_t, row_major] \ + get_dmv_int8(array, check_shape) except *: + cai = array + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[int8_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[uint64_t, uint64_t, row_major] \ + get_dmv_uint64(array, check_shape) except *: + cai = array + if cai.dtype != np.uint64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint64_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index c4512715ec..47d8e94e5f 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -23,15 +23,7 @@ import warnings import numpy as np from cython.operator cimport dereference as deref -from libc.stdint cimport ( - int8_t, - int32_t, - int64_t, - uint8_t, - uint32_t, - uint64_t, - uintptr_t, -) +from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t, uintptr_t from libcpp cimport bool, nullptr from libcpp.string cimport string @@ -58,10 +50,12 @@ from rmm._lib.memory_resource cimport ( ) cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq -from pylibraft.common.cpp.mdspan cimport ( - device_matrix_view, - make_device_matrix_view, - row_major, +from pylibraft.common.cpp.mdspan cimport device_matrix_view +from pylibraft.common.mdspan cimport ( + get_dmv_float, + get_dmv_int8, + get_dmv_uint8, + get_dmv_uint64, ) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, @@ -69,54 +63,6 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( ) -cdef device_matrix_view[float, uint64_t, row_major] \ - get_device_matrix_view_float(array, check_shape=True) except *: - cai = array - if cai.dtype != np.float32: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[float, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - -cdef device_matrix_view[uint64_t, uint64_t, row_major] \ - get_device_matrix_view_uint64(array, check_shape=True) except *: - cai = array - if cai.dtype != np.uint64: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[uint64_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - -cdef device_matrix_view[uint8_t, uint64_t, row_major] \ - get_device_matrix_view_uint8(array, check_shape=True) except *: - cai = array - if cai.dtype != np.uint8: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[uint8_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - -cdef device_matrix_view[int8_t, uint64_t, row_major] \ - get_device_matrix_view_int8(array, check_shape=True) except *: - cai = array - if cai.dtype != np.int8: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[int8_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - def _get_metric(metric): SUPPORTED_DISTANCES = { "sqeuclidean": DistanceType.L2Expanded, @@ -464,21 +410,21 @@ def build(IndexParams index_params, dataset, handle=None): with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - get_device_matrix_view_float(dataset_cai), + get_dmv_float(dataset_cai, check_shape=True), idx.index) idx.trained = True elif dataset_dt == np.byte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - get_device_matrix_view_int8(dataset_cai), + get_dmv_int8(dataset_cai, check_shape=True), idx.index) idx.trained = True elif dataset_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - get_device_matrix_view_uint8(dataset_cai), + get_dmv_uint8(dataset_cai, check_shape=True), idx.index) idx.trained = True else: @@ -574,23 +520,20 @@ def extend(Index index, new_vectors, new_indices, handle=None): with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - get_device_matrix_view_float(vecs_cai), - get_device_matrix_view_uint64(idx_cai, - check_shape=False)) + get_dmv_float(vecs_cai, check_shape=True), + get_dmv_uint64(idx_cai, check_shape=False)) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - get_device_matrix_view_int8(vecs_cai), - get_device_matrix_view_uint64(idx_cai, - check_shape=False)) + get_dmv_int8(vecs_cai, check_shape=True), + get_dmv_uint64(idx_cai, check_shape=False)) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - get_device_matrix_view_uint8(vecs_cai), - get_device_matrix_view_uint64(idx_cai, - check_shape=False)) + get_dmv_uint8(vecs_cai, check_shape=True), + get_dmv_uint64(idx_cai, check_shape=False)) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -779,28 +722,28 @@ def search(SearchParams search_params, c_ivf_pq.search(deref(handle_), params, deref(index.index), - get_device_matrix_view_float(queries_cai), + get_dmv_float(queries_cai, check_shape=True), k, - get_device_matrix_view_uint64(neighbors_cai), - get_device_matrix_view_float(distances_cai)) + get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.byte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), params, deref(index.index), - get_device_matrix_view_int8(queries_cai), + get_dmv_int8(queries_cai, check_shape=True), k, - get_device_matrix_view_uint64(neighbors_cai), - get_device_matrix_view_float(distances_cai)) + get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), params, deref(index.index), - get_device_matrix_view_uint8(queries_cai), + get_dmv_uint8(queries_cai, check_shape=True), k, - get_device_matrix_view_uint64(neighbors_cai), - get_device_matrix_view_float(distances_cai)) + get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) else: raise ValueError("query dtype %s not supported" % queries_dt) diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index e5c88b5c33..e01971d681 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -21,15 +21,7 @@ import numpy as np from cython.operator cimport dereference as deref -from libc.stdint cimport ( - int8_t, - int32_t, - int64_t, - uint8_t, - uint32_t, - uint64_t, - uintptr_t, -) +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t from libcpp cimport bool, nullptr from pylibraft.distance.distance_type cimport DistanceType @@ -56,64 +48,21 @@ cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq from pylibraft.common.cpp.mdspan cimport ( device_matrix_view, host_matrix_view, - make_device_matrix_view, make_host_matrix_view, row_major, ) +from pylibraft.common.mdspan cimport ( + get_dmv_float, + get_dmv_int8, + get_dmv_uint8, + get_dmv_uint64, +) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, search_params, ) -cdef device_matrix_view[float, uint64_t, row_major] \ - get_device_matrix_view_float(array, check_shape=True) except *: - cai = array - if cai.dtype != np.float32: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[float, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - -cdef device_matrix_view[uint64_t, uint64_t, row_major] \ - get_device_matrix_view_uint64(array, check_shape=True) except *: - cai = array - if cai.dtype != np.uint64: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[uint64_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - -cdef device_matrix_view[uint8_t, uint64_t, row_major] \ - get_device_matrix_view_uint8(array, check_shape=True) except *: - cai = array - if cai.dtype != np.uint8: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[uint8_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - -cdef device_matrix_view[int8_t, uint64_t, row_major] \ - get_device_matrix_view_int8(array, check_shape=True) except *: - cai = array - if cai.dtype != np.int8: - raise TypeError("dtype %s not supported" % cai.dtype) - if check_shape and len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[int8_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) - - # We omit the const qualifiers in the interface for refine, because cython # has an issue parsing it (https://github.com/cython/cython/issues/4180). cdef extern from "raft_runtime/neighbors/refine.hpp" \ @@ -338,29 +287,29 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, if dataset_cai.dtype == np.float32: with cuda_interruptible(): c_refine(deref(handle_), - get_device_matrix_view_float(dataset), - get_device_matrix_view_float(queries), - get_device_matrix_view_uint64(candidates), - get_device_matrix_view_uint64(indices), - get_device_matrix_view_float(distances), + get_dmv_float(dataset, check_shape=True), + get_dmv_float(queries, check_shape=True), + get_dmv_uint64(candidates, check_shape=True), + get_dmv_uint64(indices, check_shape=True), + get_dmv_float(distances, check_shape=True), c_metric) elif dataset_cai.dtype == np.int8: with cuda_interruptible(): c_refine(deref(handle_), - get_device_matrix_view_int8(dataset), - get_device_matrix_view_int8(queries), - get_device_matrix_view_uint64(candidates), - get_device_matrix_view_uint64(indices), - get_device_matrix_view_float(distances), + get_dmv_int8(dataset, check_shape=True), + get_dmv_int8(queries, check_shape=True), + get_dmv_uint64(candidates, check_shape=True), + get_dmv_uint64(indices, check_shape=True), + get_dmv_float(distances, check_shape=True), c_metric) elif dataset_cai.dtype == np.uint8: with cuda_interruptible(): c_refine(deref(handle_), - get_device_matrix_view_uint8(dataset), - get_device_matrix_view_uint8(queries), - get_device_matrix_view_uint64(candidates), - get_device_matrix_view_uint64(indices), - get_device_matrix_view_float(distances), + get_dmv_uint8(dataset, check_shape=True), + get_dmv_uint8(queries, check_shape=True), + get_dmv_uint64(candidates, check_shape=True), + get_dmv_uint64(indices, check_shape=True), + get_dmv_float(distances, check_shape=True), c_metric) else: raise TypeError("dtype %s not supported" % dataset_cai.dtype) From f3f3cf750c1c11d385696e944a3163336e0323d8 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 7 Mar 2023 11:53:30 +0100 Subject: [PATCH 7/7] fix refine --- python/pylibraft/pylibraft/common/mdspan.pyx | 12 ++---- .../pylibraft/pylibraft/neighbors/refine.pyx | 38 ++++++++++--------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index d8524d94b4..22afda043d 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -157,8 +157,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): cdef device_matrix_view[float, uint64_t, row_major] \ - get_dmv_float(array, check_shape) except *: - cai = array + get_dmv_float(cai, check_shape) except *: if cai.dtype != np.float32: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: @@ -169,8 +168,7 @@ cdef device_matrix_view[float, uint64_t, row_major] \ cdef device_matrix_view[uint8_t, uint64_t, row_major] \ - get_dmv_uint8(array, check_shape) except *: - cai = array + get_dmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: @@ -181,8 +179,7 @@ cdef device_matrix_view[uint8_t, uint64_t, row_major] \ cdef device_matrix_view[int8_t, uint64_t, row_major] \ - get_dmv_int8(array, check_shape) except *: - cai = array + get_dmv_int8(cai, check_shape) except *: if cai.dtype != np.int8: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: @@ -193,8 +190,7 @@ cdef device_matrix_view[int8_t, uint64_t, row_major] \ cdef device_matrix_view[uint64_t, uint64_t, row_major] \ - get_dmv_uint64(array, check_shape) except *: - cai = array + get_dmv_uint64(cai, check_shape) except *: if cai.dtype != np.uint64: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index e01971d681..ddc6f115a3 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -272,6 +272,9 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, raise ValueError("Argument k must be specified if both indices " "and distances arg is None") + queries_cai = cai_wrapper(queries) + dataset_cai = cai_wrapper(dataset) + candidates_cai = cai_wrapper(candidates) n_queries = cai_wrapper(queries).shape[0] if indices is None: @@ -280,36 +283,37 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, if distances is None: distances = device_ndarray.empty((n_queries, k), dtype='float32') - cdef DistanceType c_metric = _get_metric(metric) + indices_cai = cai_wrapper(indices) + distances_cai = cai_wrapper(distances) - dataset_cai = cai_wrapper(dataset) + cdef DistanceType c_metric = _get_metric(metric) if dataset_cai.dtype == np.float32: with cuda_interruptible(): c_refine(deref(handle_), - get_dmv_float(dataset, check_shape=True), - get_dmv_float(queries, check_shape=True), - get_dmv_uint64(candidates, check_shape=True), - get_dmv_uint64(indices, check_shape=True), - get_dmv_float(distances, check_shape=True), + get_dmv_float(dataset_cai, check_shape=True), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_uint64(candidates_cai, check_shape=True), + get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), c_metric) elif dataset_cai.dtype == np.int8: with cuda_interruptible(): c_refine(deref(handle_), - get_dmv_int8(dataset, check_shape=True), - get_dmv_int8(queries, check_shape=True), - get_dmv_uint64(candidates, check_shape=True), - get_dmv_uint64(indices, check_shape=True), - get_dmv_float(distances, check_shape=True), + get_dmv_int8(dataset_cai, check_shape=True), + get_dmv_int8(queries_cai, check_shape=True), + get_dmv_uint64(candidates_cai, check_shape=True), + get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), c_metric) elif dataset_cai.dtype == np.uint8: with cuda_interruptible(): c_refine(deref(handle_), - get_dmv_uint8(dataset, check_shape=True), - get_dmv_uint8(queries, check_shape=True), - get_dmv_uint64(candidates, check_shape=True), - get_dmv_uint64(indices, check_shape=True), - get_dmv_float(distances, check_shape=True), + get_dmv_uint8(dataset_cai, check_shape=True), + get_dmv_uint8(queries_cai, check_shape=True), + get_dmv_uint64(candidates_cai, check_shape=True), + get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), c_metric) else: raise TypeError("dtype %s not supported" % dataset_cai.dtype)