Skip to content

Commit

Permalink
feat(search): Hnsw Vector Search Plan Operator & Executor (#2434)
Browse files Browse the repository at this point in the history
Co-authored-by: Twice <twice.mliu@gmail.com>
  • Loading branch information
Beihao-Zhou and PragmaTwice authored Jul 24, 2024
1 parent 45ba475 commit 79a740c
Show file tree
Hide file tree
Showing 9 changed files with 634 additions and 253 deletions.
76 changes: 76 additions & 0 deletions src/search/executors/hnsw_vector_field_knn_scan_executor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <string>

#include "db_util.h"
#include "encoding.h"
#include "search/hnsw_indexer.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
#include "storage/redis_db.h"
#include "storage/redis_metadata.h"
#include "storage/storage.h"
#include "string_util.h"

namespace kqir {

// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
struct HnswVectorFieldKnnScanExecutor : ExecutorNode {
HnswVectorFieldKnnScan *scan;
redis::LatestSnapShot ss;
bool initialized = false;

IndexInfo *index;
redis::SearchKey search_key;
redis::HnswVectorFieldMetadata field_metadata;
redis::HnswIndex hnsw_index;
std::vector<redis::KeyWithDistance> row_keys;
decltype(row_keys)::iterator row_keys_iter;

HnswVectorFieldKnnScanExecutor(ExecutorContext *ctx, HnswVectorFieldKnnScan *scan)
: ExecutorNode(ctx),
scan(scan),
ss(ctx->storage),
index(scan->field->info->index),
search_key(index->ns, index->name, scan->field->name),
field_metadata(*(scan->field->info->MetadataAs<redis::HnswVectorFieldMetadata>())),
hnsw_index(redis::HnswIndex(search_key, &field_metadata, ctx->storage)) {}

StatusOr<Result> Next() override {
if (!initialized) {
row_keys = GET_OR_RET(hnsw_index.KnnSearch(scan->vector, scan->k));
row_keys_iter = row_keys.begin();
initialized = true;
}

if (row_keys_iter == row_keys.end()) {
return end;
}

auto key_str = row_keys_iter->second;
row_keys_iter++;
return RowType{key_str, {}, scan->field->info->index};
}
};

} // namespace kqir
86 changes: 86 additions & 0 deletions src/search/executors/hnsw_vector_field_range_scan_executor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <string>

#include "db_util.h"
#include "encoding.h"
#include "search/hnsw_indexer.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
#include "storage/redis_db.h"
#include "storage/redis_metadata.h"
#include "storage/storage.h"
#include "string_util.h"

namespace kqir {

// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
struct HnswVectorFieldRangeScanExecutor : ExecutorNode {
HnswVectorFieldRangeScan *scan;
redis::LatestSnapShot ss;
bool initialized = false;

IndexInfo *index;
redis::SearchKey search_key;
redis::HnswVectorFieldMetadata field_metadata;
redis::HnswIndex hnsw_index;
std::vector<redis::KeyWithDistance> row_keys;
std::unordered_set<std::string> visited;
decltype(row_keys)::iterator row_keys_iter;

HnswVectorFieldRangeScanExecutor(ExecutorContext *ctx, HnswVectorFieldRangeScan *scan)
: ExecutorNode(ctx),
scan(scan),
ss(ctx->storage),
index(scan->field->info->index),
search_key(index->ns, index->name, scan->field->name),
field_metadata(*(scan->field->info->MetadataAs<redis::HnswVectorFieldMetadata>())),
hnsw_index(redis::HnswIndex(search_key, &field_metadata, ctx->storage)) {}

StatusOr<Result> Next() override {
if (!initialized) {
row_keys = GET_OR_RET(hnsw_index.KnnSearch(scan->vector, field_metadata.ef_runtime));
row_keys_iter = row_keys.begin();
initialized = true;
}

auto effective_range = scan->range * (1 + field_metadata.epsilon);
if (row_keys_iter == row_keys.end() || row_keys_iter->first > abs(effective_range) ||
row_keys_iter->first < -abs(effective_range)) {
row_keys = GET_OR_RET(hnsw_index.ExpandSearchScope(scan->vector, std::move(row_keys), visited));
if (row_keys.empty()) return end;
row_keys_iter = row_keys.begin();
}

if (row_keys_iter->first > abs(effective_range) || row_keys_iter->first < -abs(effective_range)) {
return end;
}

auto key_str = row_keys_iter->second;
row_keys_iter++;
visited.insert(key_str);
return RowType{key_str, {}, scan->field->info->index};
}
};

} // namespace kqir
99 changes: 89 additions & 10 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,13 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SelectNeighbors(const VectorItem& v
return selected_vs;
}

StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const VectorItem& target_vector,
uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points) const {
std::vector<VectorItem> candidates;
StatusOr<std::vector<VectorItemWithDistance>> HnswIndex::SearchLayerInternal(
uint16_t level, const VectorItem& target_vector, uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points) const {
std::vector<VectorItemWithDistance> result;
std::unordered_set<NodeKey> visited;
std::priority_queue<std::pair<double, VectorItem>, std::vector<std::pair<double, VectorItem>>, std::greater<>>
explore_heap;
std::priority_queue<std::pair<double, VectorItem>> result_heap;
std::priority_queue<VectorItemWithDistance, std::vector<VectorItemWithDistance>, std::greater<>> explore_heap;
std::priority_queue<VectorItemWithDistance> result_heap;

for (const auto& entry_point_key : entry_points) {
HnswNode entry_node = HnswNode(entry_point_key, level);
Expand Down Expand Up @@ -330,13 +329,25 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
}
}

result.resize(result_heap.size());
auto idx = result_heap.size() - 1;
while (!result_heap.empty()) {
candidates.push_back(result_heap.top().second);
result[idx] = result_heap.top();
result_heap.pop();
idx--;
}
return result;
}

