Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] CAGRA filtering with BFKNN when sparsity matching threshold #378

Open
wants to merge 29 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0faf889
[Feat] CAGRA filtering with BFKNN when sparsity matching threshold
rhdong Oct 2, 2024
f3388f0
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 3, 2024
f14be71
revert: update_dataset on strided matrix
rhdong Oct 3, 2024
062ca87
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 4, 2024
a9fd8d8
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 22, 2024
8e27b74
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 28, 2024
5378827
Support strided matrix on queries & respond to the review comments
rhdong Oct 29, 2024
651387f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 29, 2024
757c222
fix a style issue
rhdong Oct 29, 2024
018879f
Merge remote-tracking branch 'rhdong/rhdong/cagra-bf' into rhdong/cag…
rhdong Oct 29, 2024
bddae7f
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 30, 2024
caab88b
fix: don't invoke 'copy_with_padding' from `src/neighbors/detail`
rhdong Oct 30, 2024
bac646d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Oct 31, 2024
f4c1922
optimize by review comments
rhdong Oct 31, 2024
0dc10a2
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 13, 2024
a73ba1f
move calling down to branch & replace copy_with_padding
rhdong Nov 14, 2024
2552d8d
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
ef734d4
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 14, 2024
0036127
fix: RAFT_LOG_DEBUG %f for double & other optimization
rhdong Nov 15, 2024
2876506
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
b5dcc02
benchmark: support pre-filter on CAGRA
rhdong Nov 18, 2024
5c9c5de
adjust the kernel selection condition to be 0.9f
rhdong Nov 18, 2024
d190b9d
expose the threshold-to-bf to callers & test cases
rhdong Nov 19, 2024
9aa1bb1
move the threshold-to-bf into search_params
rhdong Nov 19, 2024
a0fba17
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 19, 2024
e29d74d
skip the test on half when cusparse version is unsupported.
rhdong Nov 20, 2024
4d0fc8e
Revert "benchmark: support pre-filter on CAGRA"
rhdong Nov 20, 2024
1bcba66
if (params.threshold_to_bf >= 1.0) { return false; }
rhdong Nov 20, 2024
0cafa23
Merge branch 'branch-24.12' into rhdong/cagra-bf
rhdong Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,13 @@ struct index : cuvs::neighbors::index {
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_graph(res, knn_graph);
if constexpr (raft::is_device_mdspan_v<decltype(dataset)>) {
contiguous_dataset_ =
raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
} else {
contiguous_dataset_ =
raft::make_host_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
}

raft::resource::sync_stream(res);
}
Expand All @@ -417,13 +424,16 @@ struct index : cuvs::neighbors::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
contiguous_dataset_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::layout_stride> dataset)
{
contiguous_dataset_ =
achirkin marked this conversation as resolved.
Show resolved Hide resolved
raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.stride(0));
dataset_ = make_aligned_dataset(res, dataset, 16);
}

Expand All @@ -436,7 +446,8 @@ struct index : cuvs::neighbors::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
contiguous_dataset_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/**
Expand All @@ -447,14 +458,16 @@ struct index : cuvs::neighbors::index {
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
contiguous_dataset_ = std::monostate{};
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::move(dataset);
contiguous_dataset_ = std::monostate{};
dataset_ = std::move(dataset);
}

/**
Expand Down Expand Up @@ -492,11 +505,17 @@ struct index : cuvs::neighbors::index {
graph_view_ = graph_.view();
}

auto contiguous_dataset() const { return contiguous_dataset_; }

private:
cuvs::distance::DistanceType metric_;
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
std::variant<std::monostate,
raft::device_matrix_view<const T, int64_t, raft::row_major>,
raft::host_matrix_view<const T, int64_t, raft::row_major>>
contiguous_dataset_ = std::monostate{};
};
/**
* @}
Expand Down
115 changes: 115 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
#include "sample_filter_utils.cuh"
#include "search_plan.cuh"
#include "search_single_cta_inst.cuh"
#include "utils.hpp"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/unary_op.cuh>

#include <cuvs/distance/distance.hpp>

#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/cagra.hpp>

// TODO: Fix these when ivf methods are moved over
Expand Down Expand Up @@ -108,6 +111,109 @@ void search_main_core(raft::resources const& res,
}
}

/**
* @brief Performs ANN search using brute force when filter sparsity exceeds a specified threshold.
*
* This function switches to a brute force search approach to improve recall rate when the
* `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level
* (proportion of unfiltered samples) exceeding the specified `threshold_to_bf`.
*
* @tparam T data element type
* @tparam IdxT type of database vector indices
* @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need
* separate kernels for int/uint.
*
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-pq constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*
* @return true If the brute force search was applied successfully.
* @return false If the brute force search was not applied.
*/
template <typename T,
typename InternalIdxT,
typename CagraSampleFilterT,
typename IdxT = uint32_t,
typename DistanceT = float>
bool search_using_brute_force(
raft::resources const& res,
search_params& params,
achirkin marked this conversation as resolved.
Show resolved Hide resolved
const index<T, IdxT>& index,
raft::device_matrix_view<const T, int64_t, raft::row_major>& queries,
raft::device_matrix_view<InternalIdxT, int64_t, raft::row_major>& neighbors,
raft::device_matrix_view<DistanceT, int64_t, raft::row_major>& distances,
achirkin marked this conversation as resolved.
Show resolved Hide resolved
CagraSampleFilterT& sample_filter,
double threshold_to_bf = 0.9)
Copy link
Contributor

