diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h similarity index 76% rename from be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_query.h rename to be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h index 52503a2b979f02..370a8a390947c3 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_query.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h @@ -19,27 +19,27 @@ #include -#include "olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_weight.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_weight.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" #include "roaring/roaring.hh" namespace doris::segment_v2::inverted_index::query_v2 { -class BitmapQuery : public Query { +class BitSetQuery : public Query { public: - explicit BitmapQuery(std::shared_ptr bitmap) : _bitmap(std::move(bitmap)) {} - BitmapQuery(const roaring::Roaring& bitmap) + explicit BitSetQuery(std::shared_ptr bitmap) : _bitmap(std::move(bitmap)) {} + BitSetQuery(const roaring::Roaring& bitmap) : _bitmap(std::make_shared(bitmap)) {} - ~BitmapQuery() override = default; + ~BitSetQuery() override = default; WeightPtr weight(bool /*enable_scoring*/) override { - return std::make_shared(_bitmap); + return std::make_shared(_bitmap); } private: std::shared_ptr _bitmap; }; -using BitmapQueryPtr = std::shared_ptr; +using BitSetQueryPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h similarity index 87% rename from be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_scorer.h rename to be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h index 931015e2539aba..a834df1c8a2532 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h @@ -20,23 +20,22 @@ #include #include #include -#include #include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" #include "roaring/roaring.hh" namespace doris::segment_v2::inverted_index::query_v2 { -class BitmapScorer final : public Scorer { +class BitSetScorer final : public Scorer { public: - BitmapScorer(std::shared_ptr bitmap, + BitSetScorer(std::shared_ptr bitmap, std::shared_ptr null_bitmap = nullptr) - : _bitmap(std::move(bitmap)), + : _bit_set(std::move(bitmap)), _null_bitmap(std::move(null_bitmap)), - _it(_bitmap->begin()) { + _it(_bit_set->begin()) { _doc = (_it.i.has_value) ? *_it : TERMINATED; } - ~BitmapScorer() override = default; + ~BitSetScorer() override = default; uint32_t advance() override { if (_doc == TERMINATED) { @@ -70,7 +69,7 @@ class BitmapScorer final : public Scorer { uint32_t doc() const override { return _doc; } uint32_t size_hint() const override { - uint64_t card = _bitmap->cardinality(); + uint64_t card = _bit_set->cardinality(); return static_cast( std::min(card, std::numeric_limits::max())); } @@ -87,10 +86,11 @@ class BitmapScorer final : public Scorer { } private: - std::shared_ptr _bitmap; + std::shared_ptr _bit_set; std::shared_ptr _null_bitmap; roaring::Roaring::const_iterator _it; uint32_t _doc = TERMINATED; }; +using BitSetScorerPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_weight.h similarity index 81% rename from be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_weight.h rename to be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_weight.h index 44a62fa1b925b3..f1cd63f013cbaf 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_weight.h @@ -19,31 +19,30 @@ #include -#include "olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" #include "roaring/roaring.hh" namespace doris::segment_v2::inverted_index::query_v2 { -class BitmapWeight final : public Weight { +class BitSetWeight final : public Weight { public: - BitmapWeight(std::shared_ptr bitmap, + BitSetWeight(std::shared_ptr bitmap, std::shared_ptr null_bitmap = nullptr) : _bitmap(std::move(bitmap)), _null_bitmap(std::move(null_bitmap)) {} - ~BitmapWeight() override = default; + ~BitSetWeight() override = default; ScorerPtr scorer(const QueryExecutionContext& /*context*/) override { if (_bitmap == nullptr || _bitmap->isEmpty()) { return std::make_shared(); } - return std::make_shared(_bitmap, _null_bitmap); + return std::make_shared(_bitmap, _null_bitmap); } private: std::shared_ptr _bitmap; std::shared_ptr _null_bitmap; }; - -using BitmapWeightPtr = std::shared_ptr; +using BitSetWeightPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h index 872cf83815dbab..b427386f85a1e6 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h @@ -22,7 +22,7 @@ #include #include -#include "olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/buffered_union_scorer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/intersection_scorer.h" @@ -271,7 +271,7 @@ class BooleanWeight : public Weight { if (!result.null_bitmap.isEmpty()) { null_ptr = std::make_shared(std::move(result.null_bitmap)); } - return std::make_shared(std::move(true_ptr), std::move(null_ptr)); + return std::make_shared(std::move(true_ptr), std::move(null_ptr)); } OperatorType _type; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/const_score_query/const_score_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/const_score_query/const_score_scorer.h new file mode 100644 index 00000000000000..1ebf47f1d5d5c7 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/const_score_query/const_score_scorer.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +class ConstScoreScorer : public Scorer { +public: + ConstScoreScorer(ScorerPtrT scorer) : _scorer(std::move(scorer)) {} + ~ConstScoreScorer() override = default; + + uint32_t advance() override { return _scorer->advance(); } + uint32_t seek(uint32_t target) override { return _scorer->seek(target); } + uint32_t doc() const override { return _scorer->doc(); } + uint32_t size_hint() const override { return _scorer->size_hint(); } + + float score() override { return _score; } + +private: + ScorerPtrT _scorer; + + float _score = 1.0F; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h index 120ba74e4486ba..55cddec4551790 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h @@ -50,6 +50,80 @@ class DocSet { throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, "size_hint() method not implemented in base DocSet class"); } + + virtual uint32_t freq() const { + throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, + "freq() method not implemented in base DocSet class"); + } + + virtual uint32_t norm() const { + throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, + "norm() method not implemented in base DocSet class"); + } }; +class MockDocSet : public DocSet { +public: + MockDocSet(std::vector docs, uint32_t size_hint_val = 0, uint32_t norm_val = 1) + : _docs(std::move(docs)), _size_hint_val(size_hint_val), _norm_val(norm_val) { + if (_docs.empty()) { + _current_doc = TERMINATED; + } else { + std::ranges::sort(_docs.begin(), _docs.end()); + _current_doc = _docs[0]; + } + if (_size_hint_val == 0) { + _size_hint_val = static_cast(_docs.size()); + } + } + + uint32_t advance() override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + ++_index; + if (_index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + _current_doc = _docs[_index]; + return _current_doc; + } + + uint32_t seek(uint32_t target) override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + if (_current_doc >= target) { + return _current_doc; + } + auto it = std::lower_bound(_docs.begin() + _index, _docs.end(), target); + if (it == _docs.end()) { + _index = _docs.size(); + _current_doc = TERMINATED; + return TERMINATED; + } + _index = static_cast(it - _docs.begin()); + _current_doc = *it; + return _current_doc; + } + + uint32_t doc() const override { return _current_doc; } + + uint32_t size_hint() const override { return _size_hint_val; } + + uint32_t norm() const override { return _norm_val; } + +private: + std::vector _docs; + size_t _index = 0; + uint32_t _current_doc = TERMINATED; + uint32_t _size_hint_val = 0; + uint32_t _norm_val = 1; +}; + +using MockDocSetPtr = std::shared_ptr; + } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp new file mode 100644 index 00000000000000..b15c6ec7a1fb97 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/intersection.h" + +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/postings_with_offset.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +template +std::enable_if_t, IntersectionPtr> +Intersection::create(std::vector& docsets) { + size_t num_docsets = docsets.size(); + if (num_docsets < 2) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "At least 2 docsets are required for intersection"); + } + + std::sort(docsets.begin(), docsets.end(), + [](const TDocSet& a, const TDocSet& b) { return a->size_hint() < b->size_hint(); }); + go_to_first_doc(docsets); + + TDocSet left = std::move(docsets[0]); + TDocSet right = std::move(docsets[1]); + docsets.erase(docsets.begin(), docsets.begin() + 2); + return std::make_shared>(std::move(left), std::move(right), + std::move(docsets)); +} + +template +Intersection::Intersection(TDocSet left, TDocSet right, + std::vector others) + : _left(std::move(left)), _right(std::move(right)), _others(std::move(others)) {} + +template +uint32_t Intersection::advance() { + uint32_t candidate = _left->advance(); + + while (true) { + while (true) { + uint32_t right_doc = _right->seek(candidate); + candidate = _left->seek(right_doc); + if (candidate == right_doc) { + break; + } + } + + bool need_continue = false; + for (const auto& docset : _others) { + uint32_t seek_doc = docset->seek(candidate); + if (seek_doc > candidate) { + candidate = _left->seek(seek_doc); + need_continue = true; + break; + } + } + + if (!need_continue) { + return candidate; + } + } +} + +template +uint32_t Intersection::seek(uint32_t target) { + _left->seek(target); + std::vector docsets; + docsets.push_back(_left); + docsets.push_back(_right); + for (auto& docset : _others) { + docsets.push_back(docset); + } + return go_to_first_doc(docsets); +} + +template +uint32_t Intersection::doc() const { + return _left->doc(); +} + +template +uint32_t Intersection::size_hint() const { + return _left->size_hint(); +} + +template +uint32_t Intersection::norm() const { + return _left->norm(); +} + +template +uint32_t Intersection::go_to_first_doc(const std::vector& docsets) { + if (docsets.empty()) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "At least 1 docset is required for intersection"); + } + + uint32_t candidate = docsets.front()->doc(); + for (size_t i = 1; i < docsets.size(); ++i) { + candidate = std::max(candidate, docsets[i]->seek(candidate)); + } + + while (true) { + bool need_continue = false; + + for (const auto& docset : docsets) { + uint32_t seek_doc = docset->seek(candidate); + if (seek_doc > candidate) { + candidate = docset->doc(); + need_continue = true; + break; + } + } + + if (!need_continue) { + return candidate; + } + } +} + +template +template +std::enable_if_t, TDocSet&> +Intersection::docset_mut_specialized(size_t ord) { + switch (ord) { + case 0: + return _left; + case 1: + return _right; + default: + return _others[ord - 2]; + } +} + +template class Intersection; +template class Intersection; + +// create +template std::enable_if_t< + std::is_same_v, + IntersectionPtr> +Intersection::create< + PositionPostingsWithOffsetPtr>(std::vector& docsets); + +template std::enable_if_t, + IntersectionPtr> +Intersection::create( + std::vector& docsets); + +// docset_mut_specialized +template std::enable_if_t< + std::is_same_v, + PositionPostingsWithOffsetPtr&> +Intersection::docset_mut_specialized< + PositionPostingsWithOffsetPtr>(size_t ord); + +template std::enable_if_t, MockDocSetPtr&> +Intersection::docset_mut_specialized(size_t ord); + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h new file mode 100644 index 00000000000000..8dd2430fd3f3d0 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +class Intersection; + +template +using IntersectionPtr = std::shared_ptr>; + +template +class Intersection final : public DocSet { +public: + Intersection(TDocSet left, TDocSet right, std::vector others); + ~Intersection() override = default; + + template + static std::enable_if_t, IntersectionPtr> create( + std::vector& docsets); + + uint32_t advance() override; + uint32_t seek(uint32_t target) override; + uint32_t doc() const override; + uint32_t size_hint() const override; + uint32_t norm() const override; + + template + std::enable_if_t, TDocSet&> docset_mut_specialized(size_t ord); + +private: + static uint32_t go_to_first_doc(const std::vector& docsets); + + TDocSet _left; + TDocSet _right; + std::vector _others; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_query.h new file mode 100644 index 00000000000000..133cf71afe4b62 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_query.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" +#include "olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class PhraseQuery : public Query { +public: + PhraseQuery(IndexQueryContextPtr context, std::wstring field, std::vector terms) + : _context(std::move(context)), _field(std::move(field)), _terms(std::move(terms)) {} + ~PhraseQuery() override = default; + + WeightPtr weight(bool enable_scoring) override { + if (_terms.size() < 2) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "Phrase query requires at least 2 terms"); + } + + SimilarityPtr bm25_similarity; + if (enable_scoring) { + bm25_similarity = std::make_shared(); + bm25_similarity->for_terms(_context, _field, _terms); + } + return std::make_shared(_context, _field, _terms, bm25_similarity, + enable_scoring); + } + +private: + IndexQueryContextPtr _context; + + std::wstring _field; + std::vector _terms; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp new file mode 100644 index 00000000000000..d1f766a665de7d --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +ScorerPtr PhraseScorer::create_with_offset( + const std::vector>& term_postings_with_offset, + const SimilarityPtr& similarity, uint32_t slop, size_t offset) { + size_t max_offset = offset; + for (const auto& [term_offset, _] : term_postings_with_offset) { + max_offset = std::max(max_offset, term_offset + offset); + } + + size_t num_docsets = term_postings_with_offset.size(); + std::vector> postings_with_offsets; + postings_with_offsets.reserve(num_docsets); + for (const auto& [term_offset, postings] : term_postings_with_offset) { + auto adjusted_offset = static_cast(max_offset - term_offset); + auto postings_with_offset = std::make_shared>( + std::move(postings), adjusted_offset); + postings_with_offsets.emplace_back(std::move(postings_with_offset)); + } + + using IntersectionType = + Intersection, PostingsWithOffsetPtr>; + auto intersection_docset = IntersectionType::create(postings_with_offsets); + std::vector left_positions(100); + std::vector right_positions(100); + auto scorer = std::make_shared>( + std::move(intersection_docset), num_docsets, std::move(left_positions), + std::move(right_positions), 0, similarity, slop); + if (scorer->doc() != TERMINATED && !scorer->phrase_match()) { + scorer->advance(); + } + return scorer; +} + +template +uint32_t PhraseScorer::advance() { + while (true) { + uint32_t doc = _intersection_docset->advance(); + if (doc == TERMINATED || phrase_match()) { + return doc; + } + } +} + +template +uint32_t PhraseScorer::seek(uint32_t target) { + assert(target > doc()); + uint32_t doc = _intersection_docset->seek(target); + if (doc == TERMINATED || phrase_match()) { + return doc; + } + return advance(); +} + +template +uint32_t PhraseScorer::doc() const { + return _intersection_docset->doc(); +} + +template +uint32_t PhraseScorer::size_hint() const { + return _intersection_docset->size_hint(); +} + +template +uint32_t PhraseScorer::norm() const { + return _intersection_docset->norm(); +} + +template +float PhraseScorer::score() { + if (_similarity) { + return _similarity->score(_phrase_count, norm()); + } else { + return 1.0F; + } +} + +template +bool PhraseScorer::phrase_match() { + if (_similarity) { + uint32_t count = compute_phrase_count(); + _phrase_count = count; + return count > 0; + } else { + return phrase_exists(); + } +} + +template +uint32_t PhraseScorer::compute_phrase_count() { + compute_phrase_match(); + if (has_slop()) { + // TODO: Implement sloppy phrase matching logic + return 0; + } else { + return static_cast(intersection_count(_left_positions, _right_positions)); + } +} + +template +bool PhraseScorer::phrase_exists() { + compute_phrase_match(); + if (has_slop()) { + // TODO: Implement sloppy phrase matching logic + return false; + } else { + return intersection_exists(_left_positions, _right_positions); + } +} + +template +void PhraseScorer::compute_phrase_match() { + _intersection_docset->docset_mut_specialized(0)->postings(_left_positions); + for (size_t i = 1; i < _num_terms - 1; ++i) { + _intersection_docset->docset_mut_specialized(i)->postings(_right_positions); + intersection(_left_positions, _right_positions); + if (_left_positions.empty()) { + return; + } + } + _intersection_docset->docset_mut_specialized(_num_terms - 1)->postings(_right_positions); +} + +template +size_t PhraseScorer::intersection_count(const std::vector& left, + const std::vector& right) { + size_t left_index = 0; + size_t right_index = 0; + size_t count = 0; + while (left_index < left.size() && right_index < right.size()) { + uint32_t left_val = left[left_index]; + uint32_t right_val = right[right_index]; + if (left_val < right_val) { + ++left_index; + } else if (left_val == right_val) { + ++count; + ++left_index; + ++right_index; + } else { + ++right_index; + } + } + return count; +} + +template +bool PhraseScorer::intersection_exists(const std::vector& left, + const std::vector& right) { + size_t left_index = 0; + size_t right_index = 0; + while (left_index < left.size() && right_index < right.size()) { + uint32_t left_val = left[left_index]; + uint32_t right_val = right[right_index]; + if (left_val < right_val) { + ++left_index; + } else if (left_val == right_val) { + return true; + } else { + ++right_index; + } + } + return false; +} + +template class PhraseScorer; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h new file mode 100644 index 00000000000000..494dfa90b0eefe --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/intersection.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/postings_with_offset.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/similarity/similarity.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +class PhraseScorer; + +template +using PhraseScorerPtr = std::shared_ptr>; + +template +class PhraseScorer : public Scorer { +public: + using IntersectionDocSetPtr = + IntersectionPtr, PostingsWithOffsetPtr>; + + PhraseScorer(IntersectionDocSetPtr intersection_docset, size_t num_terms, + std::vector left_positions, std::vector right_positions, + uint32_t phrase_count, SimilarityPtr similarity, uint32_t slop) + : _intersection_docset(std::move(intersection_docset)), + _num_terms(num_terms), + _left_positions(std::move(left_positions)), + _right_positions(std::move(right_positions)), + _phrase_count(phrase_count), + _similarity(std::move(similarity)), + _slop(slop) {} + ~PhraseScorer() override = default; + + static ScorerPtr create(const std::vector>& term_postings, + const SimilarityPtr& similarity, uint32_t slop) { + return create_with_offset(term_postings, similarity, slop, 0); + } + + uint32_t advance() override; + uint32_t seek(uint32_t target) override; + uint32_t doc() const override; + uint32_t size_hint() const override; + uint32_t norm() const override; + + float score() override; + + bool phrase_match(); + +private: + static ScorerPtr create_with_offset( + const std::vector>& term_postings_with_offset, + const SimilarityPtr& similarity, uint32_t slop, size_t offset); + + bool phrase_exists(); + uint32_t compute_phrase_count(); + void compute_phrase_match(); + size_t intersection_count(const std::vector& left, + const std::vector& right); + bool intersection_exists(const std::vector& left, const std::vector& right); + void intersection(std::vector& left, const std::vector& right); + + bool has_slop() const { return _slop > 0; } + + IntersectionDocSetPtr _intersection_docset; + size_t _num_terms = 0; + std::vector _left_positions; + std::vector _right_positions; + uint32_t _phrase_count = 0; + SimilarityPtr _similarity; + uint32_t _slop = 0; +}; + +template +inline void PhraseScorer::intersection(std::vector& left, + const std::vector& right) { + size_t left_index = 0; + size_t right_index = 0; + size_t count = 0; + const size_t left_len = left.size(); + const size_t right_len = right.size(); + while (left_index < left_len && right_index < right_len) { + uint32_t left_val = left[left_index]; + uint32_t right_val = right[right_index]; + if (left_val < right_val) { + ++left_index; + } else if (left_val == right_val) { + left[count] = left_val; + ++count; + ++left_index; + ++right_index; + } else { + ++right_index; + } + } + left.resize(count); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h new file mode 100644 index 00000000000000..7fe68e33995af8 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" +#include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class PhraseWeight : public Weight { +public: + PhraseWeight(IndexQueryContextPtr context, std::wstring field, std::vector terms, + SimilarityPtr similarity, bool enable_scoring) + : _context(std::move(context)), + _field(std::move(field)), + _terms(std::move(terms)), + _similarity(std::move(similarity)), + _enable_scoring(enable_scoring) {} + ~PhraseWeight() override = default; + + ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string& binding_key) override { + auto scorer = phrase_scorer(ctx, binding_key); + if (scorer) { + return scorer; + } else { + return std::make_shared(); + } + } + +private: + ScorerPtr phrase_scorer(const QueryExecutionContext& ctx, const std::string& binding_key) { + auto reader = lookup_reader(_field, ctx, binding_key); + if (!reader) { + throw Exception(ErrorCode::NOT_FOUND, "Reader not found for field '{}'", + StringHelper::to_string(_field)); + } + + std::vector> term_postings_list; + for (size_t offset = 0; offset < _terms.size(); ++offset) { + const auto& term = _terms[offset]; + auto t = make_term_ptr(_field.c_str(), term.c_str()); + auto iter = make_term_positions_ptr(reader.get(), t.get(), _enable_scoring, + _context->io_ctx); + if (iter) { + auto segment_postings = + std::make_shared>(std::move(iter)); + term_postings_list.emplace_back(offset, std::move(segment_postings)); + } else { + return nullptr; + } + } + return PhraseScorer::create(term_postings_list, _similarity, 0); + } + + IndexQueryContextPtr _context; + + std::wstring _field; + std::vector _terms; + SimilarityPtr _similarity; + bool _enable_scoring = false; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/postings_with_offset.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/postings_with_offset.h new file mode 100644 index 00000000000000..0302bc36081c8e --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/postings_with_offset.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h" +#include "olap/rowset/segment_v2/inverted_index_common.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +class PostingsWithOffset : public DocSet { +public: + PostingsWithOffset(TPostings postings, uint32_t offset) + : _postings(std::move(postings)), _offset(offset) {} + + void postings(std::vector& output) { + _postings->positions_with_offset(_offset, output); + } + + uint32_t advance() override { return _postings->advance(); } + uint32_t seek(uint32_t target) override { return _postings->seek(target); } + uint32_t doc() const override { return _postings->doc(); } + uint32_t size_hint() const override { return _postings->size_hint(); } + uint32_t freq() const override { return _postings->freq(); } + uint32_t norm() const override { return _postings->norm(); } + +private: + TPostings _postings; + uint32_t _offset = 0; +}; + +template +using PostingsWithOffsetPtr = std::shared_ptr>; + +using PositionPostings = std::shared_ptr>; +using PositionPostingsWithOffsetPtr = std::shared_ptr>; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_query.h new file mode 100644 index 00000000000000..9f1e7491b50eea --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_query.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class RegexpQuery : public Query { +public: + RegexpQuery(IndexQueryContextPtr context, std::wstring field, std::string pattern) + : _context(std::move(context)), + _field(std::move(field)), + _pattern(std::move(pattern)) {} + ~RegexpQuery() override = default; + + WeightPtr weight(bool enable_scoring) override { + return std::make_shared(std::move(_context), std::move(_field), + std::move(_pattern), enable_scoring); + } + +private: + IndexQueryContextPtr _context; + + std::wstring _field; + std::string _pattern; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp new file mode 100644 index 00000000000000..5404abaddb0a4d --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "regexp_weight.h" + +#include +#include +#include + +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/const_score_query/const_score_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h" +#include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" + +CL_NS_USE(index) + +namespace doris::segment_v2::inverted_index::query_v2 { + +RegexpWeight::RegexpWeight(IndexQueryContextPtr context, std::wstring field, std::string pattern, + bool enable_scoring) + : _context(std::move(context)), + _field(std::move(field)), + _pattern(std::move(pattern)), + _enable_scoring(enable_scoring) { + // _max_expansions = _context->runtime_state->query_options().inverted_index_max_expansions; +} + +ScorerPtr RegexpWeight::scorer(const QueryExecutionContext& context, + const std::string& binding_key) { + auto prefix = get_regex_prefix(_pattern); + + hs_database_t* database = nullptr; + hs_compile_error_t* compile_err = nullptr; + hs_scratch_t* scratch = nullptr; + + if (hs_compile(_pattern.data(), HS_FLAG_DOTALL | HS_FLAG_ALLOWEMPTY | HS_FLAG_UTF8, + HS_MODE_BLOCK, nullptr, &database, &compile_err) != HS_SUCCESS) { + LOG(ERROR) << "hyperscan compilation failed: " << compile_err->message; + hs_free_compile_error(compile_err); + return std::make_shared(); + } + + if (hs_alloc_scratch(database, &scratch) != HS_SUCCESS) { + LOG(ERROR) << "hyperscan could not allocate scratch space."; + hs_free_database(database); + return std::make_shared(); + } + + std::vector matching_terms; + try { + collect_matching_terms(context, binding_key, matching_terms, database, scratch, prefix); + } catch (...) { + hs_free_scratch(scratch); + hs_free_database(database); + throw; + } + + hs_free_scratch(scratch); + hs_free_database(database); + + if (matching_terms.empty()) { + return std::make_shared(); + } + + auto doc_bitset = std::make_shared(); + for (const auto& term : matching_terms) { + auto t = make_term_ptr(_field.c_str(), term.c_str()); + auto reader = lookup_reader(_field, context, binding_key); + auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring, _context->io_ctx); + auto segment_postings = std::make_shared>(std::move(iter)); + + uint32_t doc = segment_postings->doc(); + while (doc != TERMINATED) { + doc_bitset->add(doc); + doc = segment_postings->advance(); + } + } + + auto bit_set = std::make_shared(doc_bitset); + auto const_score = std::make_shared>(std::move(bit_set)); + return const_score; +} + +std::optional RegexpWeight::get_regex_prefix(const std::string& pattern) { + DBUG_EXECUTE_IF("RegexpQuery.get_regex_prefix", { return std::nullopt; }); + + if (pattern.empty() || pattern[0] != '^') { + return std::nullopt; + } + + re2::RE2 re(pattern); + if (!re.ok()) { + return std::nullopt; + } + + std::string min_prefix, max_prefix; + if (!re.PossibleMatchRange(&min_prefix, &max_prefix, 256)) { + return std::nullopt; + } + + if (min_prefix.empty() || max_prefix.empty() || min_prefix[0] != max_prefix[0]) { + return std::nullopt; + } + + auto [it1, it2] = std::ranges::mismatch(min_prefix, max_prefix); + + const size_t common_len = std::distance(min_prefix.begin(), it1); + if (common_len == 0) { + return std::nullopt; + } + + return min_prefix.substr(0, common_len); +} + +void RegexpWeight::collect_matching_terms(const QueryExecutionContext& context, + const std::string& binding_key, + std::vector& terms, hs_database_t* database, + hs_scratch_t* scratch, + const std::optional& prefix) { + auto on_match = [](unsigned int id, unsigned long long from, unsigned long long to, + unsigned int flags, void* context) -> int { + *((bool*)context) = true; + return 0; + }; + + auto reader = lookup_reader(_field, context, binding_key); + if (reader == nullptr) { + return; + } + + int32_t count = 0; + Term* term = nullptr; + TermEnum* enumerator = nullptr; + + try { + if (prefix) { + std::wstring ws_prefix = StringUtil::string_to_wstring(*prefix); + Term prefix_term(_field.c_str(), ws_prefix.c_str()); + enumerator = reader->terms(&prefix_term, _context->io_ctx); + } else { + enumerator = reader->terms(nullptr, _context->io_ctx); + if (enumerator) { + enumerator->next(); + } + } + + if (!enumerator) { + return; + } + + do { + term = enumerator->term(); + if (term == nullptr) { + break; + } + + if (_field != term->field()) { + _CLDECDELETE(term); + break; + } + + auto term_text = + StringHelper::to_string(std::wstring(term->text(), term->textLength())); + + if (prefix && !term_text.starts_with(*prefix)) { + _CLDECDELETE(term); + break; + } + + bool is_match = false; + if (hs_scan(database, term_text.data(), static_cast(term_text.size()), 0, + scratch, on_match, (void*)&is_match) != HS_SUCCESS) { + LOG(ERROR) << "hyperscan match failed: " << term_text; + _CLDECDELETE(term); + break; + } + + if (is_match) { + if (_max_expansions > 0 && count >= _max_expansions) { + _CLDECDELETE(term); + break; + } + + terms.emplace_back(term->text(), term->textLength()); + count++; + } + + _CLDECDELETE(term); + } while (enumerator->next()); + + _CLDECDELETE(term); + if (enumerator) { + enumerator->close(); + _CLDELETE(enumerator); + } + } catch (...) { + _CLDECDELETE(term); + if (enumerator) { + enumerator->close(); + _CLDELETE(enumerator); + } + throw; + } +} + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h new file mode 100644 index 00000000000000..d85d42a6d298e0 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class RegexpWeight : public Weight { +public: + RegexpWeight(IndexQueryContextPtr context, std::wstring field, std::string pattern, + bool enable_scoring); + ~RegexpWeight() override = default; + + ScorerPtr scorer(const QueryExecutionContext& context, const std::string& binding_key) override; + +private: + std::optional get_regex_prefix(const std::string& pattern); + void collect_matching_terms(const QueryExecutionContext& context, + const std::string& binding_key, std::vector& terms, + hs_database_t* database, hs_scratch_t* scratch, + const std::optional& prefix); + + IndexQueryContextPtr _context; + + std::wstring _field; + std::string _pattern; + bool _enable_scoring = false; + int32_t _max_expansions = 50; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h index 3f1bc133c16d54..2f65bdf1cd1d7d 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h @@ -55,15 +55,9 @@ class SegmentPostingsBase : public DocSet { uint32_t size_hint() const override { return _iter->docFreq(); } - virtual int32_t freq() const { - throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, - "freq() method not implemented in base SegmentPostingsBase class"); - } + uint32_t freq() const override { return _iter->freq(); } - virtual int32_t norm() const { - throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, - "norm() method not implemented in base SegmentPostingsBase class"); - } + uint32_t norm() const override { return _iter->norm(); } protected: TermIterator _iter; @@ -77,8 +71,26 @@ class SegmentPostings final : public SegmentPostingsBase { public: SegmentPostings(TermIterator iter) : SegmentPostingsBase(std::move(iter)) {} - int32_t freq() const override { return this->_iter->freq(); } - int32_t norm() const override { return this->_iter->norm(); } + void positions_with_offset(uint32_t offset, std::vector& output) { + output.clear(); + append_positions_with_offset(offset, output); + } + + void append_positions_with_offset(uint32_t offset, std::vector& output) { + static_assert( + requires(TermIterator it) { + it->freq(); + it->nextPosition(); + }, "TermIterator must expose freq() and nextPosition()"); + + auto freq = this->_iter->freq(); + size_t prev_len = output.size(); + output.resize(prev_len + freq); + for (int32_t i = 0; i < freq; ++i) { + auto pos = this->_iter->nextPosition(); + output[prev_len + i] = offset + static_cast(pos); + } + } }; template @@ -86,8 +98,8 @@ class NoScoreSegmentPosting final : public SegmentPostingsBase { public: NoScoreSegmentPosting(TermIterator iter) : SegmentPostingsBase(std::move(iter)) {} - int32_t freq() const override { return 1; } - int32_t norm() const override { return 1; } + uint32_t freq() const override { return 1; } + uint32_t norm() const override { return 1; } }; template @@ -100,8 +112,8 @@ class EmptySegmentPosting final : public SegmentPostingsBase { uint32_t doc() const override { return TERMINATED; } uint32_t size_hint() const override { return 0; } - int32_t freq() const override { return 1; } - int32_t norm() const override { return 1; } + uint32_t freq() const override { return 1; } + uint32_t norm() const override { return 1; } }; } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h index 7e12eb1da7eee2..77bbc922b1b835 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h @@ -42,12 +42,10 @@ class TermScorer final : public Scorer { uint32_t seek(uint32_t target) override { return _segment_postings->seek(target); } uint32_t doc() const override { return _segment_postings->doc(); } uint32_t size_hint() const override { return _segment_postings->size_hint(); } + uint32_t freq() const override { return _segment_postings->freq(); } + uint32_t norm() const override { return _segment_postings->norm(); } - float score() override { - auto freq = _segment_postings->freq(); - auto norm = _segment_postings->norm(); - return _similarity->score(static_cast(freq), norm); - } + float score() override { return _similarity->score(freq(), norm()); } bool has_null_bitmap(const NullBitmapResolver* resolver = nullptr) override { _ensure_null_bitmap(resolver); @@ -98,7 +96,7 @@ class TermScorer final : public Scorer { SegmentPostingsPtr _segment_postings; SimilarityPtr _similarity; - std::string _logical_field = {}; + std::string _logical_field; bool _null_bitmap_checked = false; std::optional _null_bitmap; }; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h index 8eed7e3278e504..7f0e329d88e925 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h @@ -36,10 +36,8 @@ class TermWeight : public Weight { _logical_field(std::move(logical_field)) {} ~TermWeight() override = default; - ScorerPtr scorer(const QueryExecutionContext& ctx) override { return scorer(ctx, {}); } - ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string& binding_key) override { - auto reader = lookup_reader(ctx, binding_key); + auto reader = lookup_reader(_field, ctx, binding_key); auto field_name = _logical_field.empty() ? std::string(_field.begin(), _field.end()) : _logical_field; auto make_scorer = [&](auto segment_postings) -> ScorerPtr { @@ -72,23 +70,6 @@ class TermWeight : public Weight { } private: - std::shared_ptr lookup_reader( - const QueryExecutionContext& ctx, const std::string& binding_key) const { - if (!binding_key.empty()) { - if (auto it = ctx.reader_bindings.find(binding_key); it != ctx.reader_bindings.end()) { - return it->second; - } - } - if (auto it = ctx.field_reader_bindings.find(_field); - it != ctx.field_reader_bindings.end()) { - return it->second; - } - if (!ctx.readers.empty()) { - return ctx.readers.front(); - } - return nullptr; - } - IndexQueryContextPtr _context; std::wstring _field; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h index 1fcd2dbb14cf0c..c3483128912c92 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h @@ -44,11 +44,31 @@ class Weight { Weight() = default; virtual ~Weight() = default; - virtual ScorerPtr scorer(const QueryExecutionContext& context) = 0; + virtual ScorerPtr scorer(const QueryExecutionContext& context) { return scorer(context, {}); } + virtual ScorerPtr scorer(const QueryExecutionContext& context, const std::string& binding_key) { (void)binding_key; return scorer(context); } + +protected: + std::shared_ptr lookup_reader( + const std::wstring& field, const QueryExecutionContext& ctx, + const std::string& binding_key) const { + if (!binding_key.empty()) { + if (auto it = ctx.reader_bindings.find(binding_key); it != ctx.reader_bindings.end()) { + return it->second; + } + } + if (auto it = ctx.field_reader_bindings.find(field); + it != ctx.field_reader_bindings.end()) { + return it->second; + } + if (!ctx.readers.empty()) { + return ctx.readers.front(); + } + return nullptr; + } }; using WeightPtr = std::shared_ptr; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_query.h new file mode 100644 index 00000000000000..8cd92418a00ea0 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_query.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class WildcardQuery : public Query { +public: + WildcardQuery(IndexQueryContextPtr context, std::wstring field, std::string pattern) + : _context(std::move(context)), + _field(std::move(field)), + _pattern(std::move(pattern)) {} + ~WildcardQuery() override = default; + + WeightPtr weight(bool enable_scoring) override { + return std::make_shared(std::move(_context), std::move(_field), + std::move(_pattern), enable_scoring); + } + +private: + IndexQueryContextPtr _context; + + std::wstring _field; + std::string _pattern; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h new file mode 100644 index 00000000000000..da2de84eae30c7 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include + +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class WildcardWeight : public Weight { +public: + WildcardWeight(IndexQueryContextPtr context, std::wstring field, std::string pattern, + bool enable_scoring) + : _context(std::move(context)), + _field(std::move(field)), + _pattern(std::move(pattern)), + _enable_scoring(enable_scoring) {} + + ~WildcardWeight() override = default; + + ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string& binding_key) override { + std::string regex_pattern = wildcard_to_regex(_pattern); + auto regexp_weight = std::make_shared( + _context, std::move(_field), std::move(regex_pattern), _enable_scoring); + return regexp_weight->scorer(ctx, binding_key); + } + +private: + std::string wildcard_to_regex(const std::string& pattern) { + std::string escaped = RE2::QuoteMeta(pattern); + escaped = std::regex_replace(escaped, std::regex(R"(\\\*)"), ".*"); + escaped = std::regex_replace(escaped, std::regex(R"(\\\?)"), "."); + return "^" + escaped + "$"; + } + + IndexQueryContextPtr _context; + + std::wstring _field; + std::string _pattern; + bool _enable_scoring = false; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp index 37c4ca819030a0..a01b467c1e0c89 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.cpp @@ -55,6 +55,29 @@ void BM25Similarity::for_one_term(const IndexQueryContextPtr& context, compute_tf_cache(); } +void BM25Similarity::for_terms(const IndexQueryContextPtr& context, const std::wstring& field_name, + const std::vector& terms) { + if (terms.empty()) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "BM25 requires at least one term"); + } + + _avgdl = context->collection_statistics->get_or_calculate_avg_dl(field_name); + + if (terms.size() == 1) { + _idf = context->collection_statistics->get_or_calculate_idf(field_name, terms[0]); + } else { + float idf_sum = 0.0F; + for (const auto& term : terms) { + float term_idf = context->collection_statistics->get_or_calculate_idf(field_name, term); + idf_sum += term_idf; + } + _idf = idf_sum; + } + + _weight = _boost * _idf * (_k1 + 1.0F); + compute_tf_cache(); +} + float BM25Similarity::score(float freq, int64_t encoded_norm) { float norm_inverse = _cache[((uint8_t)encoded_norm) & 0xFF]; return _weight - _weight / (1.0F + freq * norm_inverse); diff --git a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h index df067e5353f2b9..68a1f3a90db00a 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h @@ -38,6 +38,8 @@ class BM25Similarity : public Similarity { void for_one_term(const IndexQueryContextPtr& context, const std::wstring& field_name, const std::wstring& term) override; + void for_terms(const IndexQueryContextPtr& context, const std::wstring& field_name, + const std::vector& terms) override; float score(float freq, int64_t encoded_norm) override; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h b/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h index 221b29ecd98f11..744254d410b71e 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/similarity/similarity.h @@ -32,6 +32,8 @@ class Similarity { virtual void for_one_term(const IndexQueryContextPtr& context, const std::wstring& field_name, const std::wstring& term) = 0; + virtual void for_terms(const IndexQueryContextPtr& context, const std::wstring& field_name, + const std::vector& terms) = 0; virtual float score(float freq, int64_t encoded_norm) = 0; }; diff --git a/be/src/vec/exprs/vsearch.cpp b/be/src/vec/exprs/vsearch.cpp index e32209d614b8dc..1786aa73fa2d03 100644 --- a/be/src/vec/exprs/vsearch.cpp +++ b/be/src/vec/exprs/vsearch.cpp @@ -98,7 +98,7 @@ Status collect_search_inputs(const VSearchExpr& expr, VExprContext* context, if (child->expr_name() == "element_at" && child_index < field_bindings.size() && field_bindings[child_index].__isset.is_variant_subcolumn && field_bindings[child_index].is_variant_subcolumn) { - // Variant subcolumn not materialized - skip, will create empty BitmapQuery in function_search + // Variant subcolumn not materialized - skip, will create empty BitSetQuery in function_search child_index++; continue; } diff --git a/be/src/vec/functions/function_search.cpp b/be/src/vec/functions/function_search.cpp index 650d637d0a7b70..19ec3a336128ab 100644 --- a/be/src/vec/functions/function_search.cpp +++ b/be/src/vec/functions/function_search.cpp @@ -35,7 +35,7 @@ #include "olap/rowset/segment_v2/index_file_reader.h" #include "olap/rowset/segment_v2/index_query_context.h" #include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/operator.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_query.h" @@ -417,7 +417,7 @@ Status FunctionSearch::build_query_recursive(const TSearchClause& clause, std::string child_binding_key; RETURN_IF_ERROR(build_query_recursive(child_clause, context, resolver, &child_query, &child_binding_key)); - // Add all children including empty BitmapQuery + // Add all children including empty BitSetQuery // BooleanQuery will handle the logic: // - AND with empty bitmap → result is empty // - OR with empty bitmap → empty bitmap is ignored by OR logic @@ -460,9 +460,9 @@ Status FunctionSearch::build_leaf_query(const TSearchClause& clause, // Check if binding is empty (variant subcolumn not found in this segment) if (binding.lucene_reader == nullptr) { VLOG_DEBUG << "build_leaf_query: Variant subcolumn '" << field_name - << "' has no index in this segment, creating empty BitmapQuery (no matches)"; - // Variant subcolumn doesn't exist - create empty BitmapQuery (no matches) - *out = std::make_shared(roaring::Roaring()); + << "' has no index in this segment, creating empty BitSetQuery (no matches)"; + // Variant subcolumn doesn't exist - create empty BitSetQuery (no matches) + *out = std::make_shared(roaring::Roaring()); if (binding_key) { binding_key->clear(); } diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp index a9427e0f0bfa61..6e17de9eacb8a6 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query/query_helper_test.cpp @@ -40,6 +40,10 @@ class MockSimilarity : public doris::segment_v2::Similarity { const std::wstring& field_name, const std::wstring& term) override {} + MOCK_FUNCTION void for_terms(const IndexQueryContextPtr& context, + const std::wstring& field_name, + const std::vector& terms) override {} + MOCK_FUNCTION float score(float freq, int64_t encoded_norm) override { return _score_value; } private: diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp index 5e0e783ca8e4f1..e2a3cb479fefc0 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp @@ -28,7 +28,7 @@ #include "common/status.h" #include "olap/rowset/segment_v2/index_query_context.h" #include "olap/rowset/segment_v2/inverted_index/analyzer/custom_analyzer.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/bitmap_query/bitmap_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/operator.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_query.h" #include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" @@ -545,7 +545,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_bitmap_and_term) { query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_AND); builder.add(std::make_shared(context, field, StringHelper::to_wstring("apple"))); - builder.add(std::make_shared(bm)); + builder.add(std::make_shared(bm)); auto q = builder.build(); auto w = q->weight(false); @@ -596,7 +596,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_bitmap_or_term) { query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_OR); builder.add(std::make_shared(context, field, StringHelper::to_wstring("apple"))); - builder.add(std::make_shared(bm)); + builder.add(std::make_shared(bm)); auto q = builder.build(); auto w = q->weight(false); diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp new file mode 100644 index 00000000000000..ebe61bc4590c2c --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/intersection.h" + +#include + +#include +#include + +#include "common/exception.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" + +namespace doris { + +using segment_v2::inverted_index::query_v2::DocSet; +using segment_v2::inverted_index::query_v2::Intersection; +using segment_v2::inverted_index::query_v2::TERMINATED; +using segment_v2::inverted_index::query_v2::MockDocSet; +using segment_v2::inverted_index::query_v2::MockDocSetPtr; + +class IntersectionTest : public ::testing::Test { +protected: + void SetUp() override {} + void TearDown() override {} +}; + +// Test creating intersection with less than 2 docsets (should throw exception) +TEST_F(IntersectionTest, test_create_with_empty_docsets) { + std::vector docsets; + + EXPECT_THROW((Intersection::create(docsets)), Exception); +} + +// Test creating intersection with only 1 docset (should throw exception) +TEST_F(IntersectionTest, test_create_with_single_docset) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3})); + + EXPECT_THROW((Intersection::create(docsets)), Exception); +} + +// Test creating intersection with exactly 2 docsets +TEST_F(IntersectionTest, test_create_with_two_docsets) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3, 4, 5})); + docsets.push_back(std::make_shared(std::vector {2, 3, 4, 6, 7})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Should start at first matching document + EXPECT_EQ(2u, intersection->doc()); +} + +// Test intersection advance with two docsets +TEST_F(IntersectionTest, test_advance_two_docsets) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 3, 5, 7, 9})); + docsets.push_back(std::make_shared(std::vector {3, 5, 9, 11})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + std::vector expected {3, 5, 9}; + EXPECT_EQ(expected, results); +} + +// Test intersection with three docsets +TEST_F(IntersectionTest, test_three_docsets) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3, 4, 5, 6})); + docsets.push_back(std::make_shared(std::vector {2, 4, 6, 8})); + docsets.push_back(std::make_shared(std::vector {2, 3, 4, 6, 7})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + std::vector expected {2, 4, 6}; + EXPECT_EQ(expected, results); +} + +// Test intersection with four docsets +TEST_F(IntersectionTest, test_four_docsets) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3, 4, 5, 10})); + docsets.push_back(std::make_shared(std::vector {2, 4, 5, 10, 12})); + docsets.push_back(std::make_shared(std::vector {2, 3, 5, 10, 11})); + docsets.push_back(std::make_shared(std::vector {5, 10, 15})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + std::vector expected {5, 10}; + EXPECT_EQ(expected, results); +} + +// Test intersection with no common documents +TEST_F(IntersectionTest, test_no_intersection) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 3, 5})); + docsets.push_back(std::make_shared(std::vector {2, 4, 6})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Should be terminated immediately + EXPECT_EQ(TERMINATED, intersection->doc()); + EXPECT_EQ(TERMINATED, intersection->advance()); +} + +// Test intersection with single common document +TEST_F(IntersectionTest, test_single_common_document) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 5, 10})); + docsets.push_back(std::make_shared(std::vector {2, 5, 8})); + docsets.push_back(std::make_shared(std::vector {5, 7, 9})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + EXPECT_EQ(5u, intersection->doc()); + EXPECT_EQ(TERMINATED, intersection->advance()); +} + +// Test seek functionality +TEST_F(IntersectionTest, test_seek) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 5, 10, 15, 20})); + docsets.push_back(std::make_shared(std::vector {5, 10, 15, 20, 25})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Seek to doc 10 + EXPECT_EQ(10u, intersection->seek(8)); + EXPECT_EQ(10u, intersection->doc()); + + // Seek to doc 20 + EXPECT_EQ(20u, intersection->seek(18)); + EXPECT_EQ(20u, intersection->doc()); + + // Seek beyond all docs + EXPECT_EQ(TERMINATED, intersection->seek(30)); +} + +// Test seek to current position +TEST_F(IntersectionTest, test_seek_current_position) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {5, 10, 15})); + docsets.push_back(std::make_shared(std::vector {5, 10, 15, 20})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + EXPECT_EQ(5u, intersection->doc()); + + // Seek to current position or before should stay at current + EXPECT_EQ(5u, intersection->seek(5)); + EXPECT_EQ(5u, intersection->doc()); + + EXPECT_EQ(5u, intersection->seek(3)); + EXPECT_EQ(5u, intersection->doc()); +} + +// Test size_hint - should return smallest docset's size hint +TEST_F(IntersectionTest, test_size_hint) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3, 4, 5}, 100)); + docsets.push_back(std::make_shared(std::vector {2, 3, 4}, 50)); + docsets.push_back(std::make_shared(std::vector {2, 3, 4, 5, 6}, 75)); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Should return the smallest size hint (from the smallest docset after sorting) + EXPECT_EQ(50u, intersection->size_hint()); +} + +// Test norm - should return left docset's norm +TEST_F(IntersectionTest, test_norm) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3}, 0, 10)); + docsets.push_back(std::make_shared(std::vector {2, 3, 4}, 0, 20)); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // After creation, docsets are sorted by size_hint, and smallest becomes left + // Both have same size (3), so order depends on stable sort + uint32_t norm = intersection->norm(); + EXPECT_TRUE(norm == 10 || norm == 20); +} + +// Test docset_mut_specialized for accessing individual docsets +TEST_F(IntersectionTest, test_docset_mut_specialized) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3})); + docsets.push_back(std::make_shared(std::vector {2, 3, 4})); + docsets.push_back(std::make_shared(std::vector {2, 3, 5})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Access left (ord 0) + auto& docset0 = intersection->docset_mut_specialized(0); + EXPECT_NE(nullptr, docset0); + EXPECT_EQ(2u, docset0->doc()); + + // Access right (ord 1) + auto& docset1 = intersection->docset_mut_specialized(1); + EXPECT_NE(nullptr, docset1); + + // Access others (ord 2+) + auto& docset2 = intersection->docset_mut_specialized(2); + EXPECT_NE(nullptr, docset2); +} + +// Test all docsets identical +TEST_F(IntersectionTest, test_all_identical_docsets) { + std::vector common_docs {1, 2, 3, 4, 5}; + + std::vector docsets; + docsets.push_back(std::make_shared(common_docs)); + docsets.push_back(std::make_shared(common_docs)); + docsets.push_back(std::make_shared(common_docs)); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + EXPECT_EQ(common_docs, results); +} + +// Test one empty docset +TEST_F(IntersectionTest, test_one_empty_docset) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 2, 3})); + docsets.push_back(std::make_shared(std::vector {})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Should be terminated immediately + EXPECT_EQ(TERMINATED, intersection->doc()); +} + +// Test consecutive advance calls after termination +TEST_F(IntersectionTest, test_advance_after_termination) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1})); + docsets.push_back(std::make_shared(std::vector {1})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + EXPECT_EQ(1u, intersection->doc()); + EXPECT_EQ(TERMINATED, intersection->advance()); + EXPECT_EQ(TERMINATED, intersection->advance()); + EXPECT_EQ(TERMINATED, intersection->advance()); +} + +// Test large document IDs +TEST_F(IntersectionTest, test_large_document_ids) { + std::vector docsets; + docsets.push_back( + std::make_shared(std::vector {1000, 10000, 100000, 1000000})); + docsets.push_back( + std::make_shared(std::vector {500, 10000, 50000, 1000000})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + std::vector expected {10000, 1000000}; + EXPECT_EQ(expected, results); +} + +// Test sparse docsets with large gaps +TEST_F(IntersectionTest, test_sparse_docsets) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {1, 100, 200, 300, 400})); + docsets.push_back( + std::make_shared(std::vector {50, 100, 150, 200, 250, 300})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + std::vector expected {100, 200, 300}; + EXPECT_EQ(expected, results); +} + +// Test seek after advancing +TEST_F(IntersectionTest, test_seek_after_advance) { + std::vector docsets; + docsets.push_back( + std::make_shared(std::vector {1, 5, 10, 15, 20, 25, 30})); + docsets.push_back( + std::make_shared(std::vector {5, 10, 15, 20, 25, 30, 35})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // Start at doc 5 + EXPECT_EQ(5u, intersection->doc()); + + // Advance to doc 10 + EXPECT_EQ(10u, intersection->advance()); + + // Seek to doc 25 + EXPECT_EQ(25u, intersection->seek(22)); + + // Continue advancing + EXPECT_EQ(30u, intersection->advance()); + EXPECT_EQ(TERMINATED, intersection->advance()); +} + +// Test multiple seeks without advance +TEST_F(IntersectionTest, test_multiple_seeks) { + std::vector docsets; + docsets.push_back(std::make_shared(std::vector {10, 20, 30, 40, 50})); + docsets.push_back(std::make_shared(std::vector {10, 20, 30, 40, 50, 60})); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + EXPECT_EQ(10u, intersection->doc()); + + // Seek to 25 + EXPECT_EQ(30u, intersection->seek(25)); + + // Seek to 35 + EXPECT_EQ(40u, intersection->seek(35)); + + // Seek to same position + EXPECT_EQ(40u, intersection->seek(40)); + + // Seek backwards (should stay at current) + EXPECT_EQ(40u, intersection->seek(35)); +} + +// Test docsets with different sizes are sorted correctly +TEST_F(IntersectionTest, test_docsets_sorted_by_size) { + // Create docsets with different sizes + std::vector docsets; + // Largest + docsets.push_back(std::make_shared( + std::vector {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 1000)); + // Smallest + docsets.push_back(std::make_shared(std::vector {2, 4, 6}, 100)); + // Medium + docsets.push_back(std::make_shared(std::vector {2, 3, 4, 5, 6}, 500)); + + auto intersection = Intersection::create(docsets); + ASSERT_NE(nullptr, intersection); + + // size_hint should be from smallest docset + EXPECT_EQ(100u, intersection->size_hint()); + + std::vector results; + uint32_t doc = intersection->doc(); + while (doc != TERMINATED) { + results.push_back(doc); + doc = intersection->advance(); + } + + std::vector expected {2, 4, 6}; + EXPECT_EQ(expected, results); +} + +} // namespace doris \ No newline at end of file diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query_test.cpp new file mode 100644 index 00000000000000..74831ab6917718 --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query_test.cpp @@ -0,0 +1,558 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_query.h" + +#include + +#include +#include +#include +#include + +#include "common/status.h" +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/analyzer/custom_analyzer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h" +#include "olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h" +#include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" + +CL_NS_USE(search) +CL_NS_USE(store) +CL_NS_USE(index) + +namespace doris::segment_v2 { + +using namespace inverted_index; + +class PhraseQueryV2Test : public testing::Test { +public: + const std::string kTestDir = "./ut_dir/phrase_query_test"; + + void SetUp() override { + auto st = io::global_local_filesystem()->delete_directory(kTestDir); + ASSERT_TRUE(st.ok()) << st; + st = io::global_local_filesystem()->create_directory(kTestDir); + ASSERT_TRUE(st.ok()) << st; + std::string field_name = "content"; + create_test_index(field_name, kTestDir); + } + + void TearDown() override { + EXPECT_TRUE(io::global_local_filesystem()->delete_directory(kTestDir).ok()); + } + +private: + void create_test_index(const std::string& field_name, const std::string& dir) { + std::vector test_data = {"the quick brown fox jumps over the lazy dog", + "quick brown dogs are running fast", + "the brown cat sleeps peacefully", + "lazy dogs and quick cats", + "the lazy dog is very lazy", + "quick fox and brown bear", + "the quick brown horse runs", + "dogs and cats are pets", + "the fox is quick and brown", + "brown foxes jump over fences", + "lazy cat sleeps all day", + "quick brown fox in the forest", + "the dog barks loudly", + "brown and white dogs", + "quick movements of animals", + "the lazy afternoon", + "brown fox runs quickly", + "the quick test", + "brown lazy fox", + "quick brown lazy dog"}; + + CustomAnalyzerConfig::Builder builder; + builder.with_tokenizer_config("standard", {}); + auto custom_analyzer_config = builder.build(); + auto custom_analyzer = CustomAnalyzer::build_custom_analyzer(custom_analyzer_config); + + auto* indexwriter = + _CLNEW lucene::index::IndexWriter(dir.c_str(), custom_analyzer.get(), true); + indexwriter->setMaxBufferedDocs(100); + indexwriter->setRAMBufferSizeMB(-1); + indexwriter->setMaxFieldLength(0x7FFFFFFFL); + indexwriter->setMergeFactor(1000000000); + indexwriter->setUseCompoundFile(false); + + auto char_string_reader = std::make_shared>(); + + auto* doc = _CLNEW lucene::document::Document(); + int32_t field_config = lucene::document::Field::STORE_NO; + field_config |= lucene::document::Field::INDEX_NONORMS; + field_config |= lucene::document::Field::INDEX_TOKENIZED; + auto field_name_w = std::wstring(field_name.begin(), field_name.end()); + auto* field = _CLNEW lucene::document::Field(field_name_w.c_str(), field_config); + field->setOmitTermFreqAndPositions(false); + doc->add(*field); + + for (const auto& data : test_data) { + char_string_reader->init(data.data(), data.size(), false); + auto* stream = custom_analyzer->reusableTokenStream(field->name(), char_string_reader); + field->setValue(stream); + indexwriter->addDocument(doc); + } + + indexwriter->close(); + _CLLDELETE(indexwriter); + _CLLDELETE(doc); + } +}; + +static std::shared_ptr make_shared_reader( + lucene::index::IndexReader* raw_reader) { + return {raw_reader, [](lucene::index::IndexReader* reader) { + if (reader != nullptr) { + reader->close(); + _CLDELETE(reader); + } + }}; +} + +// Test basic phrase query construction +TEST_F(PhraseQueryV2Test, test_phrase_query_construction) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("quick"), + StringHelper::to_wstring("brown")}; + + // Test query construction + auto query = std::make_shared(context, field, terms); + ASSERT_NE(query, nullptr); + + // Test weight creation without scoring + auto weight = query->weight(false); + ASSERT_NE(weight, nullptr); + + // Verify weight is of correct type + auto phrase_weight = std::dynamic_pointer_cast(weight); + ASSERT_NE(phrase_weight, nullptr); +} + +// Test phrase query with scoring enabled +// TEST_F(PhraseQueryV2Test, test_phrase_query_with_scoring) { +// auto context = std::make_shared(); +// context->collection_statistics = std::make_shared(); +// context->collection_similarity = std::make_shared(); + +// std::wstring field = StringHelper::to_wstring("content"); +// std::vector terms = {StringHelper::to_wstring("quick"), +// StringHelper::to_wstring("brown"), +// StringHelper::to_wstring("fox")}; + +// auto query = std::make_shared(context, field, terms); +// ASSERT_NE(query, nullptr); + +// // Test weight creation with scoring enabled +// auto weight = query->weight(true); +// ASSERT_NE(weight, nullptr); + +// auto phrase_weight = std::dynamic_pointer_cast(weight); +// ASSERT_NE(phrase_weight, nullptr); +// } + +// Test phrase query with empty terms (should throw exception) +TEST_F(PhraseQueryV2Test, test_phrase_query_empty_terms) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms; // Empty terms + + auto query = std::make_shared(context, field, terms); + ASSERT_NE(query, nullptr); + + // Should throw exception when creating weight with empty terms + EXPECT_THROW({ auto weight = query->weight(false); }, Exception); +} + +// Test phrase query execution with two-term phrase +TEST_F(PhraseQueryV2Test, test_phrase_query_two_terms) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("quick"), + StringHelper::to_wstring("brown")}; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Should match documents containing "quick brown" + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test phrase query execution with three-term phrase +TEST_F(PhraseQueryV2Test, test_phrase_query_three_terms) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("quick"), + StringHelper::to_wstring("brown"), + StringHelper::to_wstring("fox")}; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Should match documents containing "quick brown fox" + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test phrase query with single term (should throw exception) +TEST_F(PhraseQueryV2Test, test_phrase_query_single_term) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("fox")}; + + auto query = std::make_shared(context, field, terms); + ASSERT_NE(query, nullptr); + + // Should throw exception when creating weight with single term (phrase requires at least 2 terms) + EXPECT_THROW({ auto weight = query->weight(false); }, Exception); +} + +// Test phrase query with non-matching phrase +TEST_F(PhraseQueryV2Test, test_phrase_query_no_matches) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("purple"), + StringHelper::to_wstring("elephant")}; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_EQ(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test phrase query with scoring and verify scores +TEST_F(PhraseQueryV2Test, test_phrase_query_scoring) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("quick"), + StringHelper::to_wstring("brown")}; + + // Fill collection statistics for scoring + context->collection_statistics->_total_num_docs = reader_holder->numDocs(); + context->collection_statistics->_total_num_tokens[field] = reader_holder->numDocs() * 8; + context->collection_statistics->_term_doc_freqs[field][StringHelper::to_wstring("quick")] = 10; + context->collection_statistics->_term_doc_freqs[field][StringHelper::to_wstring("brown")] = 10; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(true); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + float total_score = 0.0F; + uint32_t count = 0; + while (doc != query_v2::TERMINATED) { + float score = scorer->score(); + EXPECT_GT(score, 0.0F) << "Score should be positive"; + total_score += score; + result.add(doc); + ++count; + doc = scorer->advance(); + } + + if (count > 0) { + EXPECT_GT(total_score, 0.0F) << "Total score should be positive"; + } + + _CLDECDELETE(dir); +} + +// Test phrase query with binding key +TEST_F(PhraseQueryV2Test, test_phrase_query_with_binding_key) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("lazy"), + StringHelper::to_wstring("dog")}; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + + std::string binding_key = "content#0"; + exec_ctx.reader_bindings[binding_key] = reader_holder; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx, binding_key); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test phrase query destructor (coverage) +TEST_F(PhraseQueryV2Test, test_phrase_query_destructor) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("test"), + StringHelper::to_wstring("phrase")}; + + { + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + ASSERT_NE(weight, nullptr); + // Query and weight will be destroyed at scope exit + } + // If we reach here without crash, destructor works correctly + SUCCEED(); +} + +// Test phrase query with longer phrase (4+ terms) +TEST_F(PhraseQueryV2Test, test_phrase_query_long_phrase) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = { + StringHelper::to_wstring("the"), StringHelper::to_wstring("quick"), + StringHelper::to_wstring("brown"), StringHelper::to_wstring("fox"), + StringHelper::to_wstring("jumps")}; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test phrase query with terms that exist but not in sequence +TEST_F(PhraseQueryV2Test, test_phrase_query_terms_not_in_sequence) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + // These terms exist in documents but not necessarily in this exact sequence + std::vector terms = {StringHelper::to_wstring("dog"), + StringHelper::to_wstring("fox")}; + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // May or may not match depending on the data + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test phrase query with BM25 similarity +TEST_F(PhraseQueryV2Test, test_phrase_query_bm25_similarity) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::vector terms = {StringHelper::to_wstring("quick"), + StringHelper::to_wstring("brown"), + StringHelper::to_wstring("fox")}; + + // Setup statistics for BM25 + context->collection_statistics->_total_num_docs = reader_holder->numDocs(); + context->collection_statistics->_total_num_tokens[field] = reader_holder->numDocs() * 8; + for (const auto& term : terms) { + context->collection_statistics->_term_doc_freqs[field][term] = 5; + } + + auto query = std::make_shared(context, field, terms); + auto weight = query->weight(true); // Enable scoring + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + uint32_t doc = scorer->doc(); + bool found_match = false; + while (doc != query_v2::TERMINATED) { + float score = scorer->score(); + EXPECT_GE(score, 0.0F) << "BM25 score should be non-negative"; + found_match = true; + doc = scorer->advance(); + } + + if (found_match) { + SUCCEED() << "Found matches with BM25 scoring"; + } + + _CLDECDELETE(dir); +} + +} // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query_test.cpp new file mode 100644 index 00000000000000..3dd818764e09b0 --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query_test.cpp @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_query.h" + +#include + +#include +#include +#include +#include + +#include "common/status.h" +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/analyzer/custom_analyzer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.h" +#include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" + +CL_NS_USE(search) +CL_NS_USE(store) +CL_NS_USE(index) + +namespace doris::segment_v2 { + +using namespace inverted_index; + +class RegexpQueryV2Test : public testing::Test { +public: + const std::string kTestDir = "./ut_dir/regexp_query_test"; + + void SetUp() override { + auto st = io::global_local_filesystem()->delete_directory(kTestDir); + ASSERT_TRUE(st.ok()) << st; + st = io::global_local_filesystem()->create_directory(kTestDir); + ASSERT_TRUE(st.ok()) << st; + std::string field_name = "content"; + create_test_index(field_name, kTestDir); + } + + void TearDown() override { + EXPECT_TRUE(io::global_local_filesystem()->delete_directory(kTestDir).ok()); + } + +private: + void create_test_index(const std::string& field_name, const std::string& dir) { + std::vector test_data = { + "apple123", "apple456", "banana789", "test123abc", "pattern456", + "regex999", "match123", "search456", "prefix123", "suffix789", + "apple_test", "banana_data", "test_pattern", "data_regex", "apple_banana", + "test_match", "pattern_test", "prefix_suffix", "abc123xyz", "def456ghi"}; + + CustomAnalyzerConfig::Builder builder; + builder.with_tokenizer_config("standard", {}); + auto custom_analyzer_config = builder.build(); + auto custom_analyzer = CustomAnalyzer::build_custom_analyzer(custom_analyzer_config); + + auto* indexwriter = + _CLNEW lucene::index::IndexWriter(dir.c_str(), custom_analyzer.get(), true); + indexwriter->setMaxBufferedDocs(100); + indexwriter->setRAMBufferSizeMB(-1); + indexwriter->setMaxFieldLength(0x7FFFFFFFL); + indexwriter->setMergeFactor(1000000000); + indexwriter->setUseCompoundFile(false); + + auto char_string_reader = std::make_shared>(); + + auto* doc = _CLNEW lucene::document::Document(); + int32_t field_config = lucene::document::Field::STORE_NO; + field_config |= lucene::document::Field::INDEX_NONORMS; + field_config |= lucene::document::Field::INDEX_TOKENIZED; + auto field_name_w = std::wstring(field_name.begin(), field_name.end()); + auto* field = _CLNEW lucene::document::Field(field_name_w.c_str(), field_config); + field->setOmitTermFreqAndPositions(false); + doc->add(*field); + + for (const auto& data : test_data) { + char_string_reader->init(data.data(), data.size(), false); + auto* stream = custom_analyzer->reusableTokenStream(field->name(), char_string_reader); + field->setValue(stream); + indexwriter->addDocument(doc); + } + + indexwriter->close(); + _CLLDELETE(indexwriter); + _CLLDELETE(doc); + } +}; + +static std::shared_ptr make_shared_reader( + lucene::index::IndexReader* raw_reader) { + return {raw_reader, [](lucene::index::IndexReader* reader) { + if (reader != nullptr) { + reader->close(); + _CLDELETE(reader); + } + }}; +} + +// Test basic regexp query construction and weight creation +TEST_F(RegexpQueryV2Test, test_regexp_query_construction) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = "apple.*"; + + // Test query construction + auto query = std::make_shared(context, field, pattern); + ASSERT_NE(query, nullptr); + + // Test weight creation without scoring + auto weight = query->weight(false); + ASSERT_NE(weight, nullptr); + + // Verify weight is of correct type + auto regexp_weight = std::dynamic_pointer_cast(weight); + ASSERT_NE(regexp_weight, nullptr); +} + +// Test regexp query with scoring enabled +TEST_F(RegexpQueryV2Test, test_regexp_query_with_scoring) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = ".*123.*"; + + auto query = std::make_shared(context, field, pattern); + ASSERT_NE(query, nullptr); + + // Test weight creation with scoring enabled + auto weight = query->weight(true); + ASSERT_NE(weight, nullptr); + + auto regexp_weight = std::dynamic_pointer_cast(weight); + ASSERT_NE(regexp_weight, nullptr); +} + +// Test regexp query execution +TEST_F(RegexpQueryV2Test, test_regexp_query_execution) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = "apple.*"; + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Should match documents containing terms starting with "apple" + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test regexp query with various patterns +TEST_F(RegexpQueryV2Test, test_regexp_query_different_patterns) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + + // Test different regex patterns + std::vector patterns = { + ".*123.*", // Match any term containing 123 + "test.*", // Match terms starting with test + ".*data.*", // Match terms containing data + "prefix.*", // Match terms starting with prefix + ".*xyz.*" // Match terms containing xyz + }; + + for (const auto& pattern : patterns) { + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr) << "Scorer should not be null for pattern: " << pattern; + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Each pattern should match at least some documents + EXPECT_GE(result.cardinality(), 0) + << "Pattern '" << pattern << "' should match some documents"; + } + + _CLDECDELETE(dir); +} + +// Test regexp query with non-matching pattern +TEST_F(RegexpQueryV2Test, test_regexp_query_no_matches) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = "nonexistent.*pattern"; + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Should match no documents + EXPECT_EQ(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test regexp query with binding key +TEST_F(RegexpQueryV2Test, test_regexp_query_with_binding_key) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = "banana.*"; + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + + std::string binding_key = "content#0"; + exec_ctx.reader_bindings[binding_key] = reader_holder; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx, binding_key); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test regexp query destructor (coverage) +TEST_F(RegexpQueryV2Test, test_regexp_query_destructor) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = "test.*"; + + { + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + ASSERT_NE(weight, nullptr); + // Query and weight will be destroyed at scope exit + } + // If we reach here without crash, destructor works correctly + SUCCEED(); +} + +// Test regexp query with complex pattern +TEST_F(RegexpQueryV2Test, test_regexp_query_complex_pattern) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("content"); + // Match terms that start with alphanumeric and contain digits + std::string pattern = "[a-z]+[0-9]+.*"; + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test move semantics in weight() method +TEST_F(RegexpQueryV2Test, test_regexp_query_move_semantics) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("content"); + std::string pattern = "test.*"; + + // Create query and immediately call weight() to test move semantics + auto query = std::make_shared(context, field, pattern); + auto weight1 = query->weight(false); + ASSERT_NE(weight1, nullptr); + + // Create another query to verify weight can be called multiple times + auto query2 = std::make_shared(context, field, pattern); + auto weight2 = query2->weight(true); + ASSERT_NE(weight2, nullptr); +} + +} // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query_test.cpp new file mode 100644 index 00000000000000..0e52dc0bc7dca1 --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query_test.cpp @@ -0,0 +1,527 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_query.h" + +#include + +#include +#include +#include +#include + +#include "common/status.h" +#include "olap/rowset/segment_v2/index_query_context.h" +#include "olap/rowset/segment_v2/inverted_index/analyzer/custom_analyzer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/wildcard_query/wildcard_weight.h" +#include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" + +CL_NS_USE(search) +CL_NS_USE(store) +CL_NS_USE(index) + +namespace doris::segment_v2 { + +using namespace inverted_index; + +class WildcardQueryV2Test : public testing::Test { +public: + const std::string kTestDir = "./ut_dir/wildcard_query_test"; + + void SetUp() override { + auto st = io::global_local_filesystem()->delete_directory(kTestDir); + ASSERT_TRUE(st.ok()) << st; + st = io::global_local_filesystem()->create_directory(kTestDir); + ASSERT_TRUE(st.ok()) << st; + std::string field_name = "field"; + create_test_index(field_name, kTestDir); + } + + void TearDown() override { + EXPECT_TRUE(io::global_local_filesystem()->delete_directory(kTestDir).ok()); + } + +private: + void create_test_index(const std::string& field_name, const std::string& dir) { + std::vector test_data = { + "apple", "application", "apply", "apricot", "banana", "band", + "bandana", "can", "candy", "cat", "catalog", "dog", + "document", "door", "hello", "help", "helpful", "test", + "testing", "testcase", "wildcard", "wild", "wilderness", "card", + "cardboard", "abc", "abcd", "abcde", "prefix123", "123suffix", + "pre123", "123suf", "both123both", "star", "start", "started", + "starter", "question", "quest", "query", "queries"}; + + CustomAnalyzerConfig::Builder builder; + builder.with_tokenizer_config("standard", {}); + auto custom_analyzer_config = builder.build(); + auto custom_analyzer = CustomAnalyzer::build_custom_analyzer(custom_analyzer_config); + + auto* indexwriter = + _CLNEW lucene::index::IndexWriter(dir.c_str(), custom_analyzer.get(), true); + indexwriter->setMaxBufferedDocs(100); + indexwriter->setRAMBufferSizeMB(-1); + indexwriter->setMaxFieldLength(0x7FFFFFFFL); + indexwriter->setMergeFactor(1000000000); + indexwriter->setUseCompoundFile(false); + + auto char_string_reader = std::make_shared>(); + + auto* doc = _CLNEW lucene::document::Document(); + int32_t field_config = lucene::document::Field::STORE_NO; + field_config |= lucene::document::Field::INDEX_NONORMS; + field_config |= lucene::document::Field::INDEX_TOKENIZED; + auto field_name_w = std::wstring(field_name.begin(), field_name.end()); + auto* field = _CLNEW lucene::document::Field(field_name_w.c_str(), field_config); + field->setOmitTermFreqAndPositions(false); + doc->add(*field); + + for (const auto& data : test_data) { + char_string_reader->init(data.data(), data.size(), false); + auto* stream = custom_analyzer->reusableTokenStream(field->name(), char_string_reader); + field->setValue(stream); + indexwriter->addDocument(doc); + } + + indexwriter->close(); + _CLLDELETE(indexwriter); + _CLLDELETE(doc); + } +}; + +static std::shared_ptr make_shared_reader( + lucene::index::IndexReader* raw_reader) { + return {raw_reader, [](lucene::index::IndexReader* reader) { + if (reader != nullptr) { + reader->close(); + _CLDELETE(reader); + } + }}; +} + +// Test basic wildcard query construction +TEST_F(WildcardQueryV2Test, test_wildcard_query_construction) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "app*"; + + // Test query construction + auto query = std::make_shared(context, field, pattern); + ASSERT_NE(query, nullptr); + + // Test weight creation without scoring + auto weight = query->weight(false); + ASSERT_NE(weight, nullptr); + + // Verify weight is of correct type + auto wildcard_weight = std::dynamic_pointer_cast(weight); + ASSERT_NE(wildcard_weight, nullptr); +} + +// Test wildcard query with scoring enabled +TEST_F(WildcardQueryV2Test, test_wildcard_query_with_scoring) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "test*"; + + auto query = std::make_shared(context, field, pattern); + ASSERT_NE(query, nullptr); + + // Test weight creation with scoring enabled + auto weight = query->weight(true); + ASSERT_NE(weight, nullptr); + + auto wildcard_weight = std::dynamic_pointer_cast(weight); + ASSERT_NE(wildcard_weight, nullptr); +} + +// Test wildcard query with asterisk prefix pattern +TEST_F(WildcardQueryV2Test, test_wildcard_query_prefix_pattern) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "app*"; // Match apple, application, apply, apricot + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Should match multiple documents starting with "app" + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with asterisk suffix pattern +TEST_F(WildcardQueryV2Test, test_wildcard_query_suffix_pattern) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "*log"; // Match catalog, dog, etc. + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with asterisk middle pattern +TEST_F(WildcardQueryV2Test, test_wildcard_query_middle_pattern) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "c*d"; // Match card, cardboard + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with multiple asterisks +TEST_F(WildcardQueryV2Test, test_wildcard_query_multiple_asterisks) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "*t*t*"; // Match terms with multiple 't's + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with no wildcard (exact match) +TEST_F(WildcardQueryV2Test, test_wildcard_query_exact_match) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "apple"; // Exact match + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with non-matching pattern +TEST_F(WildcardQueryV2Test, test_wildcard_query_no_matches) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "xyz*abc"; // Non-existent pattern + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_EQ(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with binding key +TEST_F(WildcardQueryV2Test, test_wildcard_query_with_binding_key) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "ban*"; + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + + std::string binding_key = "field#0"; + exec_ctx.reader_bindings[binding_key] = reader_holder; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx, binding_key); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GT(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query destructor (coverage) +TEST_F(WildcardQueryV2Test, test_wildcard_query_destructor) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "test*"; + + { + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + ASSERT_NE(weight, nullptr); + // Query and weight will be destroyed at scope exit + } + // If we reach here without crash, destructor works correctly + SUCCEED(); +} + +// Test wildcard query with special characters in pattern +TEST_F(WildcardQueryV2Test, test_wildcard_query_special_characters) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + // Test with numbers + std::string pattern = "*123*"; + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + EXPECT_GE(result.cardinality(), 0); + + _CLDECDELETE(dir); +} + +// Test wildcard query with all asterisk pattern +TEST_F(WildcardQueryV2Test, test_wildcard_query_all_asterisk) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + auto* dir = FSDirectory::getDirectory(kTestDir.c_str()); + auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); + ASSERT_TRUE(reader_holder != nullptr); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "*"; // Match everything + + auto query = std::make_shared(context, field, pattern); + auto weight = query->weight(false); + + query_v2::QueryExecutionContext exec_ctx; + exec_ctx.segment_num_rows = reader_holder->maxDoc(); + exec_ctx.readers = {reader_holder}; + exec_ctx.field_reader_bindings.emplace(field, reader_holder); + + auto scorer = weight->scorer(exec_ctx); + ASSERT_NE(scorer, nullptr); + + roaring::Roaring result; + uint32_t doc = scorer->doc(); + while (doc != query_v2::TERMINATED) { + result.add(doc); + doc = scorer->advance(); + } + + // Should match all documents + EXPECT_EQ(result.cardinality(), reader_holder->numDocs()); + + _CLDECDELETE(dir); +} + +// Test move semantics in weight() method +TEST_F(WildcardQueryV2Test, test_wildcard_query_move_semantics) { + auto context = std::make_shared(); + context->collection_statistics = std::make_shared(); + context->collection_similarity = std::make_shared(); + + std::wstring field = StringHelper::to_wstring("field"); + std::string pattern = "test*"; + + // Create query and immediately call weight() to test move semantics + auto query = std::make_shared(context, field, pattern); + auto weight1 = query->weight(false); + ASSERT_NE(weight1, nullptr); + + // Create another query to verify weight can be called with scoring + auto query2 = std::make_shared(context, field, pattern); + auto weight2 = query2->weight(true); + ASSERT_NE(weight2, nullptr); +} + +} // namespace doris::segment_v2 \ No newline at end of file