Skip to content

Commit

Permalink
Merge pull request #30 from nk2014yj/main
Browse files Browse the repository at this point in the history
Update puck_index.cpp
  • Loading branch information
nk2014yj authored Jun 4, 2024
2 parents 8acbc16 + 6534a7e commit 0f03b40
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
2 changes: 2 additions & 0 deletions puck/index_conf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ IndexConf::IndexConf() {
pq_data_file_name = index_path + "/" + FLAGS_pq_data_file_name;

index_type = IndexType::PUCK;
tinker_neighborhood = FLAGS_tinker_neighborhood;
tinker_construction = FLAGS_tinker_construction;
tinker_search_range = FLAGS_tinker_search_range;
}

Expand Down
2 changes: 2 additions & 0 deletions puck/index_conf.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ struct IndexConf {
std::string pq_data_file_name;
std::string index_path;
//tinker的检索参数
uint32_t tinker_neighborhood;
uint32_t tinker_construction;
uint32_t tinker_search_range;

IndexConf();
Expand Down
44 changes: 26 additions & 18 deletions puck/puck/puck_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ int PuckIndex::search(const Request* request, Response* response) {
LOG(ERROR) << "init search context has error.";
return -1;
}

context->set_request(request);
const float* feature = normalization(context.get(), request->feature);
//输出query与一级聚类中心的top-search-cell个ID和距离
int ret = search_nearest_coarse_cluster(context.get(), feature,
Expand Down Expand Up @@ -649,32 +649,40 @@ int PuckIndex::puck_assign(const ThreadParams& thread_params, uint32_t* cell_ass

for (auto quantized : quantizations) {
auto& cur_params = quantized->get_quantization_params();
std::unique_ptr<float[]> sub_residual(new float[FLAGS_thread_chunk_size * cur_params.lsq]);

//std::unique_ptr<float[]> sub_residual(new float[FLAGS_thread_chunk_size * cur_params.lsq]);
std::unique_ptr<float[]> pq_distance_table(new float[cur_params.nsq * cur_params.ks]);
for (uint32_t i = 0; i < real_thread_chunk_size; i++) {
float* cur_point_fea = chunk_points.get() + i * _conf.feature_dim;
quantized->get_dist_table(cur_point_fea, pq_distance_table.get());

u_int64_t true_point_id = cur_start_point_id + i;

auto* quantized_fea = quantized->get_quantized_feature(true_point_id);
quantized_fea += quantized->get_fea_offset();

for (uint32_t n = 0; n < (uint32_t)cur_params.nsq; ++n) {
float min_distance = std::sqrt(std::numeric_limits<float>::max());
const float* sub_dist_table = pq_distance_table.get() + n * cur_params.ks;

for (uint32_t k = 0; k < (uint32_t)cur_params.ks; ++k) {
if (sub_dist_table[k] < min_distance){
min_distance = sub_dist_table[k];
quantized_fea[n] = (unsigned char)k;
}
}
}
}

for (uint32_t k = 0; k < (uint32_t)cur_params.nsq; ++k) {
uint32_t cur_lsq = std::min(cur_params.lsq, cur_params.dim - k * cur_params.lsq);
memset(sub_residual.get(), 0, FLAGS_thread_chunk_size * cur_params.lsq * sizeof(float));

for (uint32_t i = 0; i < real_thread_chunk_size; i++) {
const float* cur_point_fea = chunk_points.get() + i * _conf.feature_dim;
memcpy(sub_residual.get() + i * cur_params.lsq, cur_point_fea + k * cur_params.lsq,
sizeof(float) * cur_lsq);
}

int distance_type = 2;
float* cur_pq_centroids = quantized->get_sub_coodbooks(k);
//knn_full_thread(distance_type, real_thread_chunk_size, cur_params.ks, cur_params.lsq, 1,
// cur_pq_centroids, sub_residual.get(), nullptr, pq_assign.get(), pq_distance.get(), 1);
nearest_center(cur_params.lsq, cur_pq_centroids, cur_params.ks, sub_residual.get(), real_thread_chunk_size,
pq_assign.get(), pq_distance.get());

for (uint32_t i = 0; i < real_thread_chunk_size; i++) {
u_int64_t true_point_id = cur_start_point_id + i;
//point i 在第k子空间对应的聚类中心id
auto* quantized_fea = quantized->get_quantized_feature(true_point_id);
quantized_fea += quantized->get_fea_offset();
quantized_fea[k] = (unsigned char)pq_assign.get()[i];
int cur_assign = pq_assign.get()[i];
int cur_assign = (int)quantized_fea[k];
float* cur_point_fea = chunk_points.get() + i * _conf.feature_dim;
//量化使用残差
cblas_saxpy(cur_lsq, -1.0,
Expand Down

0 comments on commit 0f03b40

Please sign in to comment.