From 6adc2515e49fc6934ed08c2171335507d23a02c0 Mon Sep 17 00:00:00 2001 From: nk2014yj Date: Mon, 11 Mar 2024 21:26:55 +0800 Subject: [PATCH 1/2] Update puck_index.cpp fix bug --- puck/puck/puck_index.cpp | 42 ++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/puck/puck/puck_index.cpp b/puck/puck/puck_index.cpp index f1ed642..a8c9f72 100644 --- a/puck/puck/puck_index.cpp +++ b/puck/puck/puck_index.cpp @@ -649,32 +649,40 @@ int PuckIndex::puck_assign(const ThreadParams& thread_params, uint32_t* cell_ass for (auto quantized : quantizations) { auto& cur_params = quantized->get_quantization_params(); - std::unique_ptr sub_residual(new float[FLAGS_thread_chunk_size * cur_params.lsq]); - + //std::unique_ptr sub_residual(new float[FLAGS_thread_chunk_size * cur_params.lsq]); + std::unique_ptr pq_distance_table(new float[cur_params.nsq * cur_params.ks]); + for (uint32_t i = 0; i < real_thread_chunk_size; i++) { + float* cur_point_fea = chunk_points.get() + i * _conf.feature_dim; + quantized->get_dist_table(cur_point_fea, pq_distance_table.get()); + + u_int64_t true_point_id = cur_start_point_id + i; + + auto* quantized_fea = quantized->get_quantized_feature(true_point_id); + quantized_fea += quantized->get_fea_offset(); + + for (uint32_t n = 0; n < (uint32_t)cur_params.nsq; ++n) { + float min_distance = std::sqrt(std::numeric_limits::max()); + const float* sub_dist_table = pq_distance_table.get() + n * cur_params.ks; + + for (uint32_t k = 0; k < (uint32_t)cur_params.ks; ++k) { + if (sub_dist_table[k] < min_distance){ + min_distance = sub_dist_table[k]; + quantized_fea[n] = (unsigned char)k; + } + } + } + } + for (uint32_t k = 0; k < (uint32_t)cur_params.nsq; ++k) { uint32_t cur_lsq = std::min(cur_params.lsq, cur_params.dim - k * cur_params.lsq); - memset(sub_residual.get(), 0, FLAGS_thread_chunk_size * cur_params.lsq * sizeof(float)); - - for (uint32_t i = 0; i < real_thread_chunk_size; i++) { - const float* cur_point_fea = chunk_points.get() + i * _conf.feature_dim; - memcpy(sub_residual.get() + i * cur_params.lsq, cur_point_fea + k * cur_params.lsq, - sizeof(float) * cur_lsq); - } - - int distance_type = 2; float* cur_pq_centroids = quantized->get_sub_coodbooks(k); - //knn_full_thread(distance_type, real_thread_chunk_size, cur_params.ks, cur_params.lsq, 1, - // cur_pq_centroids, sub_residual.get(), nullptr, pq_assign.get(), pq_distance.get(), 1); - nearest_center(cur_params.lsq, cur_pq_centroids, cur_params.ks, sub_residual.get(), real_thread_chunk_size, - pq_assign.get(), pq_distance.get()); for (uint32_t i = 0; i < real_thread_chunk_size; i++) { u_int64_t true_point_id = cur_start_point_id + i; //point i 在第k子空间对应的聚类中心id auto* quantized_fea = quantized->get_quantized_feature(true_point_id); quantized_fea += quantized->get_fea_offset(); - quantized_fea[k] = (unsigned char)pq_assign.get()[i]; - int cur_assign = pq_assign.get()[i]; + int cur_assign = (int)quantized_fea[k]; float* cur_point_fea = chunk_points.get() + i * _conf.feature_dim; //量化使用残差 cblas_saxpy(cur_lsq, -1.0, From 6534a7e4df5bad4e34cdc87c2e2ead38c18df0ed Mon Sep 17 00:00:00 2001 From: nk2014yj Date: Fri, 24 May 2024 14:52:17 +0800 Subject: [PATCH 2/2] fix bug --- puck/index_conf.cpp | 2 ++ puck/index_conf.h | 2 ++ puck/puck/puck_index.cpp | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/puck/index_conf.cpp b/puck/index_conf.cpp index 786389f..5dc68e4 100644 --- a/puck/index_conf.cpp +++ b/puck/index_conf.cpp @@ -91,6 +91,8 @@ IndexConf::IndexConf() { pq_data_file_name = index_path + "/" + FLAGS_pq_data_file_name; index_type = IndexType::PUCK; + tinker_neighborhood = FLAGS_tinker_neighborhood; + tinker_construction = FLAGS_tinker_construction; tinker_search_range = FLAGS_tinker_search_range; } diff --git a/puck/index_conf.h b/puck/index_conf.h index bdec9d1..4763898 100644 --- a/puck/index_conf.h +++ b/puck/index_conf.h @@ -69,6 +69,8 @@ struct IndexConf { std::string pq_data_file_name; std::string index_path; //tinker的检索参数 + uint32_t tinker_neighborhood; + uint32_t tinker_construction; uint32_t tinker_search_range; IndexConf(); diff --git a/puck/puck/puck_index.cpp b/puck/puck/puck_index.cpp index a8c9f72..27ccf46 100644 --- a/puck/puck/puck_index.cpp +++ b/puck/puck/puck_index.cpp @@ -568,7 +568,7 @@ int PuckIndex::search(const Request* request, Response* response) { LOG(ERROR) << "init search context has error."; return -1; } - + context->set_request(request); const float* feature = normalization(context.get(), request->feature); //输出query与一级聚类中心的top-search-cell个ID和距离 int ret = search_nearest_coarse_cluster(context.get(), feature,