diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index b03b8214b..d03fe7b95 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -126,6 +126,23 @@ void add_node_core( raft::resource::get_cuda_stream(handle)); raft::resource::sync_stream(handle); + // Check search results + for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) { + std::uint32_t invalid_edges = 0; + for (std::uint32_t i = 0; i < base_degree; i++) { + if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; } + } + if (invalid_edges > 0) { + RAFT_LOG_WARN( + "Invalid edges found in search results " + "(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)", + (uint64_t)vec_i, + (uint64_t)invalid_edges, + (uint64_t)degree, + (uint64_t)base_degree); + } + } + // Step 2: rank-based reordering #pragma omp parallel { @@ -136,9 +153,13 @@ void add_node_core( for (std::uint32_t i = 0; i < base_degree; i++) { std::uint32_t detourable_node_count = 0; const auto a_id = host_neighbor_indices(vec_i, i); + if (a_id >= idx.size()) { + detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1); + continue; + } for (std::uint32_t j = 0; j < i; j++) { const auto b0_id = host_neighbor_indices(vec_i, j); - assert(b0_id < idx.size()); + if (b0_id >= idx.size()) { continue; } for (std::uint32_t k = 0; k < degree; k++) { const auto b1_id = updated_graph(b0_id, k); if (a_id == b1_id) { @@ -149,6 +170,7 @@ void add_node_core( } detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count); } + std::sort(detourable_node_count_list.begin(), detourable_node_count_list.end(), [&](const std::pair a, const std::pair b) { @@ -170,13 +192,18 @@ void add_node_core( const auto target_new_node_id = old_size + batch.offset() + vec_i; for (std::size_t i = 0; i < num_rev_edges; i++) { const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i); - + if (target_node_id >= new_size) { + RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id); + } IdxT replace_id = new_size; IdxT replace_id_j = 0; std::size_t replace_num_incoming_edges = 0; for (std::int32_t j = degree - 1; j >= static_cast(rev_edge_search_range); j--) { - const auto neighbor_id = updated_graph(target_node_id, j); + const auto neighbor_id = updated_graph(target_node_id, j); + if (neighbor_id >= new_size) { + RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id); + } const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id); if (num_incoming_edges > replace_num_incoming_edges) { // Check duplication @@ -195,10 +222,6 @@ void add_node_core( replace_id_j = j; } } - if (replace_id >= new_size) { - std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id); - return; - } updated_graph(target_node_id, replace_id_j) = target_new_node_id; rev_edges[i] = replace_id; } @@ -210,13 +233,15 @@ void add_node_core( const auto rank_based_list_ptr = updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree; const auto rev_edges_return_list_ptr = rev_edges.data(); - while (num_add < degree) { + while ((num_add < degree) && + ((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) { const auto node_list_ptr = interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr; auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i; const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges; for (; node_list_index < max_node_list_index; node_list_index++) { const auto candidate = node_list_ptr[node_list_index]; + if (candidate >= new_size) { continue; } // Check duplication bool dup = false; for (std::uint32_t j = 0; j < num_add; j++) { @@ -233,6 +258,12 @@ void add_node_core( } interleave_switch = 1 - interleave_switch; } + if (num_add < degree) { + RAFT_FAIL("Number of edges is not enough (target_new_node_id:%u, num_add:%u, degree:%u)", + target_new_node_id, + num_add, + degree); + } for (std::uint32_t i = 0; i < degree; i++) { updated_graph(target_new_node_id, i) = temp[i]; } @@ -248,7 +279,9 @@ void add_graph_nodes( raft::host_matrix_view updated_graph_view, const cagra::extend_params& params) { - assert(input_updated_dataset_view.extent(0) >= index.size()); + if (input_updated_dataset_view.extent(0) < index.size()) { + RAFT_FAIL("Updated dataset must be not smaller than the previous index state."); + } const std::size_t initial_dataset_size = index.size(); const std::size_t new_dataset_size = input_updated_dataset_view.extent(0); diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 5778d85a6..b4f701819 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res, using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; std::unique_ptr> plan = factory::create( - res, params, dataset_desc, queries.extent(1), graph.extent(1), topk); + res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk); plan->check(topk); diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 7ec3d4d9e..c20d58994 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( const IndexT* __restrict__ seed_ptr, // [num_seeds] const uint32_t num_seeds, IndexT* __restrict__ visited_hash_ptr, - const uint32_t hash_bitlen, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hash_ptr, + const uint32_t traversed_hash_bitlen, const uint32_t block_id = 0, const uint32_t num_blocks = 1) { @@ -145,14 +147,21 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); if (valid_i && lane_id == 0) { - if (best_index_team_local != raft::upper_bound() && - hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { - result_distances_ptr[i] = best_norm2_team_local; - result_indices_ptr[i] = best_index_team_local; - } else { - result_distances_ptr[i] = raft::upper_bound(); - result_indices_ptr[i] = raft::upper_bound(); + if (best_index_team_local != raft::upper_bound()) { + if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } else if ((traversed_hash_ptr != nullptr) && + hashmap::search( + traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { + // Deactivate this entry as it has been already used by otehrs. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } } + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; } } } @@ -168,13 +177,15 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( const uint32_t knn_k, // hashmap IndexT* __restrict__ visited_hashmap_ptr, - const uint32_t hash_bitlen, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, const IndexT* __restrict__ parent_indices, const IndexT* __restrict__ internal_topk_list, const uint32_t search_width) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; - constexpr IndexT invalid_index = raft::upper_bound(); + constexpr IndexT invalid_index = ~static_cast(0); // Read child indices of parents from knn graph and check if the distance // computaiton is necessary. @@ -186,7 +197,13 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; } if (child_id != invalid_index) { - if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { + if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + child_id = invalid_index; + } else if ((traversed_hashmap_ptr != nullptr) && + hashmap::search( + traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { + // Deactivate this entry as this has been already used by others. child_id = invalid_index; } } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index e6e7ff64f..d2ae5c55b 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -40,10 +40,11 @@ class factory { search_params const& params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) { - search_plan_impl_base plan(params, dim, graph_degree, topk); + search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk); return dispatch_kernel(res, plan, dataset_desc); } @@ -56,15 +57,15 @@ class factory { if (plan.algo == search_algo::SINGLE_CTA) { return std::make_unique< single_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else if (plan.algo == search_algo::MULTI_CTA) { return std::make_unique< multi_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else { return std::make_unique< multi_kernel_search::search>( - res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } } }; diff --git a/cpp/src/neighbors/detail/cagra/hashmap.hpp b/cpp/src/neighbors/detail/cagra/hashmap.hpp index 2c62dda90..652e1db22 100644 --- a/cpp/src/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/src/neighbors/detail/cagra/hashmap.hpp @@ -23,6 +23,8 @@ #include +#define HASHMAP_LINEAR_PROBING + // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored // #pragma GCC diagnostic pop @@ -38,11 +40,11 @@ RAFT_DEVICE_INLINE_FUNCTION void init(IdxT* const table, { if (threadIdx.x < FIRST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { - table[i] = utils::get_max_value(); + table[i] = ~static_cast(0); } } -template +template RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) @@ -50,7 +52,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, // Open addressing is used for collision resolution const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; -#if 1 +#ifdef HASHMAP_LINEAR_PROBING // Linear probing IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; @@ -59,32 +61,89 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, uint32_t index = key & bit_mask; const uint32_t stride = (key >> bitlen) * 2 + 1; #endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { - const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); - if (old == ~static_cast(0)) { + const IdxT old = atomicCAS(&table[index], hashval_empty, key); + if (old == hashval_empty) { return 1; } else if (old == key) { return 0; + } else if (SUPPORT_REMOVE) { + // Checks if this key has been removed before. + const uint32_t old = atomicCAS(&table[index], removed_key, key); + if (old == removed_key) { + return 1; + } else if (old == key) { + return 0; + } } index = (index + stride) & bit_mask; } return 0; } -template -RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, - const uint32_t bitlen, - const IdxT key) +template +RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key) { - IdxT ret = 0; - if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } - for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { - ret |= __shfl_xor_sync(0xffffffff, ret, offset); + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#ifdef HASHMAP_LINEAR_PROBING + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + for (unsigned i = 0; i < size; i++) { + const IdxT val = table[index]; + if (val == key) { + return 1; + } else if (val == hashval_empty) { + return 0; + } else if (SUPPORT_REMOVE) { + // Check if this key has been removed. + if (val == removed_key) { return 0; } + } + index = (index + stride) & bit_mask; } - return ret; + return 0; } template +RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key) +{ + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; +#ifdef HASHMAP_LINEAR_PROBING + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; +#else + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; +#endif + constexpr IdxT hashval_empty = ~static_cast(0); + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + for (unsigned i = 0; i < size; i++) { + // To remove a key, set the MSB to 1. + const uint32_t old = atomicCAS(&table[index], key, removed_key); + if (old == key) { + return 1; + } else if (old == hashval_empty) { + return 0; + } + index = (index + stride) & bit_mask; + } + return 0; +} + +template RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(unsigned team_size, IdxT* const table, const uint32_t bitlen, const IdxT key) { diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index ecfd856f1..8a97173fa 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -102,24 +102,24 @@ struct search : public search_plan_impl& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk), + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), intermediate_indices(res), intermediate_distances(res), topk_workspace(res) - { set_params(res, params); } void set_params(raft::resources const& res, const search_params& params) { - constexpr unsigned muti_cta_itopk_size = 32; - this->itopk_size = muti_cta_itopk_size; - search_width = 1; + constexpr unsigned multi_cta_itopk_size = 32; + this->itopk_size = multi_cta_itopk_size; + search_width = 1; num_cta_per_query = - max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size)); + max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size)); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -127,8 +127,9 @@ struct search : public search_plan_impl +template RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents( - INDEX_T* const next_parent_indices, // [search_width] - const uint32_t search_width, - INDEX_T* const itopk_indices, // [num_itopk] - const size_t num_itopk, - uint32_t* const terminate_flag) + INDEX_T* const next_parent_indices, // [num_parents] + const uint32_t num_parents, + INDEX_T* const itopk_indices, // [num_itopk] + DISTANCE_T* const itopk_distances, // [num_itopk] + const uint32_t num_itopk, // (*) num_itopk <= 32 + INDEX_T* const hash_ptr, + const uint32_t hash_bitlen) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const unsigned warp_id = threadIdx.x / 32; + constexpr INDEX_T invalid_index = ~static_cast(0); + + const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; - for (uint32_t i = lane_id; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - uint32_t max_itopk = num_itopk; - if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } - uint32_t num_new_parents = 0; - for (uint32_t j = lane_id; j < max_itopk; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < num_itopk) { - index = itopk_indices[j]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } + + // Initialize + if (lane_id < num_parents) { next_parent_indices[lane_id] = ~static_cast(0); } + INDEX_T index = invalid_index; + if (lane_id < num_itopk) { index = itopk_indices[lane_id]; } + + int is_candidate = 0; + if ((index != invalid_index) && ((index & index_msb_1_mask) == 0)) { + if (hashmap::search(hash_ptr, hash_bitlen, index)) { + // Deactivate nodes that have already been used by other CTAs. + itopk_indices[lane_id] = invalid_index; + itopk_distances[lane_id] = utils::get_max_value(); + index = invalid_index; + } else { + is_candidate = 1; } - const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = j; - itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node + } + + uint32_t num_next_parents = 0; + while (num_next_parents < num_parents) { + const uint32_t ballot_mask = __ballot_sync(0xffffffff, is_candidate); + int num_candidates = __popc(ballot_mask); + if (num_candidates == 0) { return; } + int is_found = 0; + if (is_candidate) { + const auto candidate_id = __popc(ballot_mask & ((1 << lane_id) - 1)); + if (candidate_id == 0) { + if (hashmap::insert(hash_ptr, hash_bitlen, index)) { + // Use this candidate as next parent + next_parent_indices[num_next_parents] = lane_id; + index |= index_msb_1_mask; // set most significant bit as used node + is_found = 1; + } else { + // Deactivate the node since it has been used by other CTA. + index = invalid_index; + itopk_distances[lane_id] = utils::get_max_value(); + } + itopk_indices[lane_id] = index; + is_candidate = 0; } } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } + if (__ballot_sync(0xffffffff, is_found)) { num_next_parents += 1; } } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } template @@ -116,21 +136,54 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort( val[i] = indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = ~static_cast(0); } } /* Warp Sort */ bitonic::warp_sort(key, val); - /* Store itopk sorted results */ + /* Store sorted results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; - if (j < num_itopk) { + if (j < num_elements) { distances[j] = key[i]; indices[j] = val[i]; } } } +template +RAFT_DEVICE_INLINE_FUNCTION void move_valid_entries_to_head( + INDEX_T* indices, // [num_elements] + DISTANCE_T* distances, // [num_elements] + const uint32_t num_elements // (*) num_elements must be multiple of 32 +) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr INDEX_T invalid_index = ~static_cast(0); + const uint32_t warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + const uint32_t lane_id = threadIdx.x % 32; + uint32_t offset = 0; + for (uint32_t i = lane_id; i < num_elements; i += 32) { + auto index = indices[i]; + auto distance = distances[i]; + bool is_valid = (index != invalid_index); + const auto mask = __ballot_sync(0xffffffff, is_valid); + const auto j = offset + __popc(mask & ((1 << lane_id) - 1)); + if ((index != invalid_index) && (j < i)) { + indices[j] = index; + distances[j] = distance; + } + offset += __popc(mask); + __syncwarp(); + } + for (uint32_t i = offset + lane_id; i < num_elements; i += 32) { + indices[i] = invalid_index; + distances[i] = utils::get_max_value(); + } + __syncwarp(); +} + // // multiple CTAs per single query // @@ -148,9 +201,10 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( const uint64_t rand_xor_mask, const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, + const uint32_t visited_hash_bitlen, typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const uint32_t hash_bitlen, + traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] + const uint32_t traversed_hash_bitlen, const uint32_t itopk_size, const uint32_t search_width, const uint32_t min_iteration, @@ -185,11 +239,11 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( extern __shared__ uint8_t smem[]; // Layout of result_buffer - // +----------------+------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | + // +----------------+-------------------------------+---------+ + // | internal_top_k | neighbors of parent nodes | padding | // | | | upto 32 | - // +----------------+------------------------------+---------+ - // |<--- result_buffer_size --->| + // +----------------+-------------------------------+---------+ + // |<--- result_buffer_size --->| const auto result_buffer_size = itopk_size + (search_width * graph_degree); const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); assert(result_buffer_size_32 <= MAX_ELEMENTS); @@ -201,22 +255,22 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( reinterpret_cast(smem + dataset_desc->smem_ws_size_in_bytes()); auto* __restrict__ result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto* __restrict__ parent_indices_buffer = + auto* __restrict__ local_visited_hashmap_ptr = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ terminate_flag = - reinterpret_cast(parent_indices_buffer + search_width); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); -#if 0 - /* debug */ - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - result_indices_buffer[i] = utils::get_max_value(); - result_distances_buffer[i] = utils::get_max_value(); - } -#endif + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); - if (threadIdx.x == 0) { terminate_flag[0] = 0; } - INDEX_T* const local_visited_hashmap_ptr = - visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + constexpr INDEX_T invalid_index = ~static_cast(0); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); __syncthreads(); _CLK_REC(clk_init); @@ -235,7 +289,9 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, block_id, num_blocks); __syncthreads(); @@ -243,49 +299,93 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( uint32_t iter = 0; while (1) { - // topk with bitonic sort _CLK_START(); + // Topk with bitonic sort (1st warp only) topk_by_bitonic_sort(result_distances_buffer, result_indices_buffer, itopk_size + (search_width * graph_degree), itopk_size); + __syncthreads(); _CLK_REC(clk_topk); - if (iter + 1 == max_iteration) { - __syncthreads(); - break; - } + if (iter + 1 >= max_iteration) { break; } - // pick up next parents _CLK_START(); - pickup_next_parents( - parent_indices_buffer, search_width, result_indices_buffer, itopk_size, terminate_flag); + if (threadIdx.x < 32) { + // [1st warp] Pick up next parents + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); +#if 0 + if (parent_indices_buffer[0] == invalid_index) { + // Try again if no parent is found + move_valid_entries_to_head(result_indices_buffer, + result_distances_buffer, + result_buffer_size_32); + pickup_next_parents(parent_indices_buffer, + search_width, + result_indices_buffer, + result_distances_buffer, + itopk_size, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } +#endif + } else { + // [Other warps] Reset visited hashmap + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); + } + __syncthreads(); _CLK_REC(clk_pickup_parents); + if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } + + if (threadIdx.x < 32) { + // [1st warp] Restore visited hashmap by putting itopk indices in it. + for (unsigned i = threadIdx.x; i < itopk_size; i += 32) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + } else { + // [Other warps] Remove entries kicked out of the itopk list from the + // traversed hash table. + for (unsigned i = threadIdx.x - 32; i < search_width * graph_degree; i += blockDim.x - 32) { + INDEX_T index = result_indices_buffer[itopk_size + i]; + if (index == invalid_index) { continue; } + if (index & index_msb_1_mask) { + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + } + } + } __syncthreads(); - if (*terminate_flag && iter >= min_iteration) { break; } - // compute the norms between child nodes and query node _CLK_START(); + // compute the norms between child nodes and query node device::compute_distance_to_child_nodes(result_indices_buffer + itopk_size, result_distances_buffer + itopk_size, *dataset_desc, knn_graph, graph_degree, local_visited_hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, parent_indices_buffer, result_indices_buffer, search_width); - _CLK_REC(clk_compute_distance); __syncthreads(); + _CLK_REC(clk_compute_distance); // Filtering if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { if (parent_indices_buffer[p] != invalid_index) { const auto parent_id = @@ -303,36 +403,65 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( iter++; } - // Post process for filtering + // Filtering if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[i] = utils::get_max_value(); + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + if (!sample_filter(query_id, index)) { result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); } } - - __syncthreads(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); __syncthreads(); } - for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit + // Output search results (1st warp only). + if (threadIdx.x < 32) { + uint32_t offset = 0; + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { + INDEX_T index = result_indices_buffer[i]; + bool is_valid = false; + if (index != invalid_index) { + if (index & index_msb_1_mask) { + is_valid = true; + index &= ~index_msb_1_mask; + } else if (hashmap::insert( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + // If a node that is not used as a parent can be inserted into + // the traversed hash table, it is considered a valid result. + is_valid = true; + } + } + const auto mask = __ballot_sync(0xffffffff, is_valid); + if (is_valid) { + const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); + if (j < itopk_size) { + uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = index & ~index_msb_1_mask; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = result_distances_buffer[i]; + } + } else if ((index & index_msb_1_mask) == 0) { + // If a node that was successfully inserted in the traversed + // hash table is not output as a result, the hash table is + // restored using hash remove. + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); + } + } + offset += __popc(mask); + if (offset >= itopk_size) break; + } + // If the number of outputs is insufficient, fill in with invalid results. + for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { + uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = invalid_index; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = utils::get_max_value(); + } + } } if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { @@ -427,8 +556,9 @@ void select_and_run(const dataset_descriptor_host& dat uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_seeds, SampleFilterT sample_filter, @@ -441,9 +571,13 @@ void select_and_run(const dataset_descriptor_host& dat RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Initialize hash table - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap_ptr, hash_size, utils::get_max_value(), hash_size, num_queries, stream); + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + num_queries, + stream); dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); @@ -463,8 +597,9 @@ void select_and_run(const dataset_descriptor_host& dat ps.rand_xor_mask, dev_seed_ptr, num_seeds, - hashmap_ptr, - hash_bitlen, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, ps.itopk_size, ps.search_width, ps.min_iterations, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh index 1a1dcd579..e5dc29f27 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -36,8 +36,9 @@ void select_and_run(const dataset_descriptor_host& dat uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_seeds, SampleFilterT sample_filter, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index c6fe21642..be92be999 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -635,9 +635,10 @@ struct search : search_plan_impl { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk), + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), result_indices(res), result_distances(res), parent_node_list(res), diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 99254aa50..7300c89c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -108,11 +108,17 @@ struct lightweight_uvector { }; struct search_plan_impl_base : public search_params { + int64_t dataset_size; int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) + search_plan_impl_base( + search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk) + : search_params(params), + dim(dim), + dataset_size(dataset_size), + graph_degree(graph_degree), + topk(topk) { if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); @@ -141,7 +147,6 @@ struct search_plan_impl : public search_plan_impl_base { size_t small_hash_bitlen; size_t small_hash_reset_interval; size_t hashmap_size; - uint32_t dataset_size; uint32_t result_buffer_size; uint32_t smem_size; @@ -157,9 +162,10 @@ struct search_plan_impl : public search_plan_impl_base { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : search_plan_impl_base(params, dim, graph_degree, topk), + : search_plan_impl_base(params, dim, dataset_size, graph_degree, topk), hashmap(res), num_executed_iterations(res), dev_seed(res), @@ -193,10 +199,16 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t _max_iterations = max_iterations; if (max_iterations == 0) { if (algo == search_algo::MULTI_CTA) { - _max_iterations = 1 + std::min(32 * 1.1, 32 + 10.0); // TODO(anaruse) + constexpr uint32_t mc_itopk_size = 32; + constexpr uint32_t mc_search_width = 1; + _max_iterations = mc_itopk_size / mc_search_width; } else { - _max_iterations = - 1 + std::min((itopk_size / search_width) * 1.1, (itopk_size / search_width) + 10.0); + _max_iterations = itopk_size / search_width; + } + int64_t num_reachable_nodes = 1; + while (num_reachable_nodes < dataset_size) { + num_reachable_nodes *= graph_degree / 2; + _max_iterations += 1; } } if (max_iterations < min_iterations) { _max_iterations = min_iterations; } @@ -219,88 +231,107 @@ struct search_plan_impl : public search_plan_impl_base { // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size inline void calc_hashmap_params(raft::resources const& res) { - // for multiple CTA search - uint32_t mc_num_cta_per_query = 0; - uint32_t mc_search_width = 0; - uint32_t mc_itopk_size = 0; - if (algo == search_algo::MULTI_CTA) { - mc_itopk_size = 32; - mc_search_width = 1; - mc_num_cta_per_query = max(search_width, raft::ceildiv(itopk_size, (size_t)32)); - RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); - RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); - RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); - } - // Determine hash size (bit length) hashmap_size = 0; hash_bitlen = 0; small_hash_bitlen = 0; small_hash_reset_interval = 1024 * 1024; float max_fill_rate = hashmap_max_fill_rate; - while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { - // - // The small-hash reduces hash table size by initializing the hash table - // for each iteration and re-registering only the nodes that should not be - // re-visited in that iteration. Therefore, the size of small-hash should - // be determined based on the internal topk size and the number of nodes - // visited per iteration. - // - const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); - unsigned min_bitlen = 8; // 256 - unsigned max_bitlen = 13; // 8K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - if (hash_bitlen > max_bitlen) { - // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. - if (hashmap_mode == hash_mode::AUTO) { - hash_bitlen = 0; - break; - } else { - RAFT_FAIL( - "small-hash cannot be used because the required hash size exceeds the limit (%u)", - hashmap::get_size(max_bitlen)); - } - } - small_hash_bitlen = hash_bitlen; + if (algo == search_algo::MULTI_CTA) { + const uint32_t mc_itopk_size = 32; + const uint32_t mc_num_cta_per_query = + max(search_width, raft::ceildiv(itopk_size, (size_t)mc_itopk_size)); + RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); + RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); // - // Sincc the hash table size is limited to a power of 2, the requirement, - // the maximum fill rate, may be satisfied even if the frequency of hash - // table reset is reduced to once every 2 or more iterations without - // changing the hash table size. In that case, reduce the reset frequency. + // [visited_hash_table] + // In the multi CTA algo, which node has been visited is managed in a hash + // table that each CTA has in the shared memory. This hash table is not + // shared among CTAs. This hash table is reset and restored in each iteration. // - small_hash_reset_interval = 1; - while (1) { - const auto max_visited_nodes = - itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); - if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } - small_hash_reset_interval += 1; + const uint32_t max_visited_nodes = mc_itopk_size + graph_degree; + small_hash_bitlen = 8; // 256 + while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { + small_hash_bitlen += 1; } - break; - } - if (hash_bitlen == 0) { + RAFT_EXPECTS(small_hash_bitlen <= 14, "small_hash_bitlen cannot be largen than 14 (16K)"); // - // The size of hash table is determined based on the maximum number of - // nodes that may be visited before the search is completed and the - // maximum fill rate of the hash table. + // [traversed_hash_table] + // Whether a node has ever been used as the starting point for a traversal + // in each iteration is managed in a separate hash table, which is shared + // among the CTAs. // - uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); - if (algo == search_algo::MULTI_CTA) { - max_visited_nodes = mc_itopk_size + (mc_search_width * graph_degree * max_iterations); - max_visited_nodes *= mc_num_cta_per_query; - } + const auto max_traversed_nodes = + mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); unsigned min_bitlen = 11; // 2K if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + while (max_traversed_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { hash_bitlen += 1; } - RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); + RAFT_EXPECTS(hash_bitlen <= 25, "hash_bitlen cannot be largen than 25 (32M)"); + } else { + while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { + // + // The small-hash reduces hash table size by initializing the hash table + // for each iteration and re-registering only the nodes that should not be + // re-visited in that iteration. Therefore, the size of small-hash should + // be determined based on the internal topk size and the number of nodes + // visited per iteration. + // + const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); + unsigned min_bitlen = 8; // 256 + unsigned max_bitlen = 13; // 8K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + if (hash_bitlen > max_bitlen) { + // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. + if (hashmap_mode == hash_mode::AUTO) { + hash_bitlen = 0; + break; + } else { + RAFT_FAIL( + "small-hash cannot be used because the required hash size exceeds the limit (%u)", + hashmap::get_size(max_bitlen)); + } + } + small_hash_bitlen = hash_bitlen; + // + // Sincc the hash table size is limited to a power of 2, the requirement, + // the maximum fill rate, may be satisfied even if the frequency of hash + // table reset is reduced to once every 2 or more iterations without + // changing the hash table size. In that case, reduce the reset frequency. + // + small_hash_reset_interval = 1; + while (1) { + const auto max_visited_nodes = + itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); + if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } + small_hash_reset_interval += 1; + } + break; + } + if (hash_bitlen == 0) { + // + // The size of hash table is determined based on the maximum number of + // nodes that may be visited before the search is completed and the + // maximum fill rate of the hash table. + // + uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); + unsigned min_bitlen = 11; // 2K + if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } + hash_bitlen = min_bitlen; + while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { + hash_bitlen += 1; + } + RAFT_EXPECTS(hash_bitlen <= 20, + "hash_bitlen cannot be largen than 20 (1M). You can decrease itopk_size, " + "search_width or max_iterations to reduce the required hashmap size."); + } } - RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size); RAFT_LOG_DEBUG("# parent size = %lu", search_width); RAFT_LOG_DEBUG("# min_iterations = %lu", min_iterations); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index fa71dbaf9..0911d440c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -94,9 +94,10 @@ struct search : search_plan_impl { search_params params, const dataset_descriptor_host& dataset_desc, int64_t dim, + int64_t dataset_size, int64_t graph_degree, uint32_t topk) - : base_type(res, params, dataset_desc, dim, graph_degree, topk) + : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk) { set_params(res); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 678ed0cb4..94c97ed16 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -622,7 +622,9 @@ __device__ void search_core( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen); + hash_bitlen, + (INDEX_T*)nullptr, + 0); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -749,6 +751,8 @@ __device__ void search_core( graph_degree, local_visited_hashmap_ptr, hash_bitlen, + (INDEX_T*)nullptr, + 0, parent_list_buffer, result_indices_buffer, search_width);