Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialization API to brute-force #461

Merged
merged 12 commits into from
Nov 25, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ compile_commands.json
.clangd/

# serialized ann indexes
brute_force_index
cagra_index
ivf_flat_index
ivf_pq_index
Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ if(BUILD_SHARED_LIBS)
src/distance/distance.cu
src/distance/pairwise_distance.cu
src/neighbors/brute_force.cu
src/neighbors/brute_force_serialize.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_half.cu
src/neighbors/cagra_build_int8.cu
Expand Down
45 changes: 45 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,51 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup bruteforce_c_serialize BRUTEFORCE C-API serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.c}
* #include <cuvs/neighbors/brute_force.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsBruteforceBuild`
* cuvsBruteForceSerialize(res, "/path/to/index", index, true);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index BRUTEFORCE index
*
*/
cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res,
const char* filename,
cuvsBruteForceIndex_t index);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index BRUTEFORCE index loaded disk
*/
lowener marked this conversation as resolved.
Show resolved Hide resolved
cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res,
const char* filename,
cuvsBruteForceIndex_t index);

/**
* @}
*/
#ifdef __cplusplus
}
#endif
234 changes: 234 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ struct index : cuvs::neighbors::index {
index& operator=(index&&) = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& handle);

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
Expand Down Expand Up @@ -375,4 +383,230 @@ void search(raft::resources const& handle,
* @}
*/

/**
* @defgroup bruteforce_cpp_index_serialize Bruteforce index serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may be able to remove this now (everywhere). This API has been around for awhile in RAFT and in the other indexes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the mention of experimental, but I still kept the mention that the serialization format can be subject to changes to not have users expect retro-compatibility.

*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* @}
*/

} // namespace cuvs::neighbors::brute_force
15 changes: 15 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
#include <raft/core/copy.hpp>

namespace cuvs::neighbors::brute_force {

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res)
// this constructor is just for a temporary index, for use in the deserialization
// api. all the parameters here will get replaced with loaded values - that aren't
// necessarily known ahead of time before deserialization.
// TODO: do we even need a handle here - could just construct one?
: cuvs::neighbors::index(),
metric_(cuvs::distance::DistanceType::L2Expanded),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::nullopt),
metric_arg_(0)
{
}

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset,
Expand Down
Loading
Loading