@achirkin achirkin Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please make it possible to disable/enable this feature by the user (see #252 (comment) for the reasoning):

  1. Expose threshold_to_bf as CAGRA search parameter and set it to 1.0 by default there
  2. Add a check here: if threshold_to_bf >= 1.0 then disable further checks and proceed with CAGRA search immediately (i.e. no need to run the sparsity check).

Copy link
Member

@cjnolet cjnolet Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most users with a filter are going to be specifying the filter in batch, and will know the sparsity of the filter. I suggest instead of turning this feature off by default, we allow the user specified filter to know its own nnz unless updated.

Turning this off by default undermines the fundamental benefits of this feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most users are not specifying a filter, and when they do, it's expected the filter is going to be heavy. This should not impact all users.

{
achirkin marked this conversation as resolved.
Show resolved Hide resolved
bool is_applied = false;
auto n_queries = queries.extent(0);
auto n_dataset = index.size();

auto bitset_filter_view = sample_filter.bitset_view_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if the 2d bitmap isn't able to be converted to a 1d bitet without losing information?

auto dataset_view = index.contiguous_dataset();
achirkin marked this conversation as resolved.
Show resolved Hide resolved
auto sparsity = bitset_filter_view.sparsity(res);
Copy link
Member

@cjnolet cjnolet Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the number of positive bits in the bitmap also needed to compute this? But then we compute n_elements below. We should cache off the n_elements. We should consider making the bitmap immutable. That way the sparsity and the number of positive elements can be safely cached. This isn't something users are going to be updating, like ever.


// TODO: Support host dataset in `brute_force::build`
if (sparsity >= threshold_to_bf &&
std::holds_alternative<raft::device_matrix_view<const T, int64_t, raft::row_major>>(
dataset_view)) {
using bitmap_view_t = cuvs::core::bitmap_view<const uint32_t, int64_t>;

auto stream = raft::resource::get_cuda_stream(res);
auto bitmap_n_elements = bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries);

rmm::device_uvector<uint32_t> raw_bitmap(bitmap_n_elements, stream);
rmm::device_uvector<int64_t> raw_neighbors(neighbors.size(), stream);

bitset_filter_view.repeat(res, n_queries, raw_bitmap.data());

auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset);

auto brute_force_neighbors = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>(
raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1));
auto brute_force_dataset =
std::get_if<raft::device_matrix_view<const T, int64_t, raft::row_major>>(&dataset_view);

if (brute_force_dataset) {
RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity);
auto brute_force_idx =
cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric());

auto brute_force_queries = queries;
auto padding_queries = raft::make_device_matrix<T, int64_t>(res, 0, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto padding_queries = raft::make_device_matrix<T, int64_t>(res, 0, 0);
// Allocate the padded queries in the workspace resource
auto padding_queries = raft::make_device_mdarray<T, int64_t>(
res,
raft::resource::get_workspace_resource(res),
raft::make_extents<int64_t>(n_queries, dataset_view.extent(1)));
// Copy the queries and fill the padded elements with zeros
raft::linalg::map_offset(res,
padding_queries.view(),
[queries, stride = dataset_view.extent(1)] __device__(int64_t i) {
auto row_ix = i / stride;
auto el_ix = i % stride;
return el_ix < queries.extent(1) ? queries(row_ix, el_ix) : T{0};
});

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion! I found a way to allocate the queries in the workspace while keeping the reuse of copy_with_padding, which could help with clean code. If I missed something, please feel free to point it out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to stick to raft primitives and eventually remove copy_with_padding, because the latter is not a commonly used primitive and thus one requires to go read its code to understand what does it do exactly.
But I don't mind to keep it here for now.


// Happens when the original dataset is a strided matrix.
achirkin marked this conversation as resolved.
Show resolved Hide resolved
if (brute_force_dataset->extent(1) != queries.extent(1)) {
cuvs::neighbors::cagra::detail::copy_with_padding(res, padding_queries, queries);
brute_force_queries = raft::make_device_matrix_view<const T, int64_t, raft::row_major>(
padding_queries.data_handle(), padding_queries.extent(0), padding_queries.extent(1));
}
cuvs::neighbors::brute_force::search(
res,
brute_force_idx,
brute_force_queries,
brute_force_neighbors,
distances,
cuvs::neighbors::filtering::bitmap_filter(brute_force_filter));
raft::linalg::unaryOp(neighbors.data_handle(),
brute_force_neighbors.data_handle(),
neighbors.size(),
raft::cast_op<InternalIdxT>(),
raft::resource::get_cuda_stream(res));
is_applied = true;
}
}
return is_applied;
}

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -126,6 +232,7 @@ void search_main_core(raft::resources const& res,
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
*/
template <typename T,
typename InternalIdxT,
Expand All @@ -140,6 +247,14 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<DistanceT, int64_t, raft::row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
if constexpr (!std::is_same_v<CagraSampleFilterT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pull this out into a separate function that can be invoked here please? This search_main function is gettig pretty massive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cuvs::neighbors::filtering::none_sample_filter> &&
(std::is_same_v<T, float> || std::is_same_v<T, half>)) {
bool bf_search_done =
search_using_brute_force(res, params, index, queries, neighbors, distances, sample_filter);
if (bf_search_done) return;
}

auto stream = raft::resource::get_cuda_stream(res);
const auto& graph = index.graph();
auto graph_internal = raft::make_device_matrix_view<const InternalIdxT, int64_t, raft::row_major>(
Expand Down
Loading
Loading