Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support customizable BM25 score parameters in fulltext search #2410

Merged
merged 5 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/bin/infinity_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <csignal>
#include <cstdio>
#include <cstdlib>
#include <exception>
#ifdef ENABLE_JEMALLOC_PROF
#include <jemalloc/jemalloc.h>
#endif
Expand Down Expand Up @@ -196,6 +197,25 @@ void RegisterSignal() {
sigaction(SIGSEGV, &sig_action, NULL);
}

void TerminateHandler() {
infinity::String message = "TerminateHandler: ";
try {
std::exception_ptr eptr{std::current_exception()};
if (eptr) {
std::rethrow_exception(eptr);
} else {
message += "Exiting without exception";
}
} catch (const std::exception &ex) {
message += "Unhandled Exception: ";
message += ex.what();
} catch (...) {
message += "Unknown Unhandled Exception";
}
infinity::PrintStacktrace(message);
std::abort();
}

} // namespace

auto main(int argc, char **argv) -> int {
Expand Down Expand Up @@ -254,6 +274,8 @@ auto main(int argc, char **argv) -> int {

RegisterSignal();

std::set_terminate(TerminateHandler);

InfinityContext::instance().InitPhase2();

shutdown_thread.join();
Expand Down
5 changes: 3 additions & 2 deletions src/executor/operator/physical_match.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ bool PhysicalMatch::ExecuteInner(QueryContext *query_context, OperatorState *ope
static_cast<TimeDurationType>(finish_init_query_builder_time - execute_start_time).count()));

// 2 build query iterator
FullTextQueryContext full_text_query_context(ft_similarity_, minimum_should_match_option_, top_n_, match_expr_->index_names_);
FullTextQueryContext full_text_query_context(ft_similarity_, bm25_params_, minimum_should_match_option_, top_n_, match_expr_->index_names_);
full_text_query_context.query_tree_ = MakeUnique<FilterQueryNode>(common_query_filter_.get(), std::move(query_tree_));
const auto query_iterators = CreateQueryIterators(query_builder, full_text_query_context, early_term_algo_, begin_threshold_, score_threshold_);
const auto finish_query_builder_time = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -331,14 +331,15 @@ PhysicalMatch::PhysicalMatch(const u64 id,
MinimumShouldMatchOption &&minimum_should_match_option,
const f32 score_threshold,
const FulltextSimilarity ft_similarity,
const BM25Params &bm25_params,
const u64 match_table_index,
SharedPtr<Vector<LoadMeta>> load_metas,
const bool cache_result)
: PhysicalOperator(PhysicalOperatorType::kMatch, nullptr, nullptr, id, std::move(load_metas), cache_result), table_index_(match_table_index),
base_table_ref_(std::move(base_table_ref)), match_expr_(std::move(match_expr)), index_reader_(std::move(index_reader)),
query_tree_(std::move(query_tree)), begin_threshold_(begin_threshold), early_term_algo_(early_term_algo), top_n_(top_n),
common_query_filter_(common_query_filter), minimum_should_match_option_(std::move(minimum_should_match_option)),
score_threshold_(score_threshold), ft_similarity_(ft_similarity) {}
score_threshold_(score_threshold), ft_similarity_(ft_similarity), bm25_params_(bm25_params) {}

PhysicalMatch::~PhysicalMatch() = default;

Expand Down
2 changes: 2 additions & 0 deletions src/executor/operator/physical_match.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public:
MinimumShouldMatchOption &&minimum_should_match_option,
f32 score_threshold,
FulltextSimilarity ft_similarity,
const BM25Params &bm25_params,
u64 match_table_index,
SharedPtr<Vector<LoadMeta>> load_metas,
bool cache_result);
Expand Down Expand Up @@ -115,6 +116,7 @@ private:
MinimumShouldMatchOption minimum_should_match_option_{};
f32 score_threshold_{};
FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25};
BM25Params bm25_params_;

