Skip to content

Commit

Permalink
Support keyword analyzer (#2168)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Support keyword analyzer
Support boolean similarity for columns with keyword analyzer

Issue link:#2139

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
- [x] Test cases
  • Loading branch information
yangzq50 authored Nov 4, 2024
1 parent 4ea5fcb commit 96e1bb9
Show file tree
Hide file tree
Showing 13 changed files with 397 additions and 176 deletions.
4 changes: 4 additions & 0 deletions src/common/analyzer/analyzer_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import korean_analyzer;
import standard_analyzer;
import ngram_analyzer;
import rag_analyzer;
import keyword_analyzer;
import logger;

namespace infinity {
Expand Down Expand Up @@ -267,6 +268,9 @@ Tuple<UniquePtr<Analyzer>, Status> AnalyzerPool::GetAnalyzer(const std::string_v
}
return {MakeUnique<NGramAnalyzer>(ngram), Status::OK()};
}
case Str2Int(KEYWORD.data()): {
return {MakeUnique<KeywordAnalyzer>(), Status::OK()};
}
default: {
if(std::filesystem::is_regular_file(name)) {
// Suppose it is a customized Python script analyzer
Expand Down
1 change: 1 addition & 0 deletions src/common/analyzer/analyzer_pool.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public:
static constexpr std::string_view STANDARD = "standard";
static constexpr std::string_view NGRAM = "ngram";
static constexpr std::string_view RAG = "rag";
static constexpr std::string_view KEYWORD = "keyword";

private:
CacheType cache_{};
Expand Down
37 changes: 37 additions & 0 deletions src/common/analyzer/keyword_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

#include <sstream>
#include <string>
module keyword_analyzer;

import stl;
import term;
import analyzer;

namespace infinity {

int KeywordAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
std::istringstream is(input.text_);
std::string t;
u32 offset = 0;
while (is >> t) {
func(data, t.data(), t.size(), offset++, 0, Term::AND, 0, false);
}
return 0;
}

} // namespace infinity
32 changes: 32 additions & 0 deletions src/common/analyzer/keyword_analyzer.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module keyword_analyzer;
import stl;
import term;
import analyzer;

namespace infinity {
export class KeywordAnalyzer : public Analyzer {
public:
KeywordAnalyzer() = default;
~KeywordAnalyzer() override = default;

protected:
int AnalyzeImpl(const Term &input, void *data, HookType func) override;
};

} // namespace infinity
26 changes: 13 additions & 13 deletions src/parser/search_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,22 +807,22 @@ namespace infinity {
case 4: // query: query clause
#line 91 "search_parser.y"
{
auto query = driver.GetMultiQueryNodeByOperatorOption();
auto *multi_query_ptr = dynamic_cast<MultiQueryNode *>(query.get());
multi_query_ptr->Add(std::move(yystack_[1].value.as < std::unique_ptr<QueryNode> > ()));
multi_query_ptr->Add(std::move(yystack_[0].value.as < std::unique_ptr<QueryNode> > ()));
yylhs.value.as < std::unique_ptr<QueryNode> > () = std::move(query);
assert(driver.operator_option_ == FulltextQueryOperatorOption::kInfinitySyntax);
auto q = std::make_unique<OrQueryNode>();
q->Add(std::move(yystack_[1].value.as < std::unique_ptr<QueryNode> > ()));
q->Add(std::move(yystack_[0].value.as < std::unique_ptr<QueryNode> > ()));
yylhs.value.as < std::unique_ptr<QueryNode> > () = std::move(q);
}
#line 817 "search_parser.cpp"
break;

case 5: // query: query OR clause
#line 98 "search_parser.y"
{
auto query = std::make_unique<OrQueryNode>();
query->Add(std::move(yystack_[2].value.as < std::unique_ptr<QueryNode> > ()));
query->Add(std::move(yystack_[0].value.as < std::unique_ptr<QueryNode> > ()));
yylhs.value.as < std::unique_ptr<QueryNode> > () = std::move(query);
auto q = std::make_unique<OrQueryNode>();
q->Add(std::move(yystack_[2].value.as < std::unique_ptr<QueryNode> > ()));
q->Add(std::move(yystack_[0].value.as < std::unique_ptr<QueryNode> > ()));
yylhs.value.as < std::unique_ptr<QueryNode> > () = std::move(q);
}
#line 828 "search_parser.cpp"
break;
Expand Down Expand Up @@ -901,7 +901,7 @@ namespace infinity {
YYERROR;
}
std::string text = SearchDriver::Unescape(yystack_[0].value.as < InfString > ().text_);
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(field, std::move(text), yystack_[0].value.as < InfString > ().from_quoted_);
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(field, text, yystack_[0].value.as < InfString > ().from_quoted_);
}
#line 907 "search_parser.cpp"
break;
Expand All @@ -911,7 +911,7 @@ namespace infinity {
{
std::string field = SearchDriver::Unescape(yystack_[2].value.as < InfString > ().text_);
std::string text = SearchDriver::Unescape(yystack_[0].value.as < InfString > ().text_);
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(std::move(field), std::move(text), yystack_[0].value.as < InfString > ().from_quoted_);
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(field, text, yystack_[0].value.as < InfString > ().from_quoted_);
}
#line 917 "search_parser.cpp"
break;
Expand All @@ -925,7 +925,7 @@ namespace infinity {
YYERROR;
}
std::string text = SearchDriver::Unescape(yystack_[1].value.as < InfString > ().text_);
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(field, std::move(text), yystack_[1].value.as < InfString > ().from_quoted_, yystack_[0].value.as < unsigned long > ());
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(field, text, yystack_[1].value.as < InfString > ().from_quoted_, yystack_[0].value.as < unsigned long > ());
}
#line 931 "search_parser.cpp"
break;
Expand All @@ -935,7 +935,7 @@ namespace infinity {
{
std::string field = SearchDriver::Unescape(yystack_[3].value.as < InfString > ().text_);
std::string text = SearchDriver::Unescape(yystack_[1].value.as < InfString > ().text_);
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(std::move(field), std::move(text), yystack_[1].value.as < InfString > ().from_quoted_, yystack_[0].value.as < unsigned long > ());
yylhs.value.as < std::unique_ptr<QueryNode> > () = driver.AnalyzeAndBuildQueryNode(field, text, yystack_[1].value.as < InfString > ().from_quoted_, yystack_[0].value.as < unsigned long > ());
}
#line 941 "search_parser.cpp"
break;
Expand Down
26 changes: 13 additions & 13 deletions src/parser/search_parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ topLevelQuery
query
: clause { $$ = std::move($1); }
| query clause {
auto query = driver.GetMultiQueryNodeByOperatorOption();
auto *multi_query_ptr = dynamic_cast<MultiQueryNode *>(query.get());
multi_query_ptr->Add(std::move($1));
multi_query_ptr->Add(std::move($2));
$$ = std::move(query);
assert(driver.operator_option_ == FulltextQueryOperatorOption::kInfinitySyntax);
auto q = std::make_unique<OrQueryNode>();
q->Add(std::move($1));
q->Add(std::move($2));
$$ = std::move(q);
}
| query OR clause {
auto query = std::make_unique<OrQueryNode>();
query->Add(std::move($1));
query->Add(std::move($3));
$$ = std::move(query);
auto q = std::make_unique<OrQueryNode>();
q->Add(std::move($1));
q->Add(std::move($3));
$$ = std::move(q);
};

clause
Expand Down Expand Up @@ -141,12 +141,12 @@ basic_filter
YYERROR;
}
std::string text = SearchDriver::Unescape($1.text_);
$$ = driver.AnalyzeAndBuildQueryNode(field, std::move(text), $1.from_quoted_);
$$ = driver.AnalyzeAndBuildQueryNode(field, text, $1.from_quoted_);
}
| STRING OP_COLON STRING {
std::string field = SearchDriver::Unescape($1.text_);
std::string text = SearchDriver::Unescape($3.text_);
$$ = driver.AnalyzeAndBuildQueryNode(std::move(field), std::move(text), $3.from_quoted_);
$$ = driver.AnalyzeAndBuildQueryNode(field, text, $3.from_quoted_);
};
| STRING TILDE {
const std::string &field = default_field;
Expand All @@ -155,12 +155,12 @@ basic_filter
YYERROR;
}
std::string text = SearchDriver::Unescape($1.text_);
$$ = driver.AnalyzeAndBuildQueryNode(field, std::move(text), $1.from_quoted_, $2);
$$ = driver.AnalyzeAndBuildQueryNode(field, text, $1.from_quoted_, $2);
}
| STRING OP_COLON STRING TILDE {
std::string field = SearchDriver::Unescape($1.text_);
std::string text = SearchDriver::Unescape($3.text_);
$$ = driver.AnalyzeAndBuildQueryNode(std::move(field), std::move(text), $3.from_quoted_, $4);
$$ = driver.AnalyzeAndBuildQueryNode(field, text, $3.from_quoted_, $4);
};

