Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): Hnsw Vector Search Plan Operator & Executor #2434

Merged
merged 13 commits into from
Jul 24, 2024
75 changes: 75 additions & 0 deletions src/search/executors/hnsw_vector_field_knn_scan_executor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <string>

#include "db_util.h"
#include "encoding.h"
#include "search/hnsw_indexer.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
#include "storage/redis_db.h"
#include "storage/redis_metadata.h"
#include "storage/storage.h"
#include "string_util.h"

namespace kqir {

struct HnswVectorFieldKnnScanExecutor : ExecutorNode {
HnswVectorFieldKnnScan *scan;
redis::LatestSnapShot ss;
bool initialized = false;

IndexInfo *index;
redis::SearchKey search_key;
redis::HnswVectorFieldMetadata field_metadata;
Copy link
Member

@PragmaTwice PragmaTwice Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems here we can just keep a pointer e.g. const redis::HnswVectorFieldMetadata *field_metadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this was because HnswIndex constructor does not support const HnswVectorFieldMetadata* .

HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage);

HnswIndex needs to modify HnswVectorFieldMetadata* in other functions, so I did copy here. Do you have ideas on how we can improve it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me think, if the metadata is modified in a certain vector query, it seems that the global metadata of the vector field (in server->index_mgr->index_map) will not change. Will this cause any problems?

Copy link
Member Author

@Beihao-Zhou Beihao-Zhou Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HnswVectorFieldMetadata* is only modified when the node is inserted/deleted in a higher layer than all other nodes, so num_levels in HnswVectorFieldMetadata : IndexFieldMetadata is modified. So the affected field is server->index_mgr->index_map[<index_key>]->fields[<field>]->metadata.

In indexer.cc, we do

auto *metadata = iter->second.metadata.get();
if (auto vector = dynamic_cast<HnswVectorFieldMetadata *>(metadata)) {
    GET_OR_RET(UpdateHnswVectorIndex(key, original, current, search_key, vector));
}

So *metadata changes correspondingly.

But in this PR, since we don't modify metadata, I guess it will not cause actual problems by copying (maybe with consistency issue but this would be fixed later aligning with #2310 , and the expensive copy because of the large size of HnswIndex but I can fix that right after the PR if the solution using static member variable sounds good to you <3).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean.

It's good for now, but we can enhance it later. For example, we could create a const version of HnswIndex that takes a const pointer to HnswVectorFieldMetadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!
I wanted to overload with a const version, but there are some nested calls also asking for HnswVectorFieldMetadata* as a parameter, so eventually didn't implement it to avoid making this PR look too messy. I'll take a note in tracking issue #2426 and improve it in the future.

std::vector<std::string> row_keys;
decltype(row_keys)::iterator row_keys_iter;

HnswVectorFieldKnnScanExecutor(ExecutorContext *ctx, HnswVectorFieldKnnScan *scan)
: ExecutorNode(ctx),
scan(scan),
ss(ctx->storage),
index(scan->field->info->index),
search_key(index->ns, index->name, scan->field->name),
field_metadata(*(scan->field->info->MetadataAs<redis::HnswVectorFieldMetadata>())) {}

StatusOr<Result> Next() override {
if (!initialized) {
// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
auto hnsw_index = redis::HnswIndex(search_key, &field_metadata, ctx->storage);
row_keys = GET_OR_RET(hnsw_index.KnnSearch(scan->vector, scan->k));
row_keys_iter = row_keys.begin();
initialized = true;
}

if (row_keys_iter == row_keys.end()) {
return end;
}

auto key_str = *row_keys_iter;
row_keys_iter++;
return RowType{key_str, {}, scan->field->info->index};
}
};

} // namespace kqir
31 changes: 31 additions & 0 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,14 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::Nu
return Status::OK();
}

// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
Status HnswIndex::InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto target_level = RandomizeLayer();
return InsertVectorEntryInternal(key, vector, batch, target_level);
}

// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
std::string node_key(key);
for (uint16_t level = 0; level < metadata->num_levels; level++) {
Expand Down Expand Up @@ -549,4 +551,33 @@ Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<ro
return Status::OK();
}

// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
StatusOr<std::vector<std::string>> HnswIndex::KnnSearch(const kqir::NumericArray& query_vector, uint32_t k) {
VectorItem query_vector_item;
GET_OR_RET(VectorItem::Create({}, query_vector, metadata, &query_vector_item));
uint32_t effective_ef = std::max(metadata->ef_runtime, k); // Ensure ef_runtime is at least k

if (metadata->num_levels == 0) {
return {Status::NotFound, fmt::format("No vector found in the HNSW index")};
}

auto level = metadata->num_levels - 1;
auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level));
std::vector<NodeKey> entry_points{default_entry_node};
std::vector<VectorItem> nearest_vec_items;

