From 136eafd650b5bb765b22664bcbea388e4b12ef56 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 11 Nov 2024 07:57:46 -0800 Subject: [PATCH 1/8] Add serialization API to brute-force --- cpp/include/cuvs/neighbors/brute_force.hpp | 126 +++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 428fa592a..8512f3f68 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -47,6 +47,14 @@ struct index : cuvs::neighbors::index { index& operator=(index&&) = default; ~index() = default; + /** + * @brief Construct an empty index. + * + * Constructs an empty index. This index will either need to be trained with `build` + * or loaded from a saved copy with `deserialize` + */ + index(raft::resources const& handle); + /** Construct a brute force index from dataset * * Constructs a brute force index from a dataset. This lets us precompute norms for @@ -371,6 +379,124 @@ void search(raft::resources const& handle, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = cuvs::neighbors::filtering::none_sample_filter{}); + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output + * + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output + * + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = half; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index); +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index); + +/**@}*/ + +} // namespace raft::neighbors::brute_force + /** * @} */ From 622e3e3f2c3f6d7ad128318275020118a17f7933 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 11 Nov 2024 10:37:22 -0800 Subject: [PATCH 2/8] Add implementation and test --- cpp/CMakeLists.txt | 1 + cpp/include/cuvs/neighbors/brute_force.hpp | 115 +++++++++++++- cpp/src/neighbors/brute_force.cu | 15 ++ cpp/src/neighbors/brute_force_serialize.cu | 167 +++++++++++++++++++++ cpp/test/neighbors/ann_brute_force.cuh | 18 ++- 5 files changed, 308 insertions(+), 8 deletions(-) create mode 100644 cpp/src/neighbors/brute_force_serialize.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c493af488..aa11187bf 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -370,6 +370,7 @@ if(BUILD_SHARED_LIBS) src/distance/distance.cu src/distance/pairwise_distance.cu src/neighbors/brute_force.cu + src/neighbors/brute_force_serialize.cu src/neighbors/cagra_build_float.cu src/neighbors/cagra_build_half.cu src/neighbors/cagra_build_int8.cu diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 8512f3f68..b355d7264 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -404,7 +404,6 @@ void search(raft::resources const& handle, * @param[in] index brute force index * @param[in] include_dataset whether to include the dataset in the serialized * output - * */ void serialize(raft::resources const& handle, const std::string& filename, @@ -440,6 +439,61 @@ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::brute_force::index& index, bool include_dataset = true); + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index brute force index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + std::ostream& os, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index brute force index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + std::ostream& os, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); + /** * Load index from file. * @@ -473,7 +527,7 @@ void deserialize(raft::resources const& handle, * * @code{.cpp} * #include - * #include + * #include * * raft::resources handle; * @@ -492,11 +546,58 @@ void deserialize(raft::resources const& handle, void deserialize(raft::resources const& handle, const std::string& filename, cuvs::neighbors::brute_force::index* index); - -/**@}*/ - -} // namespace raft::neighbors::brute_force - +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = half; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, is, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index); +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, is, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index); /** * @} */ diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index b0f87e9ac..d534676e3 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -21,6 +21,21 @@ #include namespace cuvs::neighbors::brute_force { + +template +index::index(raft::resources const& res) + // this constructor is just for a temporary index, for use in the deserialization + // api. all the parameters here will get replaced with loaded values - that aren't + // necessarily known ahead of time before deserialization. + // TODO: do we even need a handle here - could just construct one? + : cuvs::neighbors::index(), + metric_(cuvs::distance::DistanceType::L2Expanded), + dataset_(raft::make_device_matrix(res, 0, 0)), + norms_(std::nullopt), + metric_arg_(0) +{ +} + template index::index(raft::resources const& res, raft::host_matrix_view dataset, diff --git a/cpp/src/neighbors/brute_force_serialize.cu b/cpp/src/neighbors/brute_force_serialize.cu new file mode 100644 index 000000000..c427647f1 --- /dev/null +++ b/cpp/src/neighbors/brute_force_serialize.cu @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::brute_force { + +int constexpr serialization_version = 0; + +template +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset = true) +{ + RAFT_LOG_DEBUG( + "Saving brute force index, size %zu, dim %u", static_cast(index.size()), index.dim()); + + auto dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + + raft::serialize_scalar(handle, os, serialization_version); + raft::serialize_scalar(handle, os, index.size()); + raft::serialize_scalar(handle, os, index.dim()); + raft::serialize_scalar(handle, os, index.metric()); + raft::serialize_scalar(handle, os, index.metric_arg()); + raft::serialize_scalar(handle, os, include_dataset); + if (include_dataset) { raft::serialize_mdspan(handle, os, index.dataset()); } + auto has_norms = index.has_norms(); + raft::serialize_scalar(handle, os, has_norms); + if (has_norms) { raft::serialize_mdspan(handle, os, index.norms()); } + raft::resource::sync_stream(handle); +} + +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset) +{ + auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); + serialize(handle, os, index, include_dataset); +} + +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset) +{ + auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); + serialize(handle, os, index, include_dataset); +} + +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset) +{ + serialize(handle, os, index, include_dataset); +} + +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset) +{ + serialize(handle, os, index, include_dataset); +} + +template +auto deserialize(raft::resources const& handle, std::istream& is) +{ + auto dtype_string = std::array{}; + is.read(dtype_string.data(), 4); + + auto ver = raft::deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + std::int64_t rows = raft::deserialize_scalar(handle, is); + std::int64_t dim = raft::deserialize_scalar(handle, is); + auto metric = raft::deserialize_scalar(handle, is); + auto metric_arg = raft::deserialize_scalar(handle, is); + + auto dataset_storage = raft::make_host_matrix(std::int64_t{}, std::int64_t{}); + auto include_dataset = raft::deserialize_scalar(handle, is); + if (include_dataset) { + dataset_storage = raft::make_host_matrix(rows, dim); + raft::deserialize_mdspan(handle, is, dataset_storage.view()); + } + + auto has_norms = raft::deserialize_scalar(handle, is); + auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} + : std::optional>{}; + // TODO(wphicks): Use mdbuffer here when available + auto norms_storage_dev = + has_norms ? std::optional{raft::make_device_vector(handle, rows)} + : std::optional>{}; + if (has_norms) { + raft::deserialize_mdspan(handle, is, norms_storage->view()); + raft::copy(handle, norms_storage_dev->view(), norms_storage->view()); + } + + auto result = index(handle, + raft::make_const_mdspan(dataset_storage.view()), + std::move(norms_storage_dev), + metric, + metric_arg); + raft::resource::sync_stream(handle); + + return result; +} + +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index) +{ + auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; + RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); + + *index = deserialize(handle, is); +} + +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index) +{ + auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; + RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); + + *index = deserialize(handle, is); +} + +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index) +{ + *index = deserialize(handle, is); +} + +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index) +{ + *index = deserialize(handle, is); +} + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index c2afa4e8b..03d6e820c 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -114,12 +114,28 @@ class AnnBruteForceTest : public ::testing::TestWithParam(handle_); + brute_force::deserialize(handle_, std::string{"brute_force_index"}, &index_loaded); + brute_force::search(handle_, - idx, + index_loaded, search_queries_view, indices_out_view, dists_out_view, cuvs::neighbors::filtering::none_sample_filter{}); + raft::resource::sync_stream(handle_); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(indices_naive_dev.data(), + indices_bruteforce_dev.data(), + distances_naive_dev.data(), + distances_bruteforce_dev.data(), + ps.num_queries, + ps.k, + 0.001f, + stream_, + true)); } } From 50b24ecb2efd25bdae5006cd493c04d9a6418442 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 14 Nov 2024 12:24:05 -0800 Subject: [PATCH 3/8] Add serialization API to C and Python + pytest --- cpp/include/cuvs/neighbors/brute_force.h | 45 ++++++++++++ cpp/src/neighbors/brute_force_c.cpp | 54 +++++++++++++- .../cuvs/neighbors/brute_force/__init__.py | 4 +- .../neighbors/brute_force/brute_force.pxd | 8 +++ .../neighbors/brute_force/brute_force.pyx | 71 +++++++++++++++++++ python/cuvs/cuvs/test/test_serialization.py | 38 +++++++--- 6 files changed, 206 insertions(+), 14 deletions(-) diff --git a/cpp/include/cuvs/neighbors/brute_force.h b/cpp/include/cuvs/neighbors/brute_force.h index c9e172f62..0e134eb59 100644 --- a/cpp/include/cuvs/neighbors/brute_force.h +++ b/cpp/include/cuvs/neighbors/brute_force.h @@ -166,6 +166,51 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, * @} */ +/** + * @defgroup bruteforce_c_serialize BRUTEFORCE C-API serialize functions + * @{ + */ +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.c} + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // create an index with `cuvsBruteforceBuild` + * cuvsBruteForceSerialize(res, "/path/to/index", index, true); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the file name for saving the index + * @param[in] index BRUTEFORCE index + * + */ +cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the name of the file that stores the index + * @param[out] index BRUTEFORCE index loaded disk + */ +cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index); + +/** + * @} + */ #ifdef __cplusplus } #endif diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index eda79aa31..8447bead2 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -91,6 +92,22 @@ void _search(cuvsResources_t res, } } +template +void _serialize(cuvsResources_t res, const char* filename, cuvsBruteForceIndex index) +{ + auto res_ptr = reinterpret_cast(res); + auto index_ptr = reinterpret_cast*>(index.addr); + cuvs::neighbors::brute_force::serialize(*res_ptr, std::string(filename), *index_ptr); +} + +template +void* _deserialize(cuvsResources_t res, const char* filename) +{ + auto res_ptr = reinterpret_cast(res); + auto index = new cuvs::neighbors::brute_force::index(*res_ptr); + cuvs::neighbors::brute_force::deserialize(*res_ptr, std::string(filename), index); + return index; +} } // namespace extern "C" cuvsError_t cuvsBruteForceIndexCreate(cuvsBruteForceIndex_t* index) @@ -129,7 +146,7 @@ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res, if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { index->addr = reinterpret_cast(_build(res, dataset_tensor, metric, metric_arg)); - index->dtype.code = kDLFloat; + index->dtype = dataset.dtype; } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, @@ -174,3 +191,38 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, } }); } + +extern "C" cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index) +{ + return cuvs::core::translate_exceptions([=] { + // read the numpy dtype from the beginning of the file + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename); } + char dtype_string[4]; + is.read(dtype_string, 4); + auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4)); + + index->dtype.bits = dtype.itemsize * 8; + if (dtype.kind == 'f' && dtype.itemsize == 4) { + index->dtype.code = kDLFloat; + index->addr = reinterpret_cast(_deserialize(res, filename)); + } else { + RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index) +{ + return cuvs::core::translate_exceptions([=] { + if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { + _serialize(res, filename, *index); + } else { + RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); + } + }); +} \ No newline at end of file diff --git a/python/cuvs/cuvs/neighbors/brute_force/__init__.py b/python/cuvs/cuvs/neighbors/brute_force/__init__.py index b88c4b464..6aa0e4bb2 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/__init__.py +++ b/python/cuvs/cuvs/neighbors/brute_force/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. -from .brute_force import Index, build, search +from .brute_force import Index, build, load, save, search -__all__ = ["Index", "build", "search"] +__all__ = ["Index", "build", "search", "save", "load"] diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd index 183827916..f1fc14ba7 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd @@ -47,3 +47,11 @@ cdef extern from "cuvs/neighbors/brute_force.h" nogil: DLManagedTensor* neighbors, DLManagedTensor* distances, cuvsFilter filter) except + + + cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char * filename, + cuvsBruteForceIndex_t index) except + + + cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, + const char * filename, + cuvsBruteForceIndex_t index) except + diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index 559302ccc..3201700b9 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -24,6 +24,7 @@ from cuvs.common.resources import auto_sync_resources from cython.operator cimport dereference as deref from libc.stdint cimport uint32_t from libcpp cimport bool +from libcpp.string cimport string from cuvs.common cimport cydlpack from cuvs.distance_type cimport cuvsDistanceType @@ -256,3 +257,73 @@ def search(Index index, )) return (distances, neighbors) + + +@auto_sync_resources +def save(filename, Index index, bool include_dataset=True, resources=None): + """ + Saves the index to a file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained IVF-PQ index. + {resources_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import ivf_pq + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset) + >>> # Serialize and deserialize the ivf_pq index built + >>> ivf_pq.save("my_index.bin", index) + >>> index_loaded = ivf_pq.load("my_index.bin") + """ + cdef string c_filename = filename.encode('utf-8') + cdef cuvsResources_t res = resources.get_c_obj() + check_cuvs(cuvsBruteForceSerialize(res, + c_filename.c_str(), + index.index)) + + +@auto_sync_resources +def load(filename, resources=None): + """ + Loads index from file. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of cuvs is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + {resources_docstring} + + Returns + ------- + index : Index + + """ + cdef Index idx = Index() + cdef cuvsResources_t res = resources.get_c_obj() + cdef string c_filename = filename.encode('utf-8') + + check_cuvs(cuvsBruteForceDeserialize( + res, + c_filename.c_str(), + idx.index + )) + idx.trained = True + return idx diff --git a/python/cuvs/cuvs/test/test_serialization.py b/python/cuvs/cuvs/test/test_serialization.py index 4ffccf121..1f4a54e87 100644 --- a/python/cuvs/cuvs/test/test_serialization.py +++ b/python/cuvs/cuvs/test/test_serialization.py @@ -17,7 +17,7 @@ import pytest from pylibraft.common import device_ndarray -from cuvs.neighbors import cagra, ivf_flat, ivf_pq +from cuvs.neighbors import brute_force, cagra, ivf_flat, ivf_pq from cuvs.test.ann_utils import generate_data @@ -35,6 +35,10 @@ def test_save_load_ivf_pq(): run_save_load(ivf_pq, np.float32) +def test_save_load_brute_force(): + run_save_load(brute_force, np.float32) + + def run_save_load(ann_module, dtype): n_rows = 10000 n_cols = 50 @@ -43,8 +47,11 @@ def run_save_load(ann_module, dtype): dataset = generate_data((n_rows, n_cols), dtype) dataset_device = device_ndarray(dataset) - build_params = ann_module.IndexParams() - index = ann_module.build(build_params, dataset_device) + if ann_module == brute_force: + index = ann_module.build(dataset_device) + else: + build_params = ann_module.IndexParams() + index = ann_module.build(build_params, dataset_device) assert index.trained filename = "my_index.bin" @@ -54,20 +61,29 @@ def run_save_load(ann_module, dtype): queries = generate_data((n_queries, n_cols), dtype) queries_device = device_ndarray(queries) - search_params = ann_module.SearchParams() k = 10 - - distance_dev, neighbors_dev = ann_module.search( - search_params, index, queries_device, k - ) + if ann_module == brute_force: + distance_dev, neighbors_dev = ann_module.search( + index, queries_device, k + ) + else: + search_params = ann_module.SearchParams() + distance_dev, neighbors_dev = ann_module.search( + search_params, index, queries_device, k + ) neighbors = neighbors_dev.copy_to_host() dist = distance_dev.copy_to_host() del index - distance_dev, neighbors_dev = ann_module.search( - search_params, loaded_index, queries_device, k - ) + if ann_module == brute_force: + distance_dev, neighbors_dev = ann_module.search( + loaded_index, queries_device, k + ) + else: + distance_dev, neighbors_dev = ann_module.search( + search_params, loaded_index, queries_device, k + ) neighbors2 = neighbors_dev.copy_to_host() dist2 = distance_dev.copy_to_host() From 2f5813ad936052e4f71271c9f2fcf095099c7b3b Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Nov 2024 08:53:30 -0800 Subject: [PATCH 4/8] Include documentation in Doxygen --- .gitignore | 1 + cpp/include/cuvs/neighbors/brute_force.hpp | 7 +++++++ docs/source/c_api/neighbors_bruteforce_c.rst | 8 ++++++++ docs/source/c_api/neighbors_hnsw_c.rst | 4 ++-- docs/source/c_api/neighbors_ivf_flat_c.rst | 8 ++++++++ docs/source/c_api/neighbors_ivf_pq_c.rst | 8 ++++++++ docs/source/cpp_api/neighbors_bruteforce.rst | 8 ++++++++ docs/source/python_api/neighbors_brute_force.rst | 10 ++++++++++ docs/source/python_api/neighbors_cagra.rst | 10 ++++++++++ docs/source/python_api/neighbors_hnsw.rst | 10 ++++++++++ docs/source/python_api/neighbors_ivf_flat.rst | 10 ++++++++++ docs/source/python_api/neighbors_ivf_pq.rst | 10 ++++++++++ 12 files changed, 92 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 97eab287d..da6eb07f6 100644 --- a/.gitignore +++ b/.gitignore @@ -75,6 +75,7 @@ compile_commands.json .clangd/ # serialized ann indexes +brute_force_index cagra_index ivf_flat_index ivf_pq_index diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index b355d7264..a998c30eb 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -379,7 +379,14 @@ void search(raft::resources const& handle, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = cuvs::neighbors::filtering::none_sample_filter{}); +/** + * @} + */ +/** + * @defgroup bruteforce_cpp_index_serialize Bruteforce index serialize functions + * @{ + */ /** * Save the index to file. * diff --git a/docs/source/c_api/neighbors_bruteforce_c.rst b/docs/source/c_api/neighbors_bruteforce_c.rst index af0356eee..a12175209 100644 --- a/docs/source/c_api/neighbors_bruteforce_c.rst +++ b/docs/source/c_api/neighbors_bruteforce_c.rst @@ -32,3 +32,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: bruteforce_c_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/c_api/neighbors_hnsw_c.rst b/docs/source/c_api/neighbors_hnsw_c.rst index 4d83cd3e3..988e5b6f3 100644 --- a/docs/source/c_api/neighbors_hnsw_c.rst +++ b/docs/source/c_api/neighbors_hnsw_c.rst @@ -29,13 +29,13 @@ Index Index search ------------ -.. doxygengroup:: cagra_c_index_search +.. doxygengroup:: hnsw_c_index_search :project: cuvs :members: :content-only: Index serialize ------------- +--------------- .. doxygengroup:: hnsw_c_index_serialize :project: cuvs diff --git a/docs/source/c_api/neighbors_ivf_flat_c.rst b/docs/source/c_api/neighbors_ivf_flat_c.rst index 9e1ccc0d1..1254d70ef 100644 --- a/docs/source/c_api/neighbors_ivf_flat_c.rst +++ b/docs/source/c_api/neighbors_ivf_flat_c.rst @@ -48,3 +48,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_flat_c_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/c_api/neighbors_ivf_pq_c.rst b/docs/source/c_api/neighbors_ivf_pq_c.rst index 070719609..260057b8c 100644 --- a/docs/source/c_api/neighbors_ivf_pq_c.rst +++ b/docs/source/c_api/neighbors_ivf_pq_c.rst @@ -48,3 +48,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_pq_c_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/cpp_api/neighbors_bruteforce.rst b/docs/source/cpp_api/neighbors_bruteforce.rst index 3adcb01c5..f75e26b3c 100644 --- a/docs/source/cpp_api/neighbors_bruteforce.rst +++ b/docs/source/cpp_api/neighbors_bruteforce.rst @@ -34,3 +34,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: bruteforce_cpp_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/python_api/neighbors_brute_force.rst b/docs/source/python_api/neighbors_brute_force.rst index 5fdc3658f..d756a6c80 100644 --- a/docs/source/python_api/neighbors_brute_force.rst +++ b/docs/source/python_api/neighbors_brute_force.rst @@ -20,3 +20,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.brute_force.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.brute_force.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.brute_force.load diff --git a/docs/source/python_api/neighbors_cagra.rst b/docs/source/python_api/neighbors_cagra.rst index 09b2e2694..e7155efb8 100644 --- a/docs/source/python_api/neighbors_cagra.rst +++ b/docs/source/python_api/neighbors_cagra.rst @@ -34,3 +34,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.cagra.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.cagra.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.cagra.load diff --git a/docs/source/python_api/neighbors_hnsw.rst b/docs/source/python_api/neighbors_hnsw.rst index 9922805b3..64fe5493b 100644 --- a/docs/source/python_api/neighbors_hnsw.rst +++ b/docs/source/python_api/neighbors_hnsw.rst @@ -28,3 +28,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.hnsw.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.hnsw.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.hnsw.load diff --git a/docs/source/python_api/neighbors_ivf_flat.rst b/docs/source/python_api/neighbors_ivf_flat.rst index 5514e5e43..f2c21e68a 100644 --- a/docs/source/python_api/neighbors_ivf_flat.rst +++ b/docs/source/python_api/neighbors_ivf_flat.rst @@ -32,3 +32,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.ivf_flat.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.ivf_flat.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.ivf_flat.load diff --git a/docs/source/python_api/neighbors_ivf_pq.rst b/docs/source/python_api/neighbors_ivf_pq.rst index e3625ba67..57668fbc3 100644 --- a/docs/source/python_api/neighbors_ivf_pq.rst +++ b/docs/source/python_api/neighbors_ivf_pq.rst @@ -32,3 +32,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.ivf_pq.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.ivf_pq.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.ivf_pq.load From 474ccbb7cec3d150f33ccd5816a59bc5e1959c7c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Nov 2024 14:56:21 -0800 Subject: [PATCH 5/8] fix include fstream --- cpp/src/neighbors/brute_force_serialize.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/neighbors/brute_force_serialize.cu b/cpp/src/neighbors/brute_force_serialize.cu index c427647f1..1b5b5111e 100644 --- a/cpp/src/neighbors/brute_force_serialize.cu +++ b/cpp/src/neighbors/brute_force_serialize.cu @@ -20,6 +20,8 @@ #include #include +#include + namespace cuvs::neighbors::brute_force { int constexpr serialization_version = 0; From f14a1456ab4e4448bbe6dc0020973c7efb4247a0 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Nov 2024 15:07:39 -0800 Subject: [PATCH 6/8] fix include fstream c --- cpp/src/neighbors/brute_force_c.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index 8447bead2..f1a8c995d 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include From 740a8f341aea7229c3b922de132fe9c04c0f7810 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 21 Nov 2024 05:38:22 -0800 Subject: [PATCH 7/8] Add examples of bruteforce serialization in python/c --- cpp/include/cuvs/neighbors/brute_force.h | 23 +++++++++-- cpp/include/cuvs/neighbors/brute_force.hpp | 40 +++++++++++-------- .../neighbors/brute_force/brute_force.pyx | 30 +++++++++----- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/cpp/include/cuvs/neighbors/brute_force.h b/cpp/include/cuvs/neighbors/brute_force.h index 0e134eb59..33b92f11b 100644 --- a/cpp/include/cuvs/neighbors/brute_force.h +++ b/cpp/include/cuvs/neighbors/brute_force.h @@ -172,8 +172,9 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, */ /** * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.c} * #include @@ -183,7 +184,7 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, * cuvsError_t res_create_status = cuvsResourcesCreate(&res); * * // create an index with `cuvsBruteforceBuild` - * cuvsBruteForceSerialize(res, "/path/to/index", index, true); + * cuvsBruteForceSerialize(res, "/path/to/index", index); * @endcode * * @param[in] res cuvsResources_t opaque C handle @@ -197,8 +198,22 @@ cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, /** * Load index from file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.c} + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); * - * Experimental, both the API and the serialization format are subject to change. + * // Deserialize an index previously built with `cuvsBruteforceBuild` + * cuvsBruteForceIndex_t index; + * cuvsBruteForceIndexCreate(&index); + * cuvsBruteForceDeserialize(res, "/path/to/index", index); + * @endcode * * @param[in] res cuvsResources_t opaque C handle * @param[in] filename the name of the file that stores the index diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index a998c30eb..af6841646 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -389,8 +389,9 @@ void search(raft::resources const& handle, */ /** * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -418,8 +419,9 @@ void serialize(raft::resources const& handle, bool include_dataset = true); /** * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -449,8 +451,9 @@ void serialize(raft::resources const& handle, /** * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -476,8 +479,9 @@ void serialize(raft::resources const& handle, /** * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -503,8 +507,9 @@ void serialize(raft::resources const& handle, /** * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -529,8 +534,9 @@ void deserialize(raft::resources const& handle, cuvs::neighbors::brute_force::index* index); /** * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -555,8 +561,9 @@ void deserialize(raft::resources const& handle, cuvs::neighbors::brute_force::index* index); /** * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include @@ -581,8 +588,9 @@ void deserialize(raft::resources const& handle, cuvs::neighbors::brute_force::index* index); /** * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. * * @code{.cpp} * #include diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index 471b16938..4fe19fa63 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -264,30 +264,31 @@ def save(filename, Index index, bool include_dataset=True, resources=None): """ Saves the index to a file. - Saving / loading the index is experimental. The serialization format is - subject to change. + The serialization format can be subject to changes, therefore loading + an index saved with a previous version of cuvs is not guaranteed + to work. Parameters ---------- filename : string Name of the file. index : Index - Trained IVF-PQ index. + Trained Brute Force index. {resources_docstring} Examples -------- >>> import cupy as cp - >>> from cuvs.neighbors import ivf_pq + >>> from cuvs.neighbors import brute_force >>> n_samples = 50000 >>> n_features = 50 >>> dataset = cp.random.random_sample((n_samples, n_features), ... dtype=cp.float32) >>> # Build index - >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset) - >>> # Serialize and deserialize the ivf_pq index built - >>> ivf_pq.save("my_index.bin", index) - >>> index_loaded = ivf_pq.load("my_index.bin") + >>> index = brute_force.build(dataset) + >>> # Serialize and deserialize the brute_force index built + >>> brute_force.save("my_index.bin", index) + >>> index_loaded = brute_force.load("my_index.bin") """ cdef string c_filename = filename.encode('utf-8') cdef cuvsResources_t res = resources.get_c_obj() @@ -301,9 +302,10 @@ def load(filename, resources=None): """ Loads index from file. - Saving / loading the index is experimental. The serialization format is - subject to change, therefore loading an index saved with a previous - version of cuvs is not guaranteed to work. + The serialization format can be subject to changes, therefore loading + an index saved with a previous version of cuvs is not guaranteed + to work. + Parameters ---------- @@ -315,6 +317,12 @@ def load(filename, resources=None): ------- index : Index + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import brute_force + >>> # Load an index previously built + >>> index = brute_force.load("my_index.bin") """ cdef Index idx = Index() cdef cuvsResources_t res = resources.get_c_obj() From 2bf9045ad1c7cb5aa9d1de7ed57d0b68e0a6bf21 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 21 Nov 2024 09:11:18 -0800 Subject: [PATCH 8/8] Fix bruteforce python example --- .../cuvs/cuvs/neighbors/brute_force/brute_force.pyx | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index 4fe19fa63..9d43bfb29 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -321,8 +321,15 @@ def load(filename, resources=None): -------- >>> import cupy as cp >>> from cuvs.neighbors import brute_force - >>> # Load an index previously built - >>> index = brute_force.load("my_index.bin") + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = brute_force.build(dataset) + >>> # Serialize and deserialize the brute_force index built + >>> brute_force.save("my_index.bin", index) + >>> index_loaded = brute_force.load("my_index.bin") """ cdef Index idx = Index() cdef cuvsResources_t res = resources.get_c_obj()