std::reverse(candidates.begin(), candidates.end());
return candidates;
StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const VectorItem& target_vector,
uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points) const {
std::vector<VectorItem> result;
auto result_with_distance = GET_OR_RET(SearchLayerInternal(level, target_vector, ef_runtime, entry_points));
for (auto& [_, vector_item] : result_with_distance) {
result.push_back(std::move(vector_item));
}
return result;
}

Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::NumericArray& vector,
Expand Down Expand Up @@ -549,4 +560,72 @@ Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<ro
return Status::OK();
}

StatusOr<std::vector<KeyWithDistance>> HnswIndex::KnnSearch(const kqir::NumericArray& query_vector, uint32_t k) const {
VectorItem query_vector_item;
GET_OR_RET(VectorItem::Create({}, query_vector, metadata, &query_vector_item));

if (metadata->num_levels == 0) {
return {Status::NotFound, fmt::format("No vector found in the HNSW index")};
}

auto level = metadata->num_levels - 1;
auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level));
std::vector<NodeKey> entry_points{default_entry_node};
std::vector<VectorItem> nearest_vec_items;

for (; level > 0; level--) {
nearest_vec_items = GET_OR_RET(SearchLayer(level, query_vector_item, metadata->ef_runtime, entry_points));
entry_points = {nearest_vec_items[0].key};
}

uint32_t effective_ef = std::max(metadata->ef_runtime, k); // Ensure ef_runtime is at least k
auto nearest_vec_with_distance = GET_OR_RET(SearchLayerInternal(0, query_vector_item, effective_ef, entry_points));

uint32_t result_length = std::min(k, static_cast<uint32_t>(nearest_vec_with_distance.size()));
std::vector<KeyWithDistance> nearest_neighbours;
for (uint32_t result_idx = 0; result_idx < result_length; result_idx++) {
nearest_neighbours.emplace_back(nearest_vec_with_distance[result_idx].first,
std::move(nearest_vec_with_distance[result_idx].second.key));
}
return nearest_neighbours;
}

StatusOr<std::vector<KeyWithDistance>> HnswIndex::ExpandSearchScope(const kqir::NumericArray& query_vector,
std::vector<redis::KeyWithDistance>&& initial_keys,
std::unordered_set<std::string>& visited) const {
constexpr uint16_t level = 0;
VectorItem query_vector_item;
GET_OR_RET(VectorItem::Create({}, query_vector, metadata, &query_vector_item));
std::vector<KeyWithDistance> result;

while (!initial_keys.empty()) {
auto current_key = initial_keys.front().second;
initial_keys.erase(initial_keys.begin());

auto current_node = HnswNode(current_key, level);
current_node.DecodeNeighbours(search_key, storage);

for (const auto& neighbour_key : current_node.neighbours) {
if (visited.find(neighbour_key) != visited.end()) {
continue;
}
visited.insert(neighbour_key);

auto neighbour_node = HnswNode(neighbour_key, level);
auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));

VectorItem neighbour_node_vector;
GET_OR_RET(VectorItem::Create(neighbour_key, std::move(neighbour_node_metadata.vector), metadata,
&neighbour_node_vector));

auto dist = GET_OR_RET(ComputeSimilarity(query_vector_item, neighbour_node_vector));
result.emplace_back(dist, neighbour_key);
}
}
std::sort(result.begin(), result.end(),
[](const KeyWithDistance& a, const KeyWithDistance& b) { return a.first < b.first; });

return result;
}

} // namespace redis
11 changes: 11 additions & 0 deletions src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ struct VectorItem {

StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right);

using VectorItemWithDistance = std::pair<double, VectorItem>;
using KeyWithDistance = std::pair<double, std::string>;

// TODO(Beihao): Add DB context to improve consistency and isolation - see #2332
struct HnswIndex {
using NodeKey = HnswNode::NodeKey;

Expand All @@ -103,13 +107,20 @@ struct HnswIndex {

StatusOr<std::vector<VectorItem>> SelectNeighbors(const VectorItem& vec, const std::vector<VectorItem>& vectors,
uint16_t layer) const;
StatusOr<std::vector<VectorItemWithDistance>> SearchLayerInternal(uint16_t level, const VectorItem& target_vector,
uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points) const;
StatusOr<std::vector<VectorItem>> SearchLayer(uint16_t level, const VectorItem& target_vector, uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points) const;
Status InsertVectorEntryInternal(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, uint16_t layer) const;
Status InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
StatusOr<std::vector<KeyWithDistance>> KnnSearch(const kqir::NumericArray& query_vector, uint32_t k) const;
StatusOr<std::vector<KeyWithDistance>> ExpandSearchScope(const kqir::NumericArray& query_vector,
std::vector<redis::KeyWithDistance>&& initial_keys,
std::unordered_set<std::string>& visited) const;
};

} // namespace redis
37 changes: 37 additions & 0 deletions src/search/ir_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "ir.h"
#include "search/interval.h"
#include "search/ir_sema_checker.h"
#include "search/value.h"
#include "string_util.h"

namespace kqir {
Expand Down Expand Up @@ -96,6 +97,42 @@ struct TagFieldScan : FieldScan {
}
};

struct HnswVectorFieldKnnScan : FieldScan {
kqir::NumericArray vector;
uint16_t k;

HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray vector, uint16_t k)
: FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}

std::string_view Name() const override { return "HnswVectorFieldKnnScan"; };
std::string Content() const override {
return fmt::format("[{}], {}", util::StringJoin(vector, [](auto v) { return std::to_string(v); }), k);
};
std::string Dump() const override { return fmt::format("hnsw-vector-knn-scan {}, {}", field->name, Content()); }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<HnswVectorFieldKnnScan>(field->CloneAs<FieldRef>(), vector, k);
}
};

struct HnswVectorFieldRangeScan : FieldScan {
kqir::NumericArray vector;
uint32_t range;

HnswVectorFieldRangeScan(std::unique_ptr<FieldRef> field, kqir::NumericArray vector, uint32_t range)
: FieldScan(std::move(field)), vector(std::move(vector)), range(range) {}

std::string_view Name() const override { return "HnswVectorFieldRangeScan"; };
std::string Content() const override {
return fmt::format("[{}], {}", util::StringJoin(vector, [](auto v) { return std::to_string(v); }), range);
};
std::string Dump() const override { return fmt::format("hnsw-vector-range-scan {}, {}", field->name, Content()); }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<HnswVectorFieldRangeScan>(field->CloneAs<FieldRef>(), vector, range);
}
};

struct Filter : PlanOperator {
std::unique_ptr<PlanOperator> source;
std::unique_ptr<QueryExpr> filter_expr;
Expand Down
Loading

0 comments on commit 79a740c

Please sign in to comment.