Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] CAGRA filtering with BFKNN when sparsity matching threshold #378

Open
wants to merge 29 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0faf889
[Feat] CAGRA filtering with BFKNN when sparsity matching threshold
rhdong Oct 2, 2024
f3388f0
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 3, 2024
f14be71
revert: update_dataset on strided matrix
rhdong Oct 3, 2024
062ca87
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 4, 2024
a9fd8d8
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 22, 2024
8e27b74
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 28, 2024
5378827
Support strided matrix on queries & respond to the review comments
rhdong Oct 29, 2024
651387f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 29, 2024
757c222
fix a style issue
rhdong Oct 29, 2024
018879f
Merge remote-tracking branch 'rhdong/rhdong/cagra-bf' into rhdong/cag…
rhdong Oct 29, 2024
bddae7f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 30, 2024
caab88b
fix: don't invoke 'copy_with_padding' from `src/neighbors/detail`
rhdong Oct 30, 2024
bac646d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 31, 2024
f4c1922
optimize by review comments
rhdong Oct 31, 2024
0dc10a2
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 13, 2024
a73ba1f
move calling down to branch & replace copy_with_padding
rhdong Nov 14, 2024
2552d8d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
ef734d4
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
0036127
fix: RAFT_LOG_DEBUG %f for double & other optimization
rhdong Nov 15, 2024
2876506
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
b5dcc02
benchmark: support pre-filter on CAGRA
rhdong Nov 18, 2024
5c9c5de
adjust the kernel selection condition to be 0.9f
rhdong Nov 18, 2024
d190b9d
expose the threshold-to-bf to callers & test cases
rhdong Nov 19, 2024
9aa1bb1
move the threshold-to-bf into search_params
rhdong Nov 19, 2024
a0fba17
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
e29d74d
skip the test on half when cusparse version is unsupported.
rhdong Nov 20, 2024
4d0fc8e
Revert "benchmark: support pre-filter on CAGRA"
rhdong Nov 20, 2024
1bcba66
if (params.threshold_to_bf >= 1.0) { return false; }
rhdong Nov 20, 2024
0cafa23
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ struct search_params : cuvs::neighbors::search_params {
* impact on the throughput.
*/
float persistent_device_usage = 1.0;

/** A sparsity threshold; brute force is used when sparsity exceeds this threshold, in the range
* [0, 1] */
double threshold_to_bf = 0.9;
};

/**
Expand Down
120 changes: 120 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
#include "sample_filter_utils.cuh"
#include "search_plan.cuh"
#include "search_single_cta_inst.cuh"
#include "utils.hpp"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/unary_op.cuh>

#include <cuvs/distance/distance.hpp>

#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/cagra.hpp>

// TODO: Fix these when ivf methods are moved over
Expand Down Expand Up @@ -108,6 +111,114 @@ void search_main_core(raft::resources const& res,
}
}

/**
* @brief Performs ANN search using brute force when filter sparsity exceeds a specified threshold.
*
* This function switches to a brute force search approach to improve recall rate when the
* `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level
* (proportion of unfiltered samples) exceeding the specified `params.threshold_to_bf`.
*
* @tparam T data element type
* @tparam IdxT type of database vector indices
* @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need
* separate kernels for int/uint.
*
* @param[in] handle
* @param[in] params configure the search
* @param[in] strided_dataset CAGRA strided dataset
* @param[in] metric distance type
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
*
* @return true If the brute force search was applied successfully.
* @return false If the brute force search was not applied.
*/
template <typename T,
typename InternalIdxT,
typename CagraSampleFilterT,
typename IdxT = uint32_t,
typename DistanceT = float>
bool search_using_brute_force(
raft::resources const& res,
const search_params& params,
const strided_dataset<T, IdxT>& strided_dataset,
cuvs::distance::DistanceType metric,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, raft::row_major> distances,
CagraSampleFilterT& sample_filter)
{
achirkin marked this conversation as resolved.
Show resolved Hide resolved
if (params.threshold_to_bf >= 1.0) { return false; };

auto n_queries = queries.extent(0);
auto n_dataset = strided_dataset.n_rows();

auto bitset_filter_view = sample_filter.bitset_view_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if the 2d bitmap isn't able to be converted to a 1d bitet without losing information?

auto sparsity = bitset_filter_view.sparsity(res);
Copy link
Member

@cjnolet cjnolet Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the number of positive bits in the bitmap also needed to compute this? But then we compute n_elements below. We should cache off the n_elements. We should consider making the bitmap immutable. That way the sparsity and the number of positive elements can be safely cached. This isn't something users are going to be updating, like ever.


if (sparsity < params.threshold_to_bf) { return false; }

// TODO: Support host dataset in `brute_force::build`
RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%f", sparsity);
using bitmap_view_t = cuvs::core::bitmap_view<const uint32_t, int64_t>;

auto stream = raft::resource::get_cuda_stream(res);
auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries);

