Skip to content

Commit

Permalink
Support IVFPQ insert and query (#2083)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Support IVFPQ insert
Support IVFPQ query for l2, ip, cosine metric

Issue link:#2077

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
yangzq50 authored Oct 21, 2024
1 parent 9b8996a commit c61641c
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 33 deletions.
5 changes: 5 additions & 0 deletions src/function/table/knn_scan_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace infinity {

template <>
KnnDistance1<f32, f32>::KnnDistance1(KnnDistanceType dist_type) {
dist_type_ = dist_type;
switch (dist_type) {
case KnnDistanceType::kL2: {
dist_func_ = GetSIMD_FUNCTIONS().L2Distance_func_ptr_;
Expand All @@ -60,6 +61,7 @@ KnnDistance1<f32, f32>::KnnDistance1(KnnDistanceType dist_type) {

template <>
KnnDistance1<u8, i32>::KnnDistance1(KnnDistanceType dist_type) {
dist_type_ = dist_type;
switch (dist_type) {
case KnnDistanceType::kL2: {
dist_func_ = GetSIMD_FUNCTIONS().HNSW_U8L2_ptr_;
Expand All @@ -81,6 +83,7 @@ f32 hnsw_u8ip_f32_wrapper(const u8 *v1, const u8 *v2, SizeT dim) { return static

template <>
KnnDistance1<u8, f32>::KnnDistance1(KnnDistanceType dist_type) {
dist_type_ = dist_type;
switch (dist_type) {
case KnnDistanceType::kL2: {
dist_func_ = &hnsw_u8l2_f32_wrapper;
Expand All @@ -103,6 +106,7 @@ KnnDistance1<u8, f32>::KnnDistance1(KnnDistanceType dist_type) {

template <>
KnnDistance1<i8, i32>::KnnDistance1(KnnDistanceType dist_type) {
dist_type_ = dist_type;
switch (dist_type) {
case KnnDistanceType::kL2: {
dist_func_ = GetSIMD_FUNCTIONS().HNSW_I8L2_ptr_;
Expand All @@ -124,6 +128,7 @@ f32 hnsw_i8ip_f32_wrapper(const i8 *v1, const i8 *v2, SizeT dim) { return static

template <>
KnnDistance1<i8, f32>::KnnDistance1(KnnDistanceType dist_type) {
dist_type_ = dist_type;
switch (dist_type) {
case KnnDistanceType::kL2: {
dist_func_ = &hnsw_i8l2_f32_wrapper;
Expand Down
1 change: 1 addition & 0 deletions src/function/table/knn_scan_data.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public:
using DistFunc = DistType (*)(const QueryDataType *, const QueryDataType *, SizeT);

DistFunc dist_func_{};
KnnDistanceType dist_type_{};
};

template <>
Expand Down
2 changes: 1 addition & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void IVF_Index_Storage::SearchIndex(const KnnDistanceBase1 *knn_distance,
search_top_k_with_dis(nprobe, dimension, 1, query_f32_ptr, centroids_num, centroids_data, nprobe_result.data(), centroid_dists.get(), false);
}
for (const auto part_id : nprobe_result) {
ivf_parts_storage_->SearchIndex(part_id, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func);
ivf_parts_storage_->SearchIndex(part_id, this, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func);
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_storage.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace infinity {
class LocalFileHandle;
class KnnDistanceBase1;

export class IVF_Index_Storage;

// always use float for centroids
class IVF_Centroids_Storage {
u32 embedding_dimension_ = 0;
Expand Down Expand Up @@ -68,14 +70,15 @@ public:
AppendOneEmbedding(u32 part_id, const void *embedding_ptr, SegmentOffset segment_offset, const IVF_Centroids_Storage *ivf_centroids_storage) = 0;

virtual void SearchIndex(u32 part_id,
const IVF_Index_Storage *ivf_index_storage,
const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
const std::function<bool(SegmentOffset)> &satisfy_filter_func,
const std::function<void(f32, SegmentOffset)> &add_result_func) const = 0;
};

export class IVF_Index_Storage {
class IVF_Index_Storage {
const IndexIVFOption ivf_option_ = {};
const LogicalType column_logical_type_ = LogicalType::kInvalid;
const EmbeddingDataType embedding_data_type_ = EmbeddingDataType::kElemInvalid;
Expand All @@ -94,6 +97,8 @@ public:
[[nodiscard]] LogicalType column_logical_type() const { return column_logical_type_; }
[[nodiscard]] EmbeddingDataType embedding_data_type() const { return embedding_data_type_; }
[[nodiscard]] u32 embedding_dimension() const { return embedding_dimension_; }
[[nodiscard]] const IVF_Centroids_Storage &ivf_centroids_storage() const { return ivf_centroids_storage_; }
[[nodiscard]] const IVF_Parts_Storage &ivf_parts_storage() const { return *ivf_parts_storage_; }

void Train(u32 training_embedding_num, const f32 *training_data, u32 expect_centroid_num = 0);
void AddEmbedding(SegmentOffset segment_offset, const void *embedding_ptr);
Expand Down
Loading

0 comments on commit c61641c

Please sign in to comment.