Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] CAGRA filtering with BFKNN when sparsity matching threshold #378

Open
wants to merge 29 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0faf889
[Feat] CAGRA filtering with BFKNN when sparsity matching threshold
rhdong Oct 2, 2024
f3388f0
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 3, 2024
f14be71
revert: update_dataset on strided matrix
rhdong Oct 3, 2024
062ca87
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 4, 2024
a9fd8d8
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 22, 2024
8e27b74
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 28, 2024
5378827
Support strided matrix on queries & respond to the review comments
rhdong Oct 29, 2024
651387f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 29, 2024
757c222
fix a style issue
rhdong Oct 29, 2024
018879f
Merge remote-tracking branch 'rhdong/rhdong/cagra-bf' into rhdong/cag…
rhdong Oct 29, 2024
bddae7f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 30, 2024
caab88b
fix: don't invoke 'copy_with_padding' from `src/neighbors/detail`
rhdong Oct 30, 2024
bac646d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 31, 2024
f4c1922
optimize by review comments
rhdong Oct 31, 2024
0dc10a2
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 13, 2024
a73ba1f
move calling down to branch & replace copy_with_padding
rhdong Nov 14, 2024
2552d8d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
ef734d4
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
0036127
fix: RAFT_LOG_DEBUG %f for double & other optimization
rhdong Nov 15, 2024
2876506
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
b5dcc02
benchmark: support pre-filter on CAGRA
rhdong Nov 18, 2024
5c9c5de
adjust the kernel selection condition to be 0.9f
rhdong Nov 18, 2024
d190b9d
expose the threshold-to-bf to callers & test cases
rhdong Nov 19, 2024
9aa1bb1
move the threshold-to-bf into search_params
rhdong Nov 19, 2024
a0fba17
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
e29d74d
skip the test on half when cusparse version is unsupported.
rhdong Nov 20, 2024
4d0fc8e
Revert "benchmark: support pre-filter on CAGRA"
rhdong Nov 20, 2024
1bcba66
if (params.threshold_to_bf >= 1.0) { return false; }
rhdong Nov 20, 2024
0cafa23
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's perhaps remove this from the benchmark wrapper for now?
The reason I'm suggesting this, is that filter creation should probably be a part of common benchmarking infrastructure rather than specific for CAGRA and, therefore, is a little out of the scope of this PR.

achirkin marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <iostream>
#include <memory>
#include <optional>
#include <random>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand All @@ -52,10 +53,13 @@ namespace cuvs::bench {
enum class AllocatorType { kHostPinned, kHostHugePage, kDevice };
enum class CagraBuildAlgo { kAuto, kIvfPq, kNnDescent };

constexpr double sparsity = 0.0f;

template <typename T, typename IdxT>
class cuvs_cagra : public algo<T>, public algo_gpu {
public:
using search_param_base = typename algo<T>::search_param;
// TODO: Move to arguments

struct search_param : public search_param_base {
cuvs::neighbors::cagra::search_params p;
Expand Down Expand Up @@ -91,6 +95,40 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
}
};

int64_t create_sparse_bitset(int64_t total, float sparsity, std::vector<uint32_t>& bitset) const
{
int64_t num_ones = static_cast<int64_t>((total * 1.0f) * (1.0f - sparsity));
int64_t res = num_ones;

for (auto& item : bitset) {
item = static_cast<uint32_t>(0);
}

if (sparsity == 0.0) {
for (auto& item : bitset) {
item = static_cast<uint32_t>(0xffffffff);
}
return total;
}

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int64_t> dis(0, total - 1);

while (num_ones > 0) {
int64_t index = dis(gen);

uint32_t& element = bitset[index / (8 * sizeof(uint32_t))];
int64_t bit_position = index % (8 * sizeof(uint32_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<uint32_t>(1) << bit_position);
num_ones--;
}
}
return res;
}

cuvs_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1)
: algo<T>(metric, dim),
index_params_(param),
Expand All @@ -102,8 +140,9 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
std::move(raft::make_device_matrix<IdxT, int64_t>(handle_, 0, 0)))),
input_dataset_v_(
std::make_shared<raft::device_matrix_view<const T, int64_t, raft::row_major>>(
nullptr, 0, 0))

nullptr, 0, 0)),
bitset_filter_(std::make_shared<cuvs::core::bitset<std::uint32_t, int64_t>>(
std::move(cuvs::core::bitset<std::uint32_t, int64_t>(handle_, 0, false))))
{
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);
Expand Down Expand Up @@ -171,6 +210,9 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
std::shared_ptr<raft::device_matrix<T, int64_t, raft::row_major>> dataset_;
std::shared_ptr<raft::device_matrix_view<const T, int64_t, raft::row_major>> input_dataset_v_;

