Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support applying BlockMaxWand algorithm to PhraseDocIterator #2369

Merged
merged 9 commits into from
Dec 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add bmw_iterator_interface
yangzq50 committed Dec 13, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 57dd9c12dc9e1caa6cb3e61df66e07ec832ff9d8
42 changes: 42 additions & 0 deletions src/storage/invertedindex/search/blockmax_leaf_iterator.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module blockmax_leaf_iterator;

import stl;
import internal_types;
import doc_iterator;

namespace infinity {

export class BlockMaxLeafIterator : public DocIterator {
public:
virtual RowID BlockMinPossibleDocID() const = 0;

virtual RowID BlockLastDocID() const = 0;

virtual float BlockMaxBM25Score() = 0;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
virtual bool NextShallow(RowID doc_id) = 0;

virtual float BM25Score() = 0;
};

} // namespace infinity
68 changes: 68 additions & 0 deletions src/storage/invertedindex/search/phrase_doc_iterator.cpp
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ module;

#include <cassert>
#include <iostream>
#include <vector>

module phrase_doc_iterator;

@@ -31,6 +32,8 @@ PhraseDocIterator::PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters,
estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq());
}
estimate_iterate_cost_ = {1, estimate_doc_freq_};
block_max_bm25_score_cache_part_info_end_ids_.resize(pos_iters_.size(), INVALID_ROWID);
block_max_bm25_score_cache_part_info_vals_.resize(pos_iters_.size());
}

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader) {
@@ -44,6 +47,9 @@ void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&col
float smooth_idf = std::log1p((total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F));
bm25_common_score_ = weight_ * smooth_idf * (k1 + 1.0F);
bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len);
f1 = k1 * (1.0F - b);
f2 = k1 * b / avg_column_len;
f3 = f2 * std::numeric_limits<u16>::max();
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
oss << "TermDocIterator: ";
@@ -80,13 +86,74 @@ bool PhraseDocIterator::Next(const RowID doc_id) {
bool found = GetPhraseMatchData();
if (found && (threshold_ <= 0.0f || BM25Score() > threshold_)) {
doc_id_ = target_doc_id;
UpdateBlockRangeDocID();
return true;
}
++target_doc_id;
}
}
}

void PhraseDocIterator::UpdateBlockRangeDocID() {
RowID min_doc_id = 0;
RowID max_doc_id = INVALID_ROWID;
for (const auto &it : pos_iters_) {
min_doc_id = std::max(min_doc_id, it->BlockLowestPossibleDocID());
max_doc_id = std::min(max_doc_id, it->BlockLastDocID());
}
block_min_possible_doc_id_ = min_doc_id;
block_last_doc_id_ = max_doc_id;
}

float PhraseDocIterator::BlockMaxBM25Score() {
if (const auto last_doc_id = BlockLastDocID(); last_doc_id != block_max_bm25_score_cache_end_id_) {
block_max_bm25_score_cache_end_id_ = last_doc_id;
// bm25_common_score_ / (1.0F + k1 * ((1.0F - b) / block_max_tf + b / block_max_percentage / avg_column_len));
// block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + f1 / block_max_tf + f3 / block_max_percentage_u16);
float div_add_min = std::numeric_limits<float>::max();
for (SizeT i = 0; i < pos_iters_.size(); ++i) {
const auto *iter = pos_iters_[i].get();
float current_div_add_min = {};
if (const auto iter_block_last_doc_id = iter->BlockLastDocID();
iter_block_last_doc_id == block_max_bm25_score_cache_part_info_end_ids_[i]) {
current_div_add_min = block_max_bm25_score_cache_part_info_vals_[i];
} else {
block_max_bm25_score_cache_part_info_end_ids_[i] = iter_block_last_doc_id;
const auto [block_max_tf, block_max_percentage_u16] = iter->GetBlockMaxInfo();
current_div_add_min = f1 / block_max_tf + f3 / block_max_percentage_u16;
block_max_bm25_score_cache_part_info_vals_[i] = current_div_add_min;
}
div_add_min = std::min(div_add_min, current_div_add_min);
}
block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + div_add_min);
}
return block_max_bm25_score_cache_;
}

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool PhraseDocIterator::NextShallow(RowID doc_id) {
if (threshold_ > BM25ScoreUpperBound()) [[unlikely]] {
doc_id_ = INVALID_ROWID;
return false;
}
while (true) {
for (const auto &iter : pos_iters_) {
if (!iter->SkipTo(doc_id)) {
doc_id_ = INVALID_ROWID;
return false;
}
}
UpdateBlockRangeDocID();
if (threshold_ <= 0.0f || BlockMaxBM25Score() > threshold_) {
return true;
}
doc_id = BlockLastDocID() + 1;
}
}

float PhraseDocIterator::BM25Score() {
if (doc_id_ == bm25_score_cache_docid_) [[unlikely]] {
return bm25_score_cache_;
@@ -112,6 +179,7 @@ void PhraseDocIterator::PrintTree(std::ostream &os, const String &prefix, bool i
}
os << ")";
os << " (doc_freq: " << GetDocFreq() << ")";
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << '\n';
}

23 changes: 21 additions & 2 deletions src/storage/invertedindex/search/phrase_doc_iterator.cppm
Original file line number Diff line number Diff line change
@@ -10,10 +10,11 @@ import posting_iterator;
import index_defines;
import column_length_io;
import parse_fulltext_options;
import blockmax_leaf_iterator;

namespace infinity {

export class PhraseDocIterator final : public DocIterator {
export class PhraseDocIterator final : public BlockMaxLeafIterator {
public:
PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, float weight, u32 slop, FulltextSimilarity ft_similarity);

@@ -32,7 +33,21 @@ public:

bool Next(RowID doc_id) override;

float BM25Score();
RowID BlockMinPossibleDocID() const override { return block_min_possible_doc_id_; }

RowID BlockLastDocID() const override { return block_last_doc_id_; }

void UpdateBlockRangeDocID();

float BlockMaxBM25Score() override;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool NextShallow(RowID doc_id) override;

float BM25Score() override;

float Score() override {
switch (ft_similarity_) {
@@ -86,6 +101,10 @@ private:
UniquePtr<FullTextColumnLengthReader> column_length_reader_ = nullptr;
float block_max_bm25_score_cache_ = 0.0f;
RowID block_max_bm25_score_cache_end_id_ = INVALID_ROWID;
Vector<RowID> block_max_bm25_score_cache_part_info_end_ids_;
Vector<float> block_max_bm25_score_cache_part_info_vals_;
RowID block_min_possible_doc_id_ = INVALID_ROWID;
RowID block_last_doc_id_ = INVALID_ROWID;

float tf_ = 0.0f; // current doc_id_'s tf
u32 estimate_doc_freq_{0}; // estimated at the beginning
13 changes: 7 additions & 6 deletions src/storage/invertedindex/search/term_doc_iterator.cppm
Original file line number Diff line number Diff line change
@@ -27,10 +27,11 @@ import doc_iterator;
import column_length_io;
import third_party;
import parse_fulltext_options;
import blockmax_leaf_iterator;

namespace infinity {

export class TermDocIterator final : public DocIterator {
export class TermDocIterator final : public BlockMaxLeafIterator {
public:
TermDocIterator(UniquePtr<PostingIterator> &&iter, u64 column_id, float weight, FulltextSimilarity ft_similarity);

@@ -48,15 +49,15 @@ public:

void InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader);

RowID BlockMinPossibleDocID() const { return iter_->BlockLowestPossibleDocID(); }
RowID BlockLastDocID() const { return iter_->BlockLastDocID(); }
float BlockMaxBM25Score();
RowID BlockMinPossibleDocID() const override { return iter_->BlockLowestPossibleDocID(); }
RowID BlockLastDocID() const override { return iter_->BlockLastDocID(); }
float BlockMaxBM25Score() override;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool NextShallow(RowID doc_id);
bool NextShallow(RowID doc_id) override;

// Overriden methods
DocIteratorType GetType() const override { return DocIteratorType::kTermDocIterator; }
@@ -65,7 +66,7 @@ public:

bool Next(RowID doc_id) override;

float BM25Score();
float BM25Score() override;

float Score() override {
switch (ft_similarity_) {