Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-CTA algorithm #492

Open
wants to merge 7 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ void add_node_core(
raft::resource::get_cuda_stream(handle));
raft::resource::sync_stream(handle);

// Check search results
for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) {
std::uint32_t invalid_edges = 0;
for (std::uint32_t i = 0; i < base_degree; i++) {
if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; }
}
if (invalid_edges > 0) {
RAFT_LOG_WARN(
"Invalid edges found in search results "
"(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)",
(uint64_t)vec_i,
(uint64_t)invalid_edges,
(uint64_t)degree,
(uint64_t)base_degree);
}
}

// Step 2: rank-based reordering
#pragma omp parallel
{
Expand All @@ -136,9 +153,13 @@ void add_node_core(
for (std::uint32_t i = 0; i < base_degree; i++) {
std::uint32_t detourable_node_count = 0;
const auto a_id = host_neighbor_indices(vec_i, i);
if (a_id >= idx.size()) {
detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1);
continue;
}
for (std::uint32_t j = 0; j < i; j++) {
const auto b0_id = host_neighbor_indices(vec_i, j);
assert(b0_id < idx.size());
if (b0_id >= idx.size()) { continue; }
for (std::uint32_t k = 0; k < degree; k++) {
const auto b1_id = updated_graph(b0_id, k);
if (a_id == b1_id) {
Expand All @@ -149,6 +170,7 @@ void add_node_core(
}
detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count);
}

std::sort(detourable_node_count_list.begin(),
detourable_node_count_list.end(),
[&](const std::pair<IdxT, std::size_t> a, const std::pair<IdxT, std::size_t> b) {
Expand All @@ -170,13 +192,18 @@ void add_node_core(
const auto target_new_node_id = old_size + batch.offset() + vec_i;
for (std::size_t i = 0; i < num_rev_edges; i++) {
const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i);

if (target_node_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id);
}
IdxT replace_id = new_size;
IdxT replace_id_j = 0;
std::size_t replace_num_incoming_edges = 0;
for (std::int32_t j = degree - 1; j >= static_cast<std::int32_t>(rev_edge_search_range);
j--) {
const auto neighbor_id = updated_graph(target_node_id, j);
const auto neighbor_id = updated_graph(target_node_id, j);
if (neighbor_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id);
}
const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id);
if (num_incoming_edges > replace_num_incoming_edges) {
// Check duplication
Expand All @@ -195,10 +222,6 @@ void add_node_core(
replace_id_j = j;
}
}
if (replace_id >= new_size) {
std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id);
return;
}
updated_graph(target_node_id, replace_id_j) = target_new_node_id;
rev_edges[i] = replace_id;
}
Expand All @@ -210,13 +233,15 @@ void add_node_core(
const auto rank_based_list_ptr =
updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree;
const auto rev_edges_return_list_ptr = rev_edges.data();
while (num_add < degree) {
while ((num_add < degree) &&
((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) {
const auto node_list_ptr =
interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr;
auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i;
const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges;
for (; node_list_index < max_node_list_index; node_list_index++) {
const auto candidate = node_list_ptr[node_list_index];
if (candidate >= new_size) { continue; }
// Check duplication
bool dup = false;
for (std::uint32_t j = 0; j < num_add; j++) {
Expand All @@ -233,6 +258,12 @@ void add_node_core(
}
interleave_switch = 1 - interleave_switch;
}
if (num_add < degree) {
RAFT_FAIL("Number of edges is not enough (target_new_node_id:%u, num_add:%u, degree:%u)",
target_new_node_id,
num_add,
degree);
}
for (std::uint32_t i = 0; i < degree; i++) {
updated_graph(target_new_node_id, i) = temp[i];
}
Expand All @@ -248,7 +279,9 @@ void add_graph_nodes(
raft::host_matrix_view<IdxT, std::int64_t> updated_graph_view,
const cagra::extend_params& params)
{
assert(input_updated_dataset_view.extent(0) >= index.size());
if (input_updated_dataset_view.extent(0) < index.size()) {
RAFT_FAIL("Updated dataset must be not smaller than the previous index state.");
}

const std::size_t initial_dataset_size = index.size();
const std::size_t new_dataset_size = input_updated_dataset_view.extent(0);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
39 changes: 28 additions & 11 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by otehrs.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}
Expand All @@ -168,13 +177,15 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
{
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value;
constexpr IndexT invalid_index = raft::upper_bound<IndexT>();
constexpr IndexT invalid_index = ~static_cast<IndexT>(0);

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
Expand All @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
87 changes: 73 additions & 14 deletions cpp/src/neighbors/detail/cagra/hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cstdint>

#define HASHMAP_LINEAR_PROBING

// #pragma GCC diagnostic push
// #pragma GCC diagnostic ignored
// #pragma GCC diagnostic pop
Expand All @@ -38,19 +40,19 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table,
{
if (threadIdx.x < FIRST_TID) return;
for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) {
table[i] = utils::get_max_value<IdxT>();
table[i] = ~static_cast<IdxT>(0);
}
}

template <class IdxT>
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
{
// Open addressing is used for collision resolution
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#if 1
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
Expand All @@ -59,32 +61,89 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
uint32_t index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT old = atomicCAS(&table[index], ~static_cast<IdxT>(0), key);
if (old == ~static_cast<IdxT>(0)) {
const IdxT old = atomicCAS(&table[index], hashval_empty, key);
if (old == hashval_empty) {
return 1;
} else if (old == key) {
return 0;
} else if (SUPPORT_REMOVE) {
// Checks if this key has been removed before.
const uint32_t old = atomicCAS(&table[index], removed_key, key);
if (old == removed_key) {
return 1;
} else if (old == key) {
return 0;
}
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <unsigned TEAM_SIZE, class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t bitlen,
const IdxT key)
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key)
{
IdxT ret = 0;
if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); }
for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) {
ret |= __shfl_xor_sync(0xffffffff, ret, offset);
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT val = table[index];
if (val == key) {
return 1;
} else if (val == hashval_empty) {
return 0;
} else if (SUPPORT_REMOVE) {
// Check if this key has been removed.
if (val == removed_key) { return 0; }
}
index = (index + stride) & bit_mask;
}
return ret;
return 0;
}

template <class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key)
{
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
// To remove a key, set the MSB to 1.
const uint32_t old = atomicCAS(&table[index], key, removed_key);
if (old == key) {
return 1;
} else if (old == hashval_empty) {
return 0;
}
index = (index + stride) & bit_mask;
}
return 0;
}

template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t
insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key)
{
Expand Down
Loading
Loading