bool ExecuteInner(QueryContext *query_context, OperatorState *operator_state);
};
Expand Down
1 change: 1 addition & 0 deletions src/executor/physical_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ UniquePtr<PhysicalOperator> PhysicalPlanner::BuildMatch(const SharedPtr<LogicalN
std::move(logical_match->minimum_should_match_option_),
logical_match->score_threshold_,
logical_match->ft_similarity_,
logical_match->bm25_params_,
logical_match->TableIndex(),
logical_operator->load_metas(),
true /*cache_result*/);
Expand Down
37 changes: 37 additions & 0 deletions src/planner/bound_select_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,43 @@ SharedPtr<LogicalNode> BoundSelectStatement::BuildPlan(QueryContext *query_conte
RecoverableError(Status::SyntaxError(R"(similarity option must be "BM25" or "boolean".)"));
}
}
// option: bm25_params
if (iter = search_ops.options_.find("bm25_param_k1"); iter != search_ops.options_.end()) {
const auto k1_v = DataType::StringToValue<FloatT>(iter->second);
if (k1_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_k1 must be a non-negative float. default value: 1.2"));
}
match_node->bm25_params_.k1 = k1_v;
}
if (iter = search_ops.options_.find("bm25_param_b"); iter != search_ops.options_.end()) {
const auto b_v = DataType::StringToValue<FloatT>(iter->second);
if (b_v < 0.0f || b_v > 1.0f) {
RecoverableError(Status::SyntaxError("bm25_param_b must be in the range [0.0f, 1.0f]. default value: 0.75"));
}
match_node->bm25_params_.b = b_v;
}
if (iter = search_ops.options_.find("bm25_param_delta"); iter != search_ops.options_.end()) {
const auto delta_v = DataType::StringToValue<FloatT>(iter->second);
if (delta_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_delta must be a non-negative float. default value: 0.0"));
}
match_node->bm25_params_.delta_term = delta_v;
match_node->bm25_params_.delta_phrase = delta_v;
}
if (iter = search_ops.options_.find("bm25_param_delta_term"); iter != search_ops.options_.end()) {
const auto delta_term_v = DataType::StringToValue<FloatT>(iter->second);
if (delta_term_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_delta_term must be a non-negative float. default value: 0.0"));
}
match_node->bm25_params_.delta_term = delta_term_v;
}
if (iter = search_ops.options_.find("bm25_param_delta_phrase"); iter != search_ops.options_.end()) {
const auto delta_phrase_v = DataType::StringToValue<FloatT>(iter->second);
if (delta_phrase_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_delta_phrase must be a non-negative float. default value: 0.0"));
}
match_node->bm25_params_.delta_phrase = delta_phrase_v;
}