%%
Expand Down
1 change: 1 addition & 0 deletions src/storage/invertedindex/search/doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export enum class DocIteratorType : u8 {
kBMWIterator,
kFilterIterator,
kScoreThresholdIterator,
kKeywordIterator,
};

export struct DocIteratorEstimateIterateCost {
Expand Down
71 changes: 71 additions & 0 deletions src/storage/invertedindex/search/keyword_iterator.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module keyword_iterator;

import stl;
import index_defines;
import doc_iterator;
import multi_doc_iterator;
import or_iterator;
import internal_types;

namespace infinity {

export class KeywordIterator final : public MultiDocIterator {
public:
KeywordIterator(Vector<UniquePtr<DocIterator>> iterators, const float weight) : MultiDocIterator(std::move(iterators)), weight_(weight) {}

DocIteratorType GetType() const override { return DocIteratorType::kKeywordIterator; }

String Name() const override { return "KeywordIterator"; }

/* pure virtual methods implementation */
bool Next(const RowID doc_id) override {
if (doc_id_ == INVALID_ROWID) {
for (u32 i = 0; i < children_.size(); ++i) {
children_[i]->Next();
DocIteratorEntry entry = {children_[i]->DocID(), i};
heap_.AddEntry(entry);
}
heap_.BuildHeap();
doc_id_ = heap_.TopEntry().doc_id_;
}
if (doc_id_ != INVALID_ROWID && doc_id_ >= doc_id) {
return true;
}
while (doc_id > heap_.TopEntry().doc_id_) {
DocIterator *top = children_[heap_.TopEntry().entry_id_].get();
top->Next(doc_id);
heap_.TopEntry().doc_id_ = top->DocID();
heap_.AdjustDown(1);
}
doc_id_ = heap_.TopEntry().doc_id_;
return doc_id_ != INVALID_ROWID;
}

float Score() override { return weight_; }

void UpdateScoreThreshold(float threshold) override { /* do nothing */ }

u32 MatchCount() const override { return 0; }

private:
const float weight_ = 1.0f;
DocIteratorHeap heap_{};
};

} // namespace infinity
Loading

0 comments on commit 96e1bb9

Please sign in to comment.