rmm::device_uvector<uint32_t> raw_bitmap(bitmap_n_elements, stream);
rmm::device_uvector<int64_t> raw_neighbors(neighbors.size(), stream);

bitset_filter_view.repeat(res, n_queries, raw_bitmap.data());

auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset);

auto brute_force_neighbors = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>(
raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1));
auto brute_force_dataset = raft::make_device_matrix_view<const T, int64_t, raft::row_major>(
strided_dataset.view().data_handle(), strided_dataset.n_rows(), strided_dataset.stride());

auto brute_force_idx = cuvs::neighbors::brute_force::build(res, brute_force_dataset, metric);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being called each and every time a user performs a search? There's overhead in this call, and this should cache off the built brute-force index because for many common distances this computes a set of norms.


auto brute_force_queries = queries;
auto padding_queries = raft::make_device_matrix<T, int64_t>(res, 0, 0);

// Happens when the original dataset is a strided matrix.
if (brute_force_dataset.extent(1) != queries.extent(1)) {
padding_queries = raft::make_device_mdarray<T, int64_t>(
res,
raft::resource::get_workspace_resource(res),
raft::make_extents<int64_t>(n_queries, brute_force_dataset.extent(1)));
// Copy the queries and fill the padded elements with zeros
raft::linalg::map_offset(
res,
padding_queries.view(),
[queries, stride = brute_force_dataset.extent(1)] __device__(int64_t i) {
auto row_ix = i / stride;
auto el_ix = i % stride;
return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0};
});
brute_force_queries = raft::make_device_matrix_view<const T, int64_t, raft::row_major>(
padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1));
}
cuvs::neighbors::brute_force::search(
res,
brute_force_idx,
brute_force_queries,
brute_force_neighbors,
distances,
cuvs::neighbors::filtering::bitmap_filter(brute_force_filter));
raft::linalg::unaryOp(neighbors.data_handle(),
brute_force_neighbors.data_handle(),
neighbors.size(),
raft::cast_op<InternalIdxT>(),
raft::resource::get_cuda_stream(res));
return true;
}

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -126,6 +237,7 @@ void search_main_core(raft::resources const& res,
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
*/
template <typename T,
typename InternalIdxT,
Expand All @@ -150,6 +262,14 @@ void search_main(raft::resources const& res,
// Dispatch search parameters based on the dataset kind.
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index.data());
strided_dset != nullptr) {
if constexpr (!std::is_same_v<CagraSampleFilterT,
cuvs::neighbors::filtering::none_sample_filter> &&
(std::is_same_v<T, float> || std::is_same_v<T, half>)) {
bool bf_search_done = search_using_brute_force(
res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter);
if (bf_search_done) return;
}

// Search using a plain (strided) row-major dataset
auto desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
res, params, *strided_dset, index.metric());
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,9 @@ void brute_force_search_filtered(
raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset));
float sparsity = (1.0f - (1.0f * nnz_h) / (1.0f * n_queries * n_dataset));

