From c61641c00317a9ca9d3d26fa619050bc471aad0e Mon Sep 17 00:00:00 2001 From: yangzq50 <58433399+yangzq50@users.noreply.github.com> Date: Mon, 21 Oct 2024 18:56:50 +0800 Subject: [PATCH] Support IVFPQ insert and query (#2083) ### What problem does this PR solve? Support IVFPQ insert Support IVFPQ query for l2, ip, cosine metric Issue link:#2077 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- src/function/table/knn_scan_data.cpp | 5 + src/function/table/knn_scan_data.cppm | 1 + .../knn_index/knn_ivf/ivf_index_storage.cpp | 2 +- .../knn_index/knn_ivf/ivf_index_storage.cppm | 7 +- .../knn_ivf/ivf_index_storage_parts.cpp | 188 +++++++++++++++--- 5 files changed, 170 insertions(+), 33 deletions(-) diff --git a/src/function/table/knn_scan_data.cpp b/src/function/table/knn_scan_data.cpp index 4d04611cf5..491d988dd9 100644 --- a/src/function/table/knn_scan_data.cpp +++ b/src/function/table/knn_scan_data.cpp @@ -38,6 +38,7 @@ namespace infinity { template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = GetSIMD_FUNCTIONS().L2Distance_func_ptr_; @@ -60,6 +61,7 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = GetSIMD_FUNCTIONS().HNSW_U8L2_ptr_; @@ -81,6 +83,7 @@ f32 hnsw_u8ip_f32_wrapper(const u8 *v1, const u8 *v2, SizeT dim) { return static template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = &hnsw_u8l2_f32_wrapper; @@ -103,6 +106,7 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = GetSIMD_FUNCTIONS().HNSW_I8L2_ptr_; @@ -124,6 +128,7 @@ f32 hnsw_i8ip_f32_wrapper(const i8 *v1, const i8 *v2, SizeT dim) { return static template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = &hnsw_i8l2_f32_wrapper; diff --git a/src/function/table/knn_scan_data.cppm b/src/function/table/knn_scan_data.cppm index 686dde6a50..cf2130a3fb 100644 --- a/src/function/table/knn_scan_data.cppm +++ b/src/function/table/knn_scan_data.cppm @@ -104,6 +104,7 @@ public: using DistFunc = DistType (*)(const QueryDataType *, const QueryDataType *, SizeT); DistFunc dist_func_{}; + KnnDistanceType dist_type_{}; }; template <> diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp b/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp index 62bdb3dfd0..a92eeef079 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp @@ -271,7 +271,7 @@ void IVF_Index_Storage::SearchIndex(const KnnDistanceBase1 *knn_distance, search_top_k_with_dis(nprobe, dimension, 1, query_f32_ptr, centroids_num, centroids_data, nprobe_result.data(), centroid_dists.get(), false); } for (const auto part_id : nprobe_result) { - ivf_parts_storage_->SearchIndex(part_id, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); + ivf_parts_storage_->SearchIndex(part_id, this, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); } } diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm b/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm index 91ad8849be..463225390e 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm @@ -29,6 +29,8 @@ namespace infinity { class LocalFileHandle; class KnnDistanceBase1; +export class IVF_Index_Storage; + // always use float for centroids class IVF_Centroids_Storage { u32 embedding_dimension_ = 0; @@ -68,6 +70,7 @@ public: AppendOneEmbedding(u32 part_id, const void *embedding_ptr, SegmentOffset segment_offset, const IVF_Centroids_Storage *ivf_centroids_storage) = 0; virtual void SearchIndex(u32 part_id, + const IVF_Index_Storage *ivf_index_storage, const KnnDistanceBase1 *knn_distance, const void *query_ptr, EmbeddingDataType query_element_type, @@ -75,7 +78,7 @@ public: const std::function &add_result_func) const = 0; }; -export class IVF_Index_Storage { +class IVF_Index_Storage { const IndexIVFOption ivf_option_ = {}; const LogicalType column_logical_type_ = LogicalType::kInvalid; const EmbeddingDataType embedding_data_type_ = EmbeddingDataType::kElemInvalid; @@ -94,6 +97,8 @@ public: [[nodiscard]] LogicalType column_logical_type() const { return column_logical_type_; } [[nodiscard]] EmbeddingDataType embedding_data_type() const { return embedding_data_type_; } [[nodiscard]] u32 embedding_dimension() const { return embedding_dimension_; } + [[nodiscard]] const IVF_Centroids_Storage &ivf_centroids_storage() const { return ivf_centroids_storage_; } + [[nodiscard]] const IVF_Parts_Storage &ivf_parts_storage() const { return *ivf_parts_storage_; } void Train(u32 training_embedding_num, const f32 *training_data, u32 expect_centroid_num = 0); void AddEmbedding(SegmentOffset segment_offset, const void *embedding_ptr); diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp b/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp index 5638883410..5348d3809d 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp @@ -37,6 +37,8 @@ import ivf_index_util_func; import mlas_matrix_multiply; import vector_distance; import index_base; +import knn_expr; +import simd_functions; namespace infinity { @@ -74,12 +76,12 @@ class IVF_Part_Storage { const IVF_Centroids_Storage *ivf_centroids_storage, const IVF_Parts_Storage *ivf_parts_storage) = 0; - virtual void SearchIndex(const KnnDistanceBase1 *knn_distance, + virtual void SearchIndex(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, const void *query_ptr, EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func, - const IVF_Parts_Storage *ivf_parts_storage) const = 0; + const std::function &add_result_func) const = 0; }; template @@ -114,6 +116,9 @@ class IVF_Parts_Storage_Info // size: real_subspace_centroid_num_ * subspace_num_ UniquePtr subspace_centroid_norms_neg_half_ = {}; +public: + auto real_subspace_centroid_num() const { return real_subspace_centroid_num_; } + const f32 *subspace_centroids_data_at_subspace(const u32 subspace_id) const { return subspace_centroids_data_.get() + subspace_id * subspace_dimension_ * real_subspace_centroid_num_; } @@ -128,7 +133,6 @@ class IVF_Parts_Storage_Info NON_CONST_VERSION_MEMBER_FUNC(subspace_centroids_data_at_subspace); NON_CONST_VERSION_MEMBER_FUNC(subspace_centroid_norms_neg_half_at_subspace); -public: IVF_Parts_Storage_Info(const u32 embedding_dim, const u32 centroids_num, const EmbeddingDataType embedding_data_type, @@ -249,6 +253,42 @@ class IVF_Parts_Storage_Info } } } + + void EncodeResidual(const f32 *residual, u32 *encode_output) const { + const auto xy_buffer = MakeUniqueForOverwrite(real_subspace_centroid_num_); + for (u32 j = 0; j < subspace_num_; ++j) { + matrixA_multiply_transpose_matrixB_output_to_C(residual + j * subspace_dimension_, + subspace_centroids_data_at_subspace(j), + 1, + real_subspace_centroid_num_, + subspace_dimension_, + xy_buffer.get()); + // find max id (for every embedding, find centroid with min l2 distance, and equivalently max (x*y - 0.5*y^2)) + const auto *c_norm_data = subspace_centroid_norms_neg_half_at_subspace(j); + f32 max_neg_distance = std::numeric_limits::lowest(); + u32 max_id = 0; + for (u32 k = 0; k < real_subspace_centroid_num_; ++k) { + if (const f32 neg_distance = xy_buffer[k] + c_norm_data[k]; neg_distance > max_neg_distance) { + max_neg_distance = neg_distance; + max_id = k; + } + } + encode_output[j] = max_id; + } + } + + UniquePtr GetIPTable(const f32 *query) const { + auto ip_table = MakeUniqueForOverwrite(subspace_num_ * real_subspace_centroid_num_); + for (u32 i = 0; i < subspace_num_; ++i) { + matrixA_multiply_matrixB_output_to_C(subspace_centroids_data_at_subspace(i), + query + i * subspace_dimension_, + real_subspace_centroid_num_, + 1, + subspace_dimension_, + ip_table.get() + i * real_subspace_centroid_num_); + } + return ip_table; + } }; template @@ -294,12 +334,14 @@ class IVF_Parts_Storage_T final : public IVF_Parts_Storage_Info { } void SearchIndex(const u32 part_id, + const IVF_Index_Storage *ivf_index_storage, const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, const std::function &add_result_func) const override { - return ivf_part_storages_[part_id]->SearchIndex(knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func, this); + return ivf_part_storages_[part_id] + ->SearchIndex(ivf_index_storage, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); } }; @@ -378,12 +420,12 @@ class IVF_Part_Storage_Plain final : public IVF_Part_Storage { ++embedding_num_; } - void SearchIndex(const KnnDistanceBase1 *knn_distance, + void SearchIndex(const IVF_Index_Storage *, + const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func, - const IVF_Parts_Storage *) const override { + const std::function &add_result_func) const override { auto ReturnT = [&] { if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || (query_element_type == src_embedding_data_type && @@ -617,26 +659,39 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { void AppendOneEmbedding(const void *embedding_ptr, const SegmentOffset segment_offset, - const IVF_Centroids_Storage *, - const IVF_Parts_Storage *) override { - const auto *src_embedding_data = static_cast(embedding_ptr); - (void)(src_embedding_data); - // TODO + const IVF_Centroids_Storage *ivf_centroids_storage, + const IVF_Parts_Storage *ivf_parts_storage) override { + const auto dimension = ivf_centroids_storage->embedding_dimension(); + const auto residual = MakeUniqueForOverwrite(dimension); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + { + const auto [src_embedding_f32, _] = GetF32Ptr(static_cast(embedding_ptr), dimension); + const auto centroid_data = ivf_centroids_storage->data() + part_id() * dimension; + for (u32 i = 0; i < dimension; ++i) { + residual[i] = src_embedding_f32[i] - centroid_data[i]; + } + } + const auto *ivf_parts_storage_info = + dynamic_cast *>(ivf_parts_storage); + assert(ivf_parts_storage_info); + ivf_parts_storage_info->EncodeResidual(residual.get(), encoded_codes.get()); + pq_code_storage_->AppendCodes(encoded_codes.get()); embedding_segment_offsets_.push_back(segment_offset); ++embedding_num_; } - void SearchIndex(const KnnDistanceBase1 *knn_distance, + void SearchIndex(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func, - const IVF_Parts_Storage *) const override { + const std::function &add_result_func) const override { auto ReturnT = [&] { if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || (query_element_type == src_embedding_data_type && (query_element_type == EmbeddingDataType::kElemInt8 || query_element_type == EmbeddingDataType::kElemUInt8))) { - return SearchIndexT(knn_distance, + return SearchIndexT(ivf_index_storage, + knn_distance, static_cast *>(query_ptr), satisfy_filter_func, add_result_func); @@ -661,7 +716,8 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { } template - void SearchIndexT(const KnnDistanceBase1 *knn_distance, + void SearchIndexT(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, const EmbeddingDataTypeToCppTypeT *query_ptr, const std::function &satisfy_filter_func, const std::function &add_result_func) const { @@ -670,19 +726,89 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { if (!knn_distance_1) [[unlikely]] { UnrecoverableError("Invalid KnnDistance1"); } - // TODO - // auto dist_func = knn_distance_1->dist_func_; - // const auto total_embedding_num = embedding_num(); - // for (u32 i = 0; i < total_embedding_num; ++i) { - // const auto segment_offset = embedding_segment_offset(i); - // if (!satisfy_filter_func(segment_offset)) { - // continue; - // } - // auto v_ptr = data_.data() + i * embedding_dimension(); - // auto [calc_ptr, _] = GetSearchCalcPtr(v_ptr, embedding_dimension()); - // auto d = dist_func(calc_ptr, query_ptr, embedding_dimension()); - // add_result_func(d, segment_offset); - // } + const auto &ivf_parts_storage = + static_cast &>(ivf_index_storage->ivf_parts_storage()); + const auto subspace_num = subspace_num_; + const auto real_subspace_centroid_num = ivf_parts_storage.real_subspace_centroid_num(); + const auto dimension = ivf_index_storage->embedding_dimension(); + const auto [query_f32, _] = GetF32Ptr(query_ptr, dimension); + const auto centroid_data = ivf_index_storage->ivf_centroids_storage().data() + part_id() * dimension; + const auto ip_func = GetSIMD_FUNCTIONS().IPDistance_func_ptr_; + switch (const KnnDistanceType dist_type = knn_distance_1->dist_type_; dist_type) { + case KnnDistanceType::kInnerProduct: { + const auto query_centroid_ip = ip_func(query_f32, centroid_data, dimension); + const auto ip_table = ivf_parts_storage.GetIPTable(query_f32); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + const auto total_embedding_num = embedding_num(); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + pq_code_storage_->ExtractCodes(i, encoded_codes.get()); + f32 d = query_centroid_ip; + for (u32 j = 0; j < subspace_num; ++j) { + d += ip_table[j * real_subspace_centroid_num + encoded_codes[j]]; + } + add_result_func(d, segment_offset); + } + break; + } + case KnnDistanceType::kCosine: { + const auto query_l2 = L2NormSquare(query_f32, dimension); + const auto centroid_l2 = L2NormSquare(centroid_data, dimension); + const auto query_centroid_ip = ip_func(query_f32, centroid_data, dimension); + const auto query_ip_table = ivf_parts_storage.GetIPTable(query_f32); + const auto centroid_ip_table = ivf_parts_storage.GetIPTable(centroid_data); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + const auto total_embedding_num = embedding_num(); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + pq_code_storage_->ExtractCodes(i, encoded_codes.get()); + f32 ip = query_centroid_ip; + f32 target_l2 = centroid_l2; + for (u32 j = 0; j < subspace_num; ++j) { + ip += query_ip_table[j * real_subspace_centroid_num + encoded_codes[j]]; + target_l2 -= 2.0f * (centroid_ip_table[j * real_subspace_centroid_num + encoded_codes[j]] + + ivf_parts_storage.subspace_centroid_norms_neg_half_at_subspace(j)[encoded_codes[j]]); + } + const auto d = ip / std::sqrt(query_l2 * target_l2); + add_result_func(d, segment_offset); + } + break; + } + case KnnDistanceType::kL2: { + const auto residual_query = MakeUniqueForOverwrite(dimension); + for (u32 i = 0; i < dimension; ++i) { + residual_query[i] = query_f32[i] - centroid_data[i]; + } + const auto residual_query_l2 = L2NormSquare(residual_query.get(), dimension); + const auto residual_ip_table = ivf_parts_storage.GetIPTable(residual_query.get()); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + const auto total_embedding_num = embedding_num(); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + pq_code_storage_->ExtractCodes(i, encoded_codes.get()); + f32 d = residual_query_l2; + for (u32 j = 0; j < subspace_num; ++j) { + d -= 2.0f * (residual_ip_table[j * real_subspace_centroid_num + encoded_codes[j]] + + ivf_parts_storage.subspace_centroid_norms_neg_half_at_subspace(j)[encoded_codes[j]]); + } + add_result_func(d, segment_offset); + } + break; + } + default: { + RecoverableError(Status::SyntaxError(fmt::format("IVFPQ does not support {} metric now.", KnnExpr::KnnDistanceType2Str(dist_type)))); + break; + } + } } };