Skip to content

Commit

Permalink
Support IVF search: Part 1 (#1954)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Initial support for knn search on IVF index

Issue link:#1917

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
yangzq50 authored Sep 30, 2024
1 parent 7af04e5 commit 51fdbfa
Show file tree
Hide file tree
Showing 8 changed files with 443 additions and 93 deletions.
26 changes: 25 additions & 1 deletion src/executor/operator/physical_scan/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ import segment_entry;
import abstract_hnsw;
import physical_match_tensor_scan;
import hnsw_alg;
import ivf_index_data_in_mem;
import ivf_index_data;
import ivf_index_search;

namespace infinity {

Expand Down Expand Up @@ -568,7 +571,28 @@ void PhysicalKnnScan::ExecuteInternalByColumnDataTypeAndQueryDataType(QueryConte
if (has_some_result) {
switch (segment_index_entry->table_index_entry()->index_base()->index_type_) {
case IndexType::kIVF: {
UnrecoverableError("Not supported now");
const SegmentOffset max_segment_offset = block_index->GetSegmentOffset(segment_id);
const auto ivf_search_params = IVF_Search_Params::Make(knn_scan_shared_data);
auto ivf_result_handler =
GetIVFSearchHandler<t, C, DistanceDataType>(ivf_search_params, use_bitmask, bitmask, max_segment_offset);
ivf_result_handler->Begin();
const auto [chunk_index_entries, memory_ivf_index] = segment_index_entry->GetIVFIndexSnapshot();
for (auto &chunk_index_entry : chunk_index_entries) {
if (chunk_index_entry->CheckVisible(txn)) {
BufferHandle index_handle = chunk_index_entry->GetIndex();
const auto *ivf_chunk = static_cast<const IVFIndexInChunk *>(index_handle.GetData());
ivf_result_handler->Search(ivf_chunk);
}
}
if (memory_ivf_index) {
ivf_result_handler->Search(memory_ivf_index.get());
}
auto [result_n, d_ptr, offset_ptr] = ivf_result_handler->EndWithoutSort();
auto row_ids = MakeUniqueForOverwrite<RowID[]>(result_n);
for (SizeT i = 0; i < result_n; ++i) {
row_ids[i] = RowID{segment_id, offset_ptr[i]};
}
merge_heap->Search(0, d_ptr.get(), row_ids.get(), result_n);
break;
}
case IndexType::kHnsw: {
Expand Down
4 changes: 3 additions & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_data.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ export class IVFIndexInChunk : protected IVF_Index_Storage {
public:
using IVF_Index_Storage::GetMemData;

IVF_Index_Storage *BasePtr() { return this; }
IVF_Index_Storage *GetIVFIndexStoragePtr() { return this; }

const IVF_Index_Storage *GetIVFIndexStoragePtr() const { return this; }

void BuildIVFIndex(RowID base_rowid,
u32 row_count,
Expand Down
22 changes: 21 additions & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_data_in_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,18 @@ class IVFIndexInMemT final : public IVFIndexInMem {
BufferHandle handle = new_chunk_index_entry->GetIndex();
auto *data_ptr = static_cast<IVFIndexInChunk *>(handle.GetDataMut());
data_ptr->GetMemData(std::move(*ivf_index_storage_));
ivf_index_storage_ = data_ptr->BasePtr();
ivf_index_storage_ = data_ptr->GetIVFIndexStoragePtr();
own_ivf_index_storage_ = false;
dump_handle_ = std::move(handle);
return new_chunk_index_entry;
}

void SearchIndexInMem(KnnDistanceType knn_distance_type,
const void *query_ptr,
EmbeddingDataType query_element_type,
std::function<void(f32, SegmentOffset)> add_result_func) const override {
// TODO
}
};

template <LogicalType column_logical_type>
Expand Down Expand Up @@ -260,4 +267,17 @@ SharedPtr<IVFIndexInMem> IVFIndexInMem::NewIVFIndexInMem(const ColumnDef *column
return {};
}

void IVFIndexInMem::SearchIndex(const KnnDistanceType knn_distance_type,
const void *query_ptr,
const EmbeddingDataType query_element_type,
const u32 nprobe,
std::function<void(f32, SegmentOffset)> add_result_func) const {
std::shared_lock lock(rw_mutex_);
if (have_ivf_index_.test(std::memory_order_acquire)) {
ivf_index_storage_->SearchIndex(knn_distance_type, query_ptr, query_element_type, nprobe, add_result_func);
} else {
SearchIndexInMem(knn_distance_type, query_ptr, query_element_type, add_result_func);
}
}

} // namespace infinity
13 changes: 12 additions & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_data_in_mem.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ivf_index_storage;
import column_def;
import logical_type;
import buffer_handle;
import knn_expr;

namespace infinity {

Expand Down Expand Up @@ -60,8 +61,18 @@ public:
u32 row_offset,
u32 row_count) = 0;
virtual SharedPtr<ChunkIndexEntry> Dump(SegmentIndexEntry *segment_index_entry, BufferManager *buffer_mgr) = 0;
// TODO: query
void SearchIndex(KnnDistanceType knn_distance_type,
const void *query_ptr,
EmbeddingDataType query_element_type,
u32 nprobe,
std::function<void(f32, SegmentOffset)> add_result_func) const;
static SharedPtr<IVFIndexInMem> NewIVFIndexInMem(const ColumnDef *column_def, const IndexBase *index_base, RowID begin_row_id);

private:
virtual void SearchIndexInMem(KnnDistanceType knn_distance_type,
const void *query_ptr,
EmbeddingDataType query_element_type,
std::function<void(f32, SegmentOffset)> add_result_func) const = 0;
};

} // namespace infinity
60 changes: 60 additions & 0 deletions src/storage/knn_index/knn_ivf/ivf_index_search.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

#include <string>
module ivf_index_search;

import stl;
import index_ivf;
import internal_types;
import logical_type;
import data_type;
import knn_expr;
import knn_scan_data;
import infinity_exception;
import status;
import third_party;
import ivf_index_data;
import ivf_index_data_in_mem;
import ivf_index_storage;

namespace infinity {

IVF_Search_Params IVF_Search_Params::Make(const KnnScanSharedData *knn_scan_shared_data) {
IVF_Search_Params params;
if (knn_scan_shared_data->query_count_ != 1) {
RecoverableError(Status::SyntaxError(fmt::format("Invalid query_count: {} which is not 1.", knn_scan_shared_data->query_count_)));
}
params.topk_ = knn_scan_shared_data->topk_;
params.query_embedding_ = knn_scan_shared_data->query_embedding_;
params.query_elem_type_ = knn_scan_shared_data->query_elem_type_;
params.knn_distance_type_ = knn_scan_shared_data->knn_distance_type_;
params.nprobe_ = 1;
for (const auto &opt_param : knn_scan_shared_data->opt_params_) {
if (opt_param.param_name_ == "nprobe") {
params.nprobe_ = DataType::StringToValue<IntegerT>(opt_param.param_value_);
if (params.nprobe_ <= 0) {
RecoverableError(Status::SyntaxError(fmt::format("Invalid negative nprobe value: {}", opt_param.param_name_)));
}
}
}
if (params.topk_ <= 0 || params.topk_ > std::numeric_limits<u32>::max()) {
RecoverableError(Status::SyntaxError(fmt::format("Invalid topk which is out of range: {}.", params.topk_)));
}
return params;
}

} // namespace infinity
180 changes: 180 additions & 0 deletions src/storage/knn_index/knn_ivf/ivf_index_search.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

#include <cassert>
#include <functional>
export module ivf_index_search;

import stl;
import internal_types;
import data_type;
import knn_expr;
import knn_scan_data;
import logical_type;
import knn_result_handler;
import multivector_result_handler;
import infinity_exception;
import status;
import third_party;
import roaring_bitmap;
import knn_filter;
import ivf_index_data;
import ivf_index_data_in_mem;
import ivf_index_storage;
import search_top_1;
import search_top_k;

namespace infinity {

export struct IVF_Search_Params {
KnnScanSharedData *knn_scan_shared_data_{};
i64 topk_{};
void *query_embedding_{};
EmbeddingDataType query_elem_type_{EmbeddingDataType::kElemInvalid};
KnnDistanceType knn_distance_type_{KnnDistanceType::kInvalid};
i32 nprobe_{1};

static IVF_Search_Params Make(const KnnScanSharedData *knn_scan_shared_data);
};

export template <typename DistanceDataType>
class IVF_Search_Handler {
protected:
IVF_Search_Params ivf_params_;
UniquePtr<DistanceDataType[]> distance_output_ptr_{};
UniquePtr<SegmentOffset[]> segment_offset_output_ptr_{};

explicit IVF_Search_Handler(const IVF_Search_Params &ivf_params) : ivf_params_(ivf_params) {
distance_output_ptr_ = MakeUniqueForOverwrite<DistanceDataType[]>(ivf_params_.topk_);
segment_offset_output_ptr_ = MakeUniqueForOverwrite<SegmentOffset[]>(ivf_params_.topk_);
}
virtual SizeT EndWithoutSortAndGetResultSize() = 0;

public:
virtual ~IVF_Search_Handler() = default;
virtual void Begin() = 0;
virtual void Search(const IVFIndexInChunk *ivf_index_in_chunk) = 0;
virtual void Search(const IVFIndexInMem *ivf_index_in_mem) = 0;
Tuple<SizeT, UniquePtr<DistanceDataType[]>, UniquePtr<SegmentOffset[]>> EndWithoutSort() {
const auto result_cnt = EndWithoutSortAndGetResultSize();
return {result_cnt, std::move(distance_output_ptr_), std::move(segment_offset_output_ptr_)};
}
};

template <bool use_bitmask>
struct IVF_Filter;

template <>
struct IVF_Filter<true> {
BitmaskFilter<SegmentOffset> filter_;
IVF_Filter(const Bitmask &bitmask, const SegmentOffset max_segment_offset) : filter_(bitmask) {}
bool operator()(const SegmentOffset &segment_offset) const { return filter_(segment_offset); }
};

template <>
struct IVF_Filter<false> {
AppendFilter filter_;
IVF_Filter(const Bitmask &bitmask, const SegmentOffset max_segment_offset) : filter_(max_segment_offset) {}
bool operator()(const SegmentOffset &segment_offset) const { return filter_(segment_offset); }
};

template <LogicalType t,
template <typename, typename>
typename C,
typename DistanceDataType,
bool use_bitmask,
typename MultiVectorInnerTopnIndexType = void>
class IVF_Search_HandlerT final : public IVF_Search_Handler<DistanceDataType> {
static_assert(t == LogicalType::kEmbedding || t == LogicalType::kMultiVector);
static constexpr bool NEED_FLIP = !std::is_same_v<CompareMax<DistanceDataType, SegmentOffset>, C<DistanceDataType, SegmentOffset>>;
using ResultHandler = std::conditional_t<t == LogicalType::kEmbedding,
HeapResultHandler<CompareMax<DistanceDataType, SegmentOffset>>,
MultiVectorResultHandler<DistanceDataType, SegmentOffset, MultiVectorInnerTopnIndexType>>;
IVF_Filter<use_bitmask> filter_;
ResultHandler result_handler_;

public:
IVF_Search_HandlerT(const IVF_Search_Params &ivf_params, const Bitmask &bitmask, SegmentOffset max_segment_offset)
: IVF_Search_Handler<DistanceDataType>(ivf_params), filter_(bitmask, max_segment_offset),
result_handler_(1, this->ivf_params_.topk_, this->distance_output_ptr_.get(), this->segment_offset_output_ptr_.get()) {}
void Begin() override { result_handler_.Begin(); }
void Search(const IVFIndexInChunk *ivf_index_in_chunk) override {
const auto *ivf_index_storage = ivf_index_in_chunk->GetIVFIndexStoragePtr();
ivf_index_storage->SearchIndex(this->ivf_params_.knn_distance_type_,
this->ivf_params_.query_embedding_,
this->ivf_params_.query_elem_type_,
this->ivf_params_.nprobe_,
std::bind(&IVF_Search_HandlerT::AddResult, this, std::placeholders::_1, std::placeholders::_2));
}
void Search(const IVFIndexInMem *ivf_index_in_mem) override {
ivf_index_in_mem->SearchIndex(this->ivf_params_.knn_distance_type_,
this->ivf_params_.query_embedding_,
this->ivf_params_.query_elem_type_,
this->ivf_params_.nprobe_,
std::bind(&IVF_Search_HandlerT::AddResult, this, std::placeholders::_1, std::placeholders::_2));
}
void AddResult(DistanceDataType d, SegmentOffset i) {
if constexpr (t == LogicalType::kEmbedding) {
result_handler_.AddResult(0, d, i);
} else {
static_assert(t == LogicalType::kMultiVector);
result_handler_.AddResult(d, i);
}
}
SizeT EndWithoutSortAndGetResultSize() override {
result_handler_.EndWithoutSort();
return result_handler_.GetSize(0);
}
};

export template <LogicalType t, template <typename, typename> typename C, typename DistanceDataType, bool use_bitmask>
UniquePtr<IVF_Search_Handler<DistanceDataType>>
GetIVFSearchHandler(const IVF_Search_Params &ivf_params, const Bitmask &bitmask, const SegmentOffset max_segment_offset) {
if constexpr (t == LogicalType::kEmbedding) {
return MakeUnique<IVF_Search_HandlerT<t, C, DistanceDataType, use_bitmask>>(ivf_params, bitmask, max_segment_offset);
} else if constexpr (t == LogicalType::kMultiVector) {
const auto top_k = ivf_params.topk_;
if (top_k <= 0) {
RecoverableError(Status::SyntaxError(fmt::format("Invalid topk: {}", top_k)));
return nullptr;
}
if (top_k <= std::numeric_limits<u8>::max()) {
return MakeUnique<IVF_Search_HandlerT<t, C, DistanceDataType, use_bitmask, u8>>(ivf_params, bitmask, max_segment_offset);
}
if (top_k <= std::numeric_limits<u16>::max()) {
return MakeUnique<IVF_Search_HandlerT<t, C, DistanceDataType, use_bitmask, u16>>(ivf_params, bitmask, max_segment_offset);
}
if (top_k <= std::numeric_limits<u32>::max()) {
return MakeUnique<IVF_Search_HandlerT<t, C, DistanceDataType, use_bitmask, u32>>(ivf_params, bitmask, max_segment_offset);
}
RecoverableError(Status::SyntaxError(fmt::format("Unsupported topk : {}, which is larger than u32::max()", top_k)));
return nullptr;
} else {
static_assert(false, "Invalid LogicalType for KNN");
return nullptr;
}
}

export template <LogicalType t, template <typename, typename> typename C, typename DistanceDataType>
UniquePtr<IVF_Search_Handler<DistanceDataType>>
GetIVFSearchHandler(const IVF_Search_Params &ivf_params, const bool use_bitmask, const Bitmask &bitmask, const SegmentOffset max_segment_offset) {
if (use_bitmask) {
return GetIVFSearchHandler<t, C, DistanceDataType, true>(ivf_params, bitmask, max_segment_offset);
}
return GetIVFSearchHandler<t, C, DistanceDataType, false>(ivf_params, bitmask, max_segment_offset);
}

} // namespace infinity
Loading

0 comments on commit 51fdbfa

Please sign in to comment.