if (sparsity > 0.01f) {
if (sparsity < 0.9f) {
raft::resources stream_pool_handle(res);
raft::resource::set_cuda_stream(stream_pool_handle, stream);
auto idx_norm = idx.has_norms() ? const_cast<DistanceT*>(idx.norms().data_handle()) : nullptr;
Expand Down
123 changes: 101 additions & 22 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ namespace cuvs::neighbors::cagra {
namespace {

struct test_cagra_sample_filter {
static constexpr unsigned offset = 300;
inline _RAFT_HOST_DEVICE auto operator()(
// query index
const uint32_t query_ix,
// the index of the current sample inside the current inverted list
const uint32_t sample_ix) const
const uint32_t sample_ix,
const uint32_t offset) const
{
return sample_ix >= offset;
}
Expand Down Expand Up @@ -95,6 +95,39 @@ void RandomSuffle(raft::host_matrix_view<IdxT, int64_t> index)
}
}

template <typename T, typename data_accessor>
void copy_with_padding(
raft::resources const& res,
raft::device_matrix<T, int64_t, raft::row_major>& dst,
raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::row_major, data_accessor> src,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource())
{
size_t padded_dim = raft::round_up_safe<size_t>(src.extent(1) * sizeof(T), 16) / sizeof(T);

if ((dst.extent(0) != src.extent(0)) || (static_cast<size_t>(dst.extent(1)) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dst.size()) { dst = raft::make_device_matrix<T, int64_t>(res, 0, 0); }
dst =
raft::make_device_mdarray<T>(res, mr, raft::make_extents<int64_t>(src.extent(0), padded_dim));
}
if (dst.extent(1) == src.extent(1)) {
raft::copy(
dst.data_handle(), src.data_handle(), src.size(), raft::resource::get_cuda_stream(res));
} else {
// copy with padding
RAFT_CUDA_TRY(cudaMemsetAsync(
dst.data_handle(), 0, dst.size() * sizeof(T), raft::resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(),
sizeof(T) * dst.extent(1),
src.data_handle(),
sizeof(T) * src.extent(1),
sizeof(T) * src.extent(1),
src.extent(0),
cudaMemcpyDefault,
raft::resource::get_cuda_stream(res)));
}
}

template <typename DistanceT, typename DatatT, typename IdxT>
testing::AssertionResult CheckOrder(raft::host_matrix_view<IdxT, int64_t> index_test,
raft::host_matrix_view<DatatT, int64_t> dataset)
Expand Down Expand Up @@ -276,6 +309,8 @@ struct AnnCagraInputs {
bool include_serialized_dataset;
// std::optional<double>
double min_recall; // = std::nullopt;
double threshold_to_bf = 0.9;
uint32_t filter_offset = 300;
std::optional<float> ivf_pq_search_refine_ratio = std::nullopt;
std::optional<vpq_params> compression = std::nullopt;

Expand Down Expand Up @@ -702,21 +737,20 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
rmm::device_uvector<DistanceT> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim;
cuvs::neighbors::naive_knn<DistanceT, DataT, IdxT>(
handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - test_cagra_sample_filter::offset,
ps.dim,
ps.k,
ps.metric);
auto* database_filtered_ptr = database.data() + ps.filter_offset * ps.dim;
cuvs::neighbors::naive_knn<DistanceT, DataT, IdxT>(handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - ps.filter_offset,
ps.dim,
ps.k,
ps.metric);
raft::linalg::addScalar(indices_naive_dev.data(),
indices_naive_dev.data(),
IdxT(test_cagra_sample_filter::offset),
IdxT(ps.filter_offset),
queries_size,
stream_);
raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
Expand Down Expand Up @@ -755,9 +789,10 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {

index_params.compression = ps.compression;
cagra::search_params search_params;
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.threshold_to_bf = ps.threshold_to_bf;

// TODO: setting search_params.itopk_size here breaks the filter tests, but is required for
// k>1024 skip these tests until fixed
Expand All @@ -780,14 +815,25 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {

if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); }

auto dataset_padding = raft::make_device_matrix<DataT, int64_t>(handle_, 0, 0);
if ((sizeof(DataT) * ps.dim % 16) != 0) {
copy_with_padding(handle_, dataset_padding, database_view);
auto database_view = raft::make_device_strided_matrix_view<const DataT, int64_t>(
dataset_padding.data_handle(),
dataset_padding.extent(0),
ps.dim,
dataset_padding.extent(1));
index.update_dataset(handle_, database_view);
}

auto search_queries_view = raft::make_device_matrix_view<const DataT, int64_t>(
search_queries.data(), ps.n_queries, ps.dim);
auto indices_out_view =
raft::make_device_matrix_view<IdxT, int64_t>(indices_dev.data(), ps.n_queries, ps.k);
auto dists_out_view = raft::make_device_matrix_view<DistanceT, int64_t>(
distances_dev.data(), ps.n_queries, ps.k);
auto removed_indices =
raft::make_device_vector<int64_t, int64_t>(handle_, test_cagra_sample_filter::offset);
raft::make_device_vector<int64_t, int64_t>(handle_, ps.filter_offset);
thrust::sequence(
raft::resource::get_thrust_policy(handle_),
thrust::device_pointer_cast(removed_indices.data_handle()),
Expand All @@ -813,8 +859,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
bool unacceptable_node = false;
for (int q = 0; q < ps.n_queries; q++) {
for (int i = 0; i < ps.k; i++) {
const auto n = indices_Cagra[q * ps.k + i];
unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n);
const auto n = indices_Cagra[q * ps.k + i];
unacceptable_node =
unacceptable_node | !test_cagra_sample_filter()(q, n, ps.filter_offset);
}
}
EXPECT_FALSE(unacceptable_node);
Expand Down Expand Up @@ -1002,6 +1049,8 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{false, true},
{false},
{0.99},
{0.9},
{uint32_t(300)},
{1.0f, 2.0f, 3.0f});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -1028,6 +1077,36 @@ inline std::vector<AnnCagraInputs> generate_inputs()
return inputs;
}

const std::vector<AnnCagraInputs> inputs = generate_inputs();
inline std::vector<AnnCagraInputs> generate_bf_inputs()
{
// Add test cases for brute force as sparsity >= 0.9.
std::vector<AnnCagraInputs> inputs_for_brute_force;
auto inputs_original = raft::util::itertools::product<AnnCagraInputs>(
{100},
{1000},
{1, 7, 8, 17},
{1, 16}, // k
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{0, 1, 10, 100},
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{true},
{1.0},
{0.1, 0.4, 0.8});
for (auto input : inputs_original) {
input.filter_offset = 0.5 * input.n_rows;
input.min_recall = input.threshold_to_bf <= 0.5 ? 1.0 : 0.6;
inputs_for_brute_force.push_back(input);
}

return inputs_for_brute_force;
}

const std::vector<AnnCagraInputs> inputs = generate_inputs();
const std::vector<AnnCagraInputs> inputs_brute_force = generate_bf_inputs();

} // namespace cuvs::neighbors::cagra
3 changes: 3 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest,
AnnCagraAddNodesTestF_U32,
::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest,
AnnCagraFilterTestF_U32,
::testing::ValuesIn(inputs_brute_force));

} // namespace cuvs::neighbors::cagra
7 changes: 7 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ namespace cuvs::neighbors::cagra {
typedef AnnCagraTest<float, half, std::uint32_t> AnnCagraTestF16_U32;
TEST_P(AnnCagraTestF16_U32, AnnCagra) { this->testCagra(); }

typedef AnnCagraFilterTest<float, half, std::uint32_t> AnnCagraFilterTestF16_U32;
TEST_P(AnnCagraFilterTestF16_U32, AnnCagra) { this->testCagra(); }

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF16_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF16_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest,
AnnCagraFilterTestF16_U32,
::testing::ValuesIn(inputs_brute_force));

} // namespace cuvs::neighbors::cagra
Loading
Loading