Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialization API to brute-force #461

Open
wants to merge 12 commits into
base: branch-24.12
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ compile_commands.json
.clangd/

# serialized ann indexes
brute_force_index
cagra_index
ivf_flat_index
ivf_pq_index
Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ if(BUILD_SHARED_LIBS)
src/distance/pairwise_distance.cu
src/distance/sparse_distance.cu
src/neighbors/brute_force.cu
src/neighbors/brute_force_serialize.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_half.cu
src/neighbors/cagra_build_int8.cu
Expand Down
60 changes: 60 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,66 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup bruteforce_c_serialize BRUTEFORCE C-API serialize functions
* @{
*/
/**
* Save the index to file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.c}
* #include <cuvs/neighbors/brute_force.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsBruteforceBuild`
* cuvsBruteForceSerialize(res, "/path/to/index", index);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index BRUTEFORCE index
*
*/
cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res,
const char* filename,
cuvsBruteForceIndex_t index);

/**
* Load index from file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.c}
* #include <cuvs/neighbors/brute_force.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // Deserialize an index previously built with `cuvsBruteforceBuild`
* cuvsBruteForceIndex_t index;
* cuvsBruteForceIndexCreate(&index);
* cuvsBruteForceDeserialize(res, "/path/to/index", index);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index BRUTEFORCE index loaded disk
*/
lowener marked this conversation as resolved.
Show resolved Hide resolved
cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res,
const char* filename,
cuvsBruteForceIndex_t index);

/**
* @}
*/
#ifdef __cplusplus
}
#endif
243 changes: 243 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ struct index : cuvs::neighbors::index {
index& operator=(index&&) = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& handle);

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
Expand Down Expand Up @@ -479,4 +487,239 @@ void search(raft::resources const& handle,
/**
* @}
*/

/**
* @defgroup bruteforce_cpp_index_serialize Bruteforce index serialize functions
* @{
*/
/**
* Save the index to file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);
/**
* Save the index to file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Load index from file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* Load index from input stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from input stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* @}
*/

} // namespace cuvs::neighbors::brute_force
15 changes: 15 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
#include <raft/core/copy.hpp>

namespace cuvs::neighbors::brute_force {

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res)
// this constructor is just for a temporary index, for use in the deserialization
// api. all the parameters here will get replaced with loaded values - that aren't
// necessarily known ahead of time before deserialization.
// TODO: do we even need a handle here - could just construct one?
: cuvs::neighbors::index(),
metric_(cuvs::distance::DistanceType::L2Expanded),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::nullopt),
metric_arg_(0)
{
}

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset,
Expand Down
Loading
Loading