// std::shared_ptr<raft::device_vector<std::uint32_t, int64_t>> bitset_filter_;
std::shared_ptr<cuvs::core::bitset<std::uint32_t, int64_t>> bitset_filter_;

inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type)
{
switch (mem_type) {
Expand Down Expand Up @@ -256,6 +298,15 @@ void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param)

need_dataset_update_ = false;
}

{ // create bitset filter in advance.
auto stream_ = raft::resource::get_cuda_stream(handle_);
size_t filter_n_elements = size_t((input_dataset_v_->extent(0) + 31) / 32);
bitset_filter_->resize(handle_, input_dataset_v_->extent(0), false);
std::vector<std::uint32_t> bitset_cpu(filter_n_elements);
create_sparse_bitset(input_dataset_v_->extent(0), sparsity, bitset_cpu);
raft::copy(bitset_filter_->data(), bitset_cpu.data(), filter_n_elements, stream_);
}
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -328,8 +379,15 @@ void cuvs_cagra<T, IdxT>::search_base(const T* queries,
raft::make_device_matrix_view<IdxT, int64_t>(neighbors_idx_t, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

cuvs::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);
if constexpr ((std::is_same_v<T, float> || std::is_same_v<T, half>)&&sparsity >= 0.0f) {
auto filter = cuvs::neighbors::filtering::bitset_filter(bitset_filter_->view());
cuvs::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, filter);

} else {
cuvs::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);
}

if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) {
if (raft::get_device_for_address(neighbors) < 0 &&
Expand Down
20 changes: 16 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,8 @@ void extend(
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/

void search(raft::resources const& res,
Expand All @@ -1066,7 +1068,8 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);
achirkin marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Search ANN using the constructed index.
Expand All @@ -1083,6 +1086,8 @@ void search(raft::resources const& res,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
Expand All @@ -1091,7 +1096,8 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);

/**
* @brief Search ANN using the constructed index.
Expand All @@ -1108,6 +1114,8 @@ void search(raft::resources const& res,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
Expand All @@ -1116,7 +1124,8 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);

/**
* @brief Search ANN using the constructed index.
Expand All @@ -1133,6 +1142,8 @@ void search(raft::resources const& res,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
Expand All @@ -1141,7 +1152,8 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);

/**
* @}
Expand Down
23 changes: 17 additions & 6 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ index<T, IdxT> build(
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*
*/
template <typename T, typename IdxT, typename CagraSampleFilterT>
void search_with_filtering(raft::resources const& res,
Expand All @@ -301,7 +304,8 @@ void search_with_filtering(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
double threshold_to_bf = 0.9)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
Expand All @@ -322,8 +326,14 @@ void search_with_filtering(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view<float, int64_t, raft::row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(res,
params,
idx,
queries_internal,
neighbors_internal,
distances_internal,
sample_filter,
threshold_to_bf);
}

template <typename T, typename IdxT>
Expand All @@ -333,14 +343,15 @@ void search(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
const cuvs::neighbors::filtering::base_filter& sample_filter_ref,
double threshold_to_bf = 0.9)
{
try {
using none_filter_type = cuvs::neighbors::filtering::none_sample_filter;
auto& sample_filter = dynamic_cast<const none_filter_type&>(sample_filter_ref);
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, none_filter_type>(
res, params, idx, queries, neighbors, distances, sample_filter_copy);
res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf);
return;
} catch (const std::bad_cast&) {
}
Expand All @@ -351,7 +362,7 @@ void search(raft::resources const& res,
sample_filter_ref);
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, decltype(sample_filter_copy)>(
res, params, idx, queries, neighbors, distances, sample_filter_copy);
res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
}
Expand Down
23 changes: 12 additions & 11 deletions cpp/src/neighbors/cagra_search_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
}

CUVS_INST_CAGRA_SEARCH(float, uint32_t);
Expand Down
23 changes: 12 additions & 11 deletions cpp/src/neighbors/cagra_search_half.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
}

CUVS_INST_CAGRA_SEARCH(half, uint32_t);
Expand Down
23 changes: 12 additions & 11 deletions cpp/src/neighbors/cagra_search_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
#include <cuvs/neighbors/cagra.hpp>
namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
}

CUVS_INST_CAGRA_SEARCH(int8_t, uint32_t);
Expand Down
Loading
Loading