for (; level > 0; level--) {
nearest_vec_items = GET_OR_RET(SearchLayer(level, query_vector_item, effective_ef, entry_points));
entry_points = {nearest_vec_items[0].key};
}
nearest_vec_items = GET_OR_RET(SearchLayer(0, query_vector_item, effective_ef, entry_points));

uint32_t result_length = std::min(k, static_cast<uint32_t>(nearest_vec_items.size()));
std::vector<std::string> nearest_neighbours_key(result_length);
for (uint32_t result_idx = 0; result_idx < result_length; result_idx++) {
nearest_neighbours_key[result_idx] = std::move(nearest_vec_items[result_idx].key);
}
return nearest_neighbours_key;
}

} // namespace redis
1 change: 1 addition & 0 deletions src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct HnswIndex {
Status InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
StatusOr<std::vector<std::string>> KnnSearch(const kqir::NumericArray& query_vector, uint32_t k);
};

} // namespace redis
20 changes: 20 additions & 0 deletions src/search/ir_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "ir.h"
#include "search/interval.h"
#include "search/ir_sema_checker.h"
#include "search/value.h"
#include "string_util.h"

namespace kqir {
Expand Down Expand Up @@ -96,6 +97,25 @@ struct TagFieldScan : FieldScan {
}
};

struct HnswVectorFieldKnnScan : FieldScan {
kqir::NumericArray vector;
uint16_t k;

HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray vector, uint16_t k)
: FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}

std::string_view Name() const override { return "HnswVectorFieldKnnScan"; };
std::string Content() const override { return kqir::MakeValue<kqir::NumericArray>(vector).ToString(); };
std::string Dump() const override {
return fmt::format("hnsw-vector-knn-scan {}, {}", field->name,
kqir::MakeValue<kqir::NumericArray>(vector).ToString());
}
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved

std::unique_ptr<Node> Clone() const override {
return std::make_unique<HnswVectorFieldKnnScan>(field->CloneAs<FieldRef>(), vector, k);
}
};

struct Filter : PlanOperator {
std::unique_ptr<PlanOperator> source;
std::unique_ptr<QueryExpr> filter_expr;
Expand Down
7 changes: 7 additions & 0 deletions src/search/plan_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "search/executors/filter_executor.h"
#include "search/executors/full_index_scan_executor.h"
#include "search/executors/hnsw_vector_field_knn_scan_executor.h"
#include "search/executors/limit_executor.h"
#include "search/executors/merge_executor.h"
#include "search/executors/mock_executor.h"
Expand Down Expand Up @@ -84,6 +85,10 @@ struct ExecutorContextVisitor {
return Visit(v);
}

if (auto v = dynamic_cast<HnswVectorFieldKnnScan *>(op)) {
return Visit(v);
}

if (auto v = dynamic_cast<Mock *>(op)) {
return Visit(v);
}
Expand Down Expand Up @@ -129,6 +134,8 @@ struct ExecutorContextVisitor {

void Visit(TagFieldScan *op) { ctx->nodes[op] = std::make_unique<TagFieldScanExecutor>(ctx, op); }

void Visit(HnswVectorFieldKnnScan *op) { ctx->nodes[op] = std::make_unique<HnswVectorFieldKnnScanExecutor>(ctx, op); }

void Visit(Mock *op) { ctx->nodes[op] = std::make_unique<MockExecutor>(ctx, op); }
};

Expand Down
1 change: 1 addition & 0 deletions src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {
double epsilon = 0.01; // Relative factor setting search boundaries in range queries
uint16_t num_levels = 0; // Number of levels in the HNSW graph

// TODO: Initialize with required fields?
HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}

void Encode(std::string *dst) const override {
Expand Down
89 changes: 89 additions & 0 deletions tests/cppunit/hnsw_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,92 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
actual_set = {node5_layer0.neighbours.begin(), node5_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);
}

TEST_F(HnswIndexTest, KnnSearch) {
std::vector<double> query_vector = {31.0, 32.0, 23.0};
uint32_t k = 3;
auto s1 = hnsw_index->KnnSearch(query_vector, k);
ASSERT_FALSE(s1.IsOK());
EXPECT_EQ(s1.GetCode(), Status::NotFound);

std::vector<double> vec1 = {11.0, 12.0, 13.0};
std::vector<double> vec2 = {14.0, 15.0, 16.0};
std::vector<double> vec3 = {17.0, 18.0, 19.0};
std::vector<double> vec4 = {12.0, 13.0, 14.0};
std::vector<double> vec5 = {30.0, 40.0, 35.0};

std::string key1 = "key1";
std::string key2 = "key2";
std::string key3 = "key3";
std::string key4 = "key4";
std::string key5 = "key5";

// Insert key1 into layer 1
uint16_t target_level = 1;
auto batch = storage_->GetWriteBatchBase();
auto s2 = hnsw_index->InsertVectorEntryInternal(key1, vec1, batch, target_level);
ASSERT_TRUE(s2.IsOK());
auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

// Search when HNSW graph contains less than k nodes
auto s3 = hnsw_index->KnnSearch(query_vector, k);
ASSERT_TRUE(s3.IsOK());
auto key_strs = s3.GetValue();
std::vector<std::string> expected = {"key1"};
EXPECT_EQ(key_strs, expected);

// Insert key2 into layer 2
target_level = 2;
batch = storage_->GetWriteBatchBase();
auto s4 = hnsw_index->InsertVectorEntryInternal(key2, vec2, batch, target_level);
ASSERT_TRUE(s4.IsOK());
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

// Insert key3 into layer 0
target_level = 0;
batch = storage_->GetWriteBatchBase();
auto s5 = hnsw_index->InsertVectorEntryInternal(key3, vec3, batch, target_level);
ASSERT_TRUE(s5.IsOK());
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

// Search when HNSW graph contains exactly k nodes
auto s6 = hnsw_index->KnnSearch(query_vector, k);
ASSERT_TRUE(s6.IsOK());
key_strs = s6.GetValue();
expected = {"key3", "key2", "key1"};
EXPECT_EQ(key_strs, expected);

// Insert key4 into layer 1
target_level = 1;
batch = storage_->GetWriteBatchBase();
auto s7 = hnsw_index->InsertVectorEntryInternal(key4, vec4, batch, target_level);
ASSERT_TRUE(s7.IsOK());
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

// Insert key5 into layer 0
target_level = 0;
batch = storage_->GetWriteBatchBase();
auto s8 = hnsw_index->InsertVectorEntryInternal(key5, vec5, batch, target_level);
ASSERT_TRUE(s8.IsOK());
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

// Search when HNSW graph contains more than k nodes
auto s9 = hnsw_index->KnnSearch(query_vector, k);
ASSERT_TRUE(s9.IsOK());
key_strs = s9.GetValue();
expected = {"key5", "key3", "key2"};
EXPECT_EQ(key_strs, expected);

// Edge case: If ef_runtime is smaller than k, enlarge ef_runtime equal to k
hnsw_index->metadata->ef_runtime = 1;
auto s10 = hnsw_index->KnnSearch(query_vector, k);
ASSERT_TRUE(s10.IsOK());
key_strs = s10.GetValue();
expected = {"key5", "key3", "key2"};
EXPECT_EQ(key_strs, expected);
}
44 changes: 44 additions & 0 deletions tests/cppunit/indexer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ struct IndexerTest : TestBase {
auto json_info = std::make_unique<kqir::IndexInfo>("jsontest", json_field_meta, ns);
json_info->Add(kqir::FieldInfo("$.x", std::make_unique<redis::TagFieldMetadata>()));
json_info->Add(kqir::FieldInfo("$.y", std::make_unique<redis::NumericFieldMetadata>()));
auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
hnsw_field_meta->dim = 3;
hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
json_info->Add(kqir::FieldInfo("$.z", std::move(hnsw_field_meta)));
json_info->prefixes.prefixes.emplace_back("idxtestjson");

map.emplace("jsontest", std::move(json_info));
Expand Down Expand Up @@ -280,3 +285,42 @@ TEST_F(IndexerTest, JsonTagBuildIndex) {
ASSERT_EQ(val, "");
}
}

TEST_F(IndexerTest, JsonHnswVector) {
redis::Json db(storage_.get(), ns);
auto cfhandler = storage_->GetCFHandle(ColumnFamilyID::Search);

{
auto s = indexer.Record("no_exist", ns);
ASSERT_TRUE(s.Is<Status::NoPrefixMatched>());
}

auto key3 = "idxtestjson:k3";
auto idxname = "jsontest";

{
auto s = indexer.Record(key3, ns);
ASSERT_TRUE(s);
ASSERT_EQ(s->updater.info->name, idxname);
ASSERT_TRUE(s->fields.empty());

auto s_set = db.Set(key3, "$", R"({"z": [1,2,3]})");
ASSERT_TRUE(s_set.ok());

auto s2 = indexer.Update(*s);
EXPECT_EQ(s2.Msg(), Status::ok_msg);

auto search_key = redis::SearchKey(ns, idxname, "$.z").ConstructHnswNode(0, key3);

std::string val;
auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, search_key, &val);
ASSERT_TRUE(s3.ok());

redis::HnswNodeFieldMetadata node_meta;
Slice input(val);
node_meta.Decode(&input);
EXPECT_EQ(node_meta.num_neighbours, 0);
std::vector<double> expected = {1, 2, 3};
EXPECT_EQ(expected, node_meta.vector);
}
}
Loading
Loading