SearchDriver search_driver(column2analyzer, default_field, query_operator_option);
UniquePtr<QueryNode> query_tree =
Expand Down
1 change: 1 addition & 0 deletions src/planner/node/logical_match.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public:
MinimumShouldMatchOption minimum_should_match_option_{};
f32 score_threshold_{};
FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25};
BM25Params bm25_params_;
};

} // namespace infinity
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ class IndexScanFilterExpressionPushDownMethod {
MinimumShouldMatchOption minimum_should_match_option;
f32 score_threshold = {};
FulltextSimilarity ft_similarity = FulltextSimilarity::kBM25;
BM25Params bm25_params;
Vector<String> index_names;
{
SearchOptions search_ops(filter_fulltext_expr->options_text_);
Expand Down Expand Up @@ -528,6 +529,43 @@ class IndexScanFilterExpressionPushDownMethod {
RecoverableError(Status::SyntaxError(R"(similarity option must be "BM25" or "boolean".)"));
}
}
// option: bm25_params
if (iter = search_ops.options_.find("bm25_param_k1"); iter != search_ops.options_.end()) {
const auto k1_v = DataType::StringToValue<FloatT>(iter->second);
if (k1_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_k1 must be a non-negative float. default value: 1.2"));
}
bm25_params.k1 = k1_v;
}
if (iter = search_ops.options_.find("bm25_param_b"); iter != search_ops.options_.end()) {
const auto b_v = DataType::StringToValue<FloatT>(iter->second);
if (b_v < 0.0f || b_v > 1.0f) {
RecoverableError(Status::SyntaxError("bm25_param_b must be in the range [0.0f, 1.0f]. default value: 0.75"));
}
bm25_params.b = b_v;
}
if (iter = search_ops.options_.find("bm25_param_delta"); iter != search_ops.options_.end()) {
const auto delta_v = DataType::StringToValue<FloatT>(iter->second);
if (delta_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_delta must be a non-negative float. default value: 0.0"));
}
bm25_params.delta_term = delta_v;
bm25_params.delta_phrase = delta_v;
}
if (iter = search_ops.options_.find("bm25_param_delta_term"); iter != search_ops.options_.end()) {
const auto delta_term_v = DataType::StringToValue<FloatT>(iter->second);
if (delta_term_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_delta_term must be a non-negative float. default value: 0.0"));
}
bm25_params.delta_term = delta_term_v;
}
if (iter = search_ops.options_.find("bm25_param_delta_phrase"); iter != search_ops.options_.end()) {
const auto delta_phrase_v = DataType::StringToValue<FloatT>(iter->second);
if (delta_phrase_v < 0.0f) {
RecoverableError(Status::SyntaxError("bm25_param_delta_phrase must be a non-negative float. default value: 0.0"));
}
bm25_params.delta_phrase = delta_phrase_v;
}

// option: indexes
if (iter = search_ops.options_.find("indexes"); iter != search_ops.options_.end()) {
Expand Down Expand Up @@ -564,6 +602,7 @@ class IndexScanFilterExpressionPushDownMethod {
std::move(minimum_should_match_option),
score_threshold,
ft_similarity,
bm25_params,
std::move(index_names));
}
case Enum::kAndExpr: {
Expand Down
3 changes: 2 additions & 1 deletion src/planner/optimizer/index_scan/index_filter_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ Bitmask IndexFilterEvaluatorFulltext::Evaluate(const SegmentID segment_id, const
result.SetAllFalse();
const RowID begin_rowid(segment_id, 0);
const RowID end_rowid(segment_id, segment_row_count);
const CreateSearchParams params{table_entry_, &index_reader_, early_term_algo_, ft_similarity_, minimum_should_match_, 0u, index_names_};
const CreateSearchParams params{table_entry_, &index_reader_, early_term_algo_, ft_similarity_, bm25_params_, minimum_should_match_, 0u, index_names_};
auto ft_iter = query_tree_->CreateSearch(params);
if (ft_iter && score_threshold_ > 0.0f) {
auto new_ft_iter = MakeUnique<ScoreThresholdIterator>(std::move(ft_iter), score_threshold_);
Expand Down Expand Up @@ -524,6 +524,7 @@ Bitmask IndexFilterEvaluatorAND::Evaluate(const SegmentID segment_id, const Segm
&(fulltext_evaluator_->index_reader_),
fulltext_evaluator_->early_term_algo_,
fulltext_evaluator_->ft_similarity_,
fulltext_evaluator_->bm25_params_,
fulltext_evaluator_->minimum_should_match_,
0u,
fulltext_evaluator_->index_names_};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator {
std::atomic_flag after_optimize_ = {};
f32 score_threshold_ = {};
FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25;
BM25Params bm25_params_;
Vector<String> index_names_;

IndexFilterEvaluatorFulltext(const FilterFulltextExpression *src_filter_fulltext_expression,
Expand All @@ -104,11 +105,12 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator {
MinimumShouldMatchOption &&minimum_should_match_option,
const f32 score_threshold,
const FulltextSimilarity ft_similarity,
const BM25Params &bm25_params,
Vector<String> &&index_names)
: IndexFilterEvaluator(Type::kFulltextIndex), src_filter_fulltext_expressions_({src_filter_fulltext_expression}), table_entry_(table_entry),
early_term_algo_(early_term_algo), index_reader_(std::move(index_reader)), query_tree_(std::move(query_tree)),
minimum_should_match_option_(std::move(minimum_should_match_option)), score_threshold_(std::max(score_threshold, 0.0f)),
ft_similarity_(ft_similarity), index_names_(std::move(index_names)) {}
ft_similarity_(ft_similarity), bm25_params_(bm25_params), index_names_(std::move(index_names)) {}
Bitmask Evaluate(SegmentID segment_id, SegmentOffset segment_row_count, Txn *txn) const override;
bool HaveMinimumShouldMatchOption() const { return !minimum_should_match_option_.empty(); }
void OptimizeQueryTree();
Expand Down
4 changes: 4 additions & 0 deletions src/storage/invertedindex/search/blockmax_leaf_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ export module blockmax_leaf_iterator;
import stl;
import internal_types;
import doc_iterator;
import column_length_io;

namespace infinity {

export class BlockMaxLeafIterator : public DocIterator {
public:
// ref: https://en.wikipedia.org/wiki/Okapi_BM25
virtual void InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader, float delta, float k1, float b) = 0;

virtual RowID BlockMinPossibleDocID() const = 0;

virtual RowID BlockLastDocID() const = 0;
Expand Down
7 changes: 7 additions & 0 deletions src/storage/invertedindex/search/parse_fulltext_options.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,11 @@ export enum class FulltextSimilarity {
kBoolean,
};

export struct BM25Params {
float k1 = 1.2F;
float b = 0.75F;
float delta_term = 0.0F;
float delta_phrase = 0.0F;
};

} // namespace infinity
15 changes: 6 additions & 9 deletions src/storage/invertedindex/search/phrase_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ PhraseDocIterator::PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters,
block_max_bm25_score_cache_part_info_vals_.resize(pos_iters_.size());
}

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader) {
// BM25 parameters
constexpr float k1 = 1.2F;
constexpr float b = 0.75F;

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader, const float delta, const float k1, const float b) {
column_length_reader_ = std::move(column_length_reader);
const u64 total_df = column_length_reader_->GetTotalDF();
const float avg_column_len = column_length_reader_->GetAvgColumnLength();
Expand All @@ -50,13 +46,14 @@ void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&col
total_idf += std::log1p((total_df - doc_freq + 0.5F) / (doc_freq + 0.5F));
}
bm25_common_score_ = weight_ * total_idf * (k1 + 1.0F);
bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len);
bm25_score_upper_bound_ = bm25_common_score_ * (avg_column_len / (avg_column_len + k1 * b) + delta / (k1 + 1.0F));
f1 = k1 * (1.0F - b);
f2 = k1 * b / avg_column_len;
f3 = f2 * std::numeric_limits<u16>::max();
f4 = delta / (k1 + 1.0F);
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
oss << "TermDocIterator: ";
oss << "PhraseDocIterator: ";
if (column_name_ptr_ != nullptr) {
oss << "column: " << *column_name_ptr_ << ",";
}
Expand Down Expand Up @@ -129,7 +126,7 @@ float PhraseDocIterator::BlockMaxBM25Score() {
}
div_add_min = std::min(div_add_min, current_div_add_min);
}
block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + div_add_min);
block_max_bm25_score_cache_ = bm25_common_score_ * (1.0F / (1.0F + div_add_min) + f4);
}
return block_max_bm25_score_cache_;
}
Expand Down Expand Up @@ -167,7 +164,7 @@ float PhraseDocIterator::BM25Score() {
// bm25_common_score_ * tf / (tf + k1 * (1.0F - b + b * column_len / avg_column_len));
const auto doc_len = column_length_reader_->GetColumnLength(doc_id_);
const float p = f1 + f2 * doc_len;
bm25_score_cache_ = bm25_common_score_ * tf_ / (tf_ + p);
bm25_score_cache_ = bm25_common_score_ * (tf_ / (tf_ + p) + f4);
return bm25_score_cache_;
}

Expand Down
3 changes: 2 additions & 1 deletion src/storage/invertedindex/search/phrase_doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public:

float GetPhraseFreq() const { return phrase_freq_; }

void InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader);
void InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader, float delta, float k1, float b) override;

DocIteratorType GetType() const override { return DocIteratorType::kPhraseIterator; }
String Name() const override { return "PhraseDocIterator"; }
Expand Down Expand Up @@ -97,6 +97,7 @@ private:
float f1 = 0.0f;
float f2 = 0.0f;
float f3 = 0.0f;
float f4 = 0.0f;
float bm25_common_score_ = 0.0f;
UniquePtr<FullTextColumnLengthReader> column_length_reader_ = nullptr;
float block_max_bm25_score_cache_ = 0.0f;
Expand Down
1 change: 1 addition & 0 deletions src/storage/invertedindex/search/query_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ UniquePtr<DocIterator> QueryBuilder::CreateSearch(FullTextQueryContext &context)
&index_reader_,
context.early_term_algo_,
context.ft_similarity_,
context.bm25_params_,
context.minimum_should_match_,
context.topn_,
context.index_names_};
Expand Down
Loading