diff --git a/keyvi/include/keyvi/dictionary/completion/forward_backward_completion.h b/keyvi/include/keyvi/dictionary/completion/forward_backward_completion.h index f8a6cd7fb..bbf97eb2f 100644 --- a/keyvi/include/keyvi/dictionary/completion/forward_backward_completion.h +++ b/keyvi/include/keyvi/dictionary/completion/forward_backward_completion.h @@ -53,7 +53,7 @@ class ForwardBackwardCompletion final { ForwardBackwardCompletion(dictionary_t forward_dictionary, dictionary_t backward_dictionary) : forward_completions_(forward_dictionary), backward_completions_(backward_dictionary) {} - struct result_compare : public std::binary_function { + struct result_compare { bool operator()(const Match& m1, const Match& m2) const { return m1.GetScore() < m2.GetScore(); } }; diff --git a/keyvi/include/keyvi/dictionary/dictionary.h b/keyvi/include/keyvi/dictionary/dictionary.h index 3accc345e..dd497356f 100644 --- a/keyvi/include/keyvi/dictionary/dictionary.h +++ b/keyvi/include/keyvi/dictionary/dictionary.h @@ -38,6 +38,8 @@ #include "keyvi/dictionary/match_iterator.h" #include "keyvi/dictionary/matching/fuzzy_matching.h" #include "keyvi/dictionary/matching/near_matching.h" +#include "keyvi/dictionary/matching/prefix_completion_matching.h" +#include "keyvi/dictionary/util/bounded_priority_queue.h" // #define ENABLE_TRACING #include "keyvi/dictionary/util/trace.h" @@ -324,6 +326,40 @@ class Dictionary final { return MatchIterator::MakeIteratorPair(func, data->FirstMatch()); } + MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query) const { + auto data = std::make_shared>( + matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query)); + + auto func = [data]() { return data->NextMatch(); }; + return MatchIterator::MakeIteratorPair( + func, data->FirstMatch(), + std::bind(&matching::PrefixCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1)); + } + + MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query, size_t top_n) const { + auto data = std::make_shared>( + matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query)); + + auto best_weights = std::make_shared>(top_n); + + auto func = [data, best_weights = std::move(best_weights)]() { + auto m = data->NextMatch(); + while (!m.IsEmpty()) { + if (m.GetWeight() >= best_weights->Back()) { + best_weights->Put(m.GetWeight()); + return m; + } + + m = data->NextMatch(); + } + return Match(); + }; + + return MatchIterator::MakeIteratorPair( + func, data->FirstMatch(), + std::bind(&matching::PrefixCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1)); + } + std::string GetManifest() const { return fsa_->GetManifest(); } private: diff --git a/keyvi/include/keyvi/dictionary/fsa/automata.h b/keyvi/include/keyvi/dictionary/fsa/automata.h index bd1c11cd5..b6b2d9363 100644 --- a/keyvi/include/keyvi/dictionary/fsa/automata.h +++ b/keyvi/include/keyvi/dictionary/fsa/automata.h @@ -239,7 +239,7 @@ class Automata final { traversal::TraversalPayload* payload) const { // reset the state traversal_state->Clear(); - uint32_t parent_weight = GetWeightValue(starting_state); + uint32_t parent_weight = GetInnerWeight(starting_state); #if defined(KEYVI_SSE42) // Optimized version using SSE4.2, see http://www.strchr.com/strcmp_and_strlen_using_sse_4.2 @@ -262,7 +262,7 @@ class Automata final { if ((mask_int & 1) == 1) { TRACE("push symbol+%d", symbol + i); uint64_t child_state = ResolvePointer(starting_state, symbol + i); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + i, payload); } @@ -285,49 +285,49 @@ class Automata final { if (((xor_labels_with_mask & 0x00000000000000ffULL) == 0)) { uint64_t child_state = ResolvePointer(starting_state, symbol); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol, payload); } if ((xor_labels_with_mask & 0x000000000000ff00ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 1); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 1, payload); } if ((xor_labels_with_mask & 0x0000000000ff0000ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 2); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 2, payload); } if ((xor_labels_with_mask & 0x00000000ff000000ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 3); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 3, payload); } if ((xor_labels_with_mask & 0x000000ff00000000ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 4); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 4, payload); } if ((xor_labels_with_mask & 0x0000ff0000000000ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 5); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 5, payload); } if ((xor_labels_with_mask & 0x00ff000000000000ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 6); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 6, payload); } if ((xor_labels_with_mask & 0xff00000000000000ULL) == 0) { uint64_t child_state = ResolvePointer(starting_state, symbol + 7); - uint32_t weight = GetWeightValue(child_state); + uint32_t weight = GetInnerWeight(child_state); weight = weight != 0 ? weight : parent_weight; traversal_state->Add(child_state, weight, symbol + 7, payload); } @@ -356,7 +356,7 @@ class Automata final { return keyvi::util::decodeVarShort(transitions_compact_ + state + FINAL_OFFSET_TRANSITION); } - uint32_t GetWeightValue(uint64_t state) const { + uint32_t GetInnerWeight(uint64_t state) const { if (labels_[state + INNER_WEIGHT_TRANSITION_COMPACT] != 0) { return 0; } @@ -364,6 +364,11 @@ class Automata final { return (transitions_compact_[state + INNER_WEIGHT_TRANSITION_COMPACT]); } + uint32_t GetWeight(uint64_t state) const { + assert(value_store_reader_); + return value_store_reader_->GetWeight(state); + } + internal::IValueStoreReader::attributes_t GetValueAsAttributeVector(uint64_t state_value) const { assert(value_store_reader_); return value_store_reader_->GetValueAsAttributeVector(state_value); @@ -379,9 +384,13 @@ class Automata final { return value_store_reader_->GetRawValueAsString(state_value); } - std::string GetStatistics() const { return dictionary_properties_->GetStatistics(); } + std::string GetStatistics() const { + return dictionary_properties_->GetStatistics(); + } - std::string GetManifest() const { return dictionary_properties_->GetManifest(); } + std::string GetManifest() const { + return dictionary_properties_->GetManifest(); + } private: dictionary_properties_t dictionary_properties_; diff --git a/keyvi/include/keyvi/dictionary/fsa/bounded_weighted_state_traverser.h b/keyvi/include/keyvi/dictionary/fsa/bounded_weighted_state_traverser.h index 7333e4296..e74a36ee5 100644 --- a/keyvi/include/keyvi/dictionary/fsa/bounded_weighted_state_traverser.h +++ b/keyvi/include/keyvi/dictionary/fsa/bounded_weighted_state_traverser.h @@ -88,7 +88,7 @@ class BoundedWeightedStateTraverser final { uint64_t GetStateValue() { return fsa_->GetStateValue(current_state_); } - uint32_t GetInnerWeight() { return fsa_->GetWeightValue(current_state_); } + uint32_t GetInnerWeight() { return fsa_->GetInnerWeight(current_state_); } uint64_t GetStateId() { return current_state_; } @@ -220,7 +220,7 @@ class BoundedWeightedStateTraverser final { child_node = fsa_->TryWalkTransition(current_state_, i); if (child_node) { // todo: stop reading the weight dependent on the depth of traversal - uint32_t weight = fsa_->GetWeightValue(child_node); + uint32_t weight = fsa_->GetInnerWeight(child_node); // if weight is not set take the weight of the parent if (weight == 0) { diff --git a/keyvi/include/keyvi/dictionary/fsa/comparable_state_traverser.h b/keyvi/include/keyvi/dictionary/fsa/comparable_state_traverser.h index 450d9510b..9f480b7e3 100644 --- a/keyvi/include/keyvi/dictionary/fsa/comparable_state_traverser.h +++ b/keyvi/include/keyvi/dictionary/fsa/comparable_state_traverser.h @@ -168,6 +168,15 @@ class ComparableStateTraverser final { size_t GetOrder() const { return order_; } + /** + * Set the minimum weight states must be greater or equal to. + * + * Only available for WeightedTransition specialization. + * + * @param weight minimum transition weight + */ + inline void SetMinWeight(uint32_t weight) {} + private: innerTraverserType state_traverser_; std::vector label_stack_; @@ -267,6 +276,16 @@ inline bool ComparableStateTraverser::operator<(const Compar return order_ > rhs.order_; } +/** + * Set the minimum weight states must be greater or equal to. + * + * @param weight minimum transition weight + */ +template <> +inline void ComparableStateTraverser::SetMinWeight(uint32_t weight) { + state_traverser_.SetMinWeight(weight); +} + } /* namespace fsa */ } /* namespace dictionary */ } /* namespace keyvi */ diff --git a/keyvi/include/keyvi/dictionary/fsa/internal/int_inner_weights_value_store.h b/keyvi/include/keyvi/dictionary/fsa/internal/int_inner_weights_value_store.h index 352f211cf..fff3ec46d 100644 --- a/keyvi/include/keyvi/dictionary/fsa/internal/int_inner_weights_value_store.h +++ b/keyvi/include/keyvi/dictionary/fsa/internal/int_inner_weights_value_store.h @@ -104,6 +104,8 @@ class IntInnerWeightsValueStoreReader final : public IValueStoreReader { } std::string GetValueAsString(uint64_t fsa_value) const override { return std::to_string(fsa_value); } + + uint32_t GetWeight(uint64_t fsa_value) const override { return static_cast(fsa_value); } }; template <> diff --git a/keyvi/include/keyvi/dictionary/fsa/internal/ivalue_store.h b/keyvi/include/keyvi/dictionary/fsa/internal/ivalue_store.h index 40bdc63a0..cee376dc9 100644 --- a/keyvi/include/keyvi/dictionary/fsa/internal/ivalue_store.h +++ b/keyvi/include/keyvi/dictionary/fsa/internal/ivalue_store.h @@ -126,6 +126,16 @@ class IValueStoreReader { */ virtual std::string GetValueAsString(uint64_t fsa_value) const = 0; + /** + * Get Weight + * + * This is only supported by value stores that support weights. + * + * @param fsa_value + * @return the weight + */ + virtual uint32_t GetWeight(uint64_t fsa_value) const { return 0; } + /** * Test whether this value store is compatible to the given value store. * Throws if they are not compatible. diff --git a/keyvi/include/keyvi/dictionary/fsa/state_traverser.h b/keyvi/include/keyvi/dictionary/fsa/state_traverser.h index e64f974ba..63270fa0d 100644 --- a/keyvi/include/keyvi/dictionary/fsa/state_traverser.h +++ b/keyvi/include/keyvi/dictionary/fsa/state_traverser.h @@ -29,6 +29,7 @@ #include "keyvi/dictionary/fsa/automata.h" #include "keyvi/dictionary/fsa/traversal/traversal_base.h" +#include "keyvi/dictionary/fsa/traversal/weighted_traversal.h" // #define ENABLE_TRACING #include "keyvi/dictionary/util/trace.h" @@ -119,13 +120,14 @@ class StateTraverser final { void operator++(int) { TRACE("statetraverser++"); + // ignore cases where we are already at the end if (current_state_ == 0) { TRACE("at the end"); return; } - current_state_ = stack_.GetStates().GetNextState(); + current_state_ = FilterByMinWeight(stack_.GetStates().GetNextState()); TRACE("next state: %ld depth: %ld", current_state_, stack_.GetDepth()); while (current_state_ == 0) { @@ -139,7 +141,7 @@ class StateTraverser final { TRACE("state is 0, go up"); --stack_; stack_.GetStates()++; - current_state_ = stack_.GetStates().GetNextState(); + current_state_ = FilterByMinWeight(stack_.GetStates().GetNextState()); TRACE("next state %ld depth %ld", current_state_, stack_.GetDepth()); } @@ -155,6 +157,15 @@ class StateTraverser final { operator bool() const { return !at_end_; } + /** + * Set the minimum weight states must be greater or equal to. + * + * Only available for WeightedTransition specialization. + * + * @param min_weight minimum transition weight + */ + inline void SetMinWeight(uint32_t min_weight) {} + bool AtEnd() const { return at_end_; } private: @@ -165,6 +176,13 @@ class StateTraverser final { bool at_end_; traversal::TraversalStack stack_; + /** + * Filter hook for weighted traversal to filter weights lower than the minimum weight (see spezialisation). + * + * Default: no filter + */ + inline uint64_t FilterByMinWeight(uint64_t state) { return state; } + template friend class ComparableStateTraverser; const traversal::TraversalStack &GetStack() const { return stack_; } @@ -176,6 +194,32 @@ class StateTraverser final { const traversal::TraversalPayload &GetTraversalPayload() const { return stack_.traversal_stack_payload; } }; +/** + * Filter state that doesn't meet the min weight requirement. + * + * This happens when SetMinWeight has been calles after the transitions got already read. + * + * @param state the state we currently look at. + * + * @return the current state if it has a weight higher than the minimum weight, 0 otherwise. + */ +template <> +inline uint64_t StateTraverser::FilterByMinWeight(uint64_t state) { + TRACE("filter min weight for weighted transition specialization"); + return state > 0 && stack_.GetStates().GetNextInnerWeight() >= stack_.traversal_stack_payload.min_weight ? state : 0; +} + +/** + * Set the minimum weight states must be greater or equal to. + * + * @param weight minimum transition weight + */ +template <> +inline void StateTraverser::SetMinWeight(uint32_t min_weight) { + TRACE("set min weight for weighted transition specialization %d", min_weight); + stack_.traversal_stack_payload.min_weight = min_weight; +} + } /* namespace fsa */ } /* namespace dictionary */ } /* namespace keyvi */ diff --git a/keyvi/include/keyvi/dictionary/fsa/traversal/bounded_weighted_traversal.h b/keyvi/include/keyvi/dictionary/fsa/traversal/bounded_weighted_traversal.h deleted file mode 100644 index d03afd36f..000000000 --- a/keyvi/include/keyvi/dictionary/fsa/traversal/bounded_weighted_traversal.h +++ /dev/null @@ -1,67 +0,0 @@ -/* * keyvi - A key value store. - * - * Copyright 2015 Hendrik Muhs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * bounded_weighted_traversal.h - * - * Created on: Nov 17, 2015 - * Author: hendrik - */ - -#ifndef KEYVI_DICTIONARY_FSA_TRAVERSAL_BOUNDED_WEIGHTED_TRAVERSAL_H_ -#define KEYVI_DICTIONARY_FSA_TRAVERSAL_BOUNDED_WEIGHTED_TRAVERSAL_H_ - -#include - -#include "keyvi/dictionary/fsa/traversal/weighted_traversal.h" -#include "keyvi/dictionary/util/bounded_priority_queue.h" - -// #define ENABLE_TRACING -#include "keyvi/dictionary/util/trace.h" - -namespace keyvi { -namespace dictionary { -namespace fsa { -namespace traversal { - -struct BoundedWeightedTransition : public WeightedTransition { - using WeightedTransition::WeightedTransition; -}; - -template <> -struct TraversalPayload { - TraversalPayload() : current_depth(0), priority_queue(10) {} - - size_t current_depth; - util::BoundedPriorityQueue priority_queue; -}; - -template <> -inline void TraversalState::PostProcess( - TraversalPayload* payload) { - if (traversal_state_payload.transitions.size() > 0) { - std::sort(traversal_state_payload.transitions.begin(), traversal_state_payload.transitions.end(), - WeightedTransitionCompare); - } -} - -} /* namespace traversal */ -} /* namespace fsa */ -} /* namespace dictionary */ -} /* namespace keyvi */ - -#endif // KEYVI_DICTIONARY_FSA_TRAVERSAL_BOUNDED_WEIGHTED_TRAVERSAL_H_ diff --git a/keyvi/include/keyvi/dictionary/fsa/traversal/weighted_traversal.h b/keyvi/include/keyvi/dictionary/fsa/traversal/weighted_traversal.h index a1794a9bc..c655692b4 100644 --- a/keyvi/include/keyvi/dictionary/fsa/traversal/weighted_traversal.h +++ b/keyvi/include/keyvi/dictionary/fsa/traversal/weighted_traversal.h @@ -52,6 +52,20 @@ static bool WeightedTransitionCompare(const WeightedTransition& a, const Weighte return a.weight > b.weight; } +template <> +struct TraversalPayload { + size_t current_depth; + uint32_t min_weight = 0; +}; + +template <> +inline void TraversalState::Add(uint64_t s, uint32_t weight, unsigned char l, + TraversalPayload* payload) { + if (weight >= payload->min_weight) { + traversal_state_payload.transitions.push_back(WeightedTransition(s, weight, l)); + } +} + template <> inline void TraversalState::PostProcess(TraversalPayload* payload) { if (traversal_state_payload.transitions.size() > 0) { diff --git a/keyvi/include/keyvi/dictionary/fsa/traverser_types.h b/keyvi/include/keyvi/dictionary/fsa/traverser_types.h index 58dadf1ca..cbf0d4184 100644 --- a/keyvi/include/keyvi/dictionary/fsa/traverser_types.h +++ b/keyvi/include/keyvi/dictionary/fsa/traverser_types.h @@ -26,7 +26,6 @@ #define KEYVI_DICTIONARY_FSA_TRAVERSER_TYPES_H_ #include "keyvi/dictionary/fsa/state_traverser.h" -#include "keyvi/dictionary/fsa/traversal/bounded_weighted_traversal.h" #include "keyvi/dictionary/fsa/traversal/near_traversal.h" #include "keyvi/dictionary/fsa/traversal/weighted_traversal.h" @@ -34,7 +33,6 @@ namespace keyvi { namespace dictionary { namespace fsa { -using BoundedWeightedStateTraverser2 = StateTraverser; using NearStateTraverser = StateTraverser; using WeightedStateTraverser = StateTraverser; diff --git a/keyvi/include/keyvi/dictionary/fsa/zip_state_traverser.h b/keyvi/include/keyvi/dictionary/fsa/zip_state_traverser.h index 61a85a9ea..f787e4023 100644 --- a/keyvi/include/keyvi/dictionary/fsa/zip_state_traverser.h +++ b/keyvi/include/keyvi/dictionary/fsa/zip_state_traverser.h @@ -214,6 +214,15 @@ class ZipStateTraverser final { const std::vector &GetStateLabels() const { return traverser_queue_.top()->GetStateLabels(); } + /** + * Set the minimum weight states must be greater or equal to. + * + * Only available for WeightedTransition specialization. + * + * @param weight minimum transition weight + */ + inline void SetMinWeight(uint32_t weight) {} + private: heap_t traverser_queue_; bool final_ = false; @@ -226,6 +235,8 @@ class ZipStateTraverser final { automata_t fsa_; size_t equal_states_ = 1; bool pruned = false; + // this field only used for weighted traversal, ignored otherwise + uint32_t min_weight_ = 0; inline void PreIncrement() {} @@ -500,6 +511,18 @@ inline ZipStateTraverser::ZipStateTraverser( FillInValues(); } +/** + * Set the minimum weight states must be greater or equal to. + * + * @param weight minimum transition weight + */ +template <> +inline void ZipStateTraverser::SetMinWeight(uint32_t weight) { + for (auto t : traverser_queue_) { + t->SetMinWeight(weight); + } +} + } // namespace fsa } // namespace dictionary } // namespace keyvi diff --git a/keyvi/include/keyvi/dictionary/match.h b/keyvi/include/keyvi/dictionary/match.h index 042ab3c9d..5ad1b587e 100644 --- a/keyvi/include/keyvi/dictionary/match.h +++ b/keyvi/include/keyvi/dictionary/match.h @@ -74,12 +74,13 @@ struct Match { typedef std::shared_ptr>> attributes_t; - Match(size_t a, size_t b, const std::string& matched_item, uint32_t score = 0) + Match(size_t a, size_t b, const std::string& matched_item, uint32_t score = 0, uint32_t weight = 0) : start_(a), end_(b), matched_item_(matched_item), raw_value_(), score_(score) { TRACE("initialized Match %d->%d %s", a, b, matched_item.c_str()); } - Match(size_t a, size_t b, const std::string& matched_item, uint32_t score, const fsa::automata_t& fsa, uint64_t state) + Match(size_t a, size_t b, const std::string& matched_item, uint32_t score, const fsa::automata_t& fsa, uint64_t state, + uint32_t weight = 0) : start_(a), end_(b), matched_item_(matched_item), raw_value_(), score_(score), fsa_(fsa), state_(state) { TRACE("initialized Match %d->%d %s", a, b, matched_item.c_str()); } @@ -141,6 +142,14 @@ struct Match { (*attributes_)[key] = value; } + uint32_t GetWeight() const { + if (!fsa_) { + return 0; + } + + return fsa_->GetWeight(state_); + } + std::string GetValueAsString() const { if (!fsa_) { if (raw_value_.size() != 0) { @@ -175,7 +184,9 @@ struct Match { * * @param value */ - void SetRawValue(const std::string& value) { raw_value_ = value; } + void SetRawValue(const std::string& value) { + raw_value_ = value; + } private: size_t start_ = 0; @@ -193,7 +204,9 @@ struct Match { template friend Match index::internal::FirstFilteredMatch(const MatcherT&, const DeletedT&); - fsa::automata_t& GetFsa() { return fsa_; } + fsa::automata_t& GetFsa() { + return fsa_; + } }; } /* namespace dictionary */ diff --git a/keyvi/include/keyvi/dictionary/match_iterator.h b/keyvi/include/keyvi/dictionary/match_iterator.h index 73fefe78d..565acb367 100644 --- a/keyvi/include/keyvi/dictionary/match_iterator.h +++ b/keyvi/include/keyvi/dictionary/match_iterator.h @@ -55,21 +55,31 @@ class MatchIterator : public boost::iterator_facade MatchIteratorPair; - explicit MatchIterator(std::function match_functor, const Match& first_match = Match()) - : match_functor_(match_functor) { + explicit MatchIterator(std::function match_functor, const Match& first_match = Match(), + std::function set_min_weight = {}) + : match_functor_(match_functor), set_min_weight_(set_min_weight) { current_match_ = first_match; if (first_match.IsEmpty()) { increment(); } } - static MatchIteratorPair MakeIteratorPair(std::function f, const Match& first_match = Match()) { - return MatchIteratorPair(MatchIterator(f, first_match), MatchIterator()); + static MatchIteratorPair MakeIteratorPair(std::function f, const Match& first_match = Match(), + std::function set_min_weight = {}) { + return MatchIteratorPair(MatchIterator(f, first_match, set_min_weight), MatchIterator()); } static MatchIteratorPair EmptyIteratorPair() { return MatchIteratorPair(MatchIterator(), MatchIterator()); } - MatchIterator() : match_functor_(0) {} + MatchIterator() : match_functor_(0), set_min_weight_({}) {} + + void SetMinWeight(uint32_t min_weight) { + // ignore if a min weight setter was not provided + if (set_min_weight_) { + set_min_weight_(min_weight); + } + } + // What we implement is determined by the boost::forward_traversal_tag // template parameter private: @@ -85,6 +95,7 @@ class MatchIterator : public boost::iterator_facade match_functor_; Match current_match_; + std::function set_min_weight_; }; } /* namespace dictionary */ diff --git a/keyvi/include/keyvi/dictionary/matching/prefix_completion_matching.h b/keyvi/include/keyvi/dictionary/matching/prefix_completion_matching.h new file mode 100644 index 000000000..9d13c9064 --- /dev/null +++ b/keyvi/include/keyvi/dictionary/matching/prefix_completion_matching.h @@ -0,0 +1,220 @@ +/* keyvi - A key value store. + * + * Copyright 2024 Hendrik Muhs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * prefix_completion_matching.h + */ + +#ifndef KEYVI_DICTIONARY_MATCHING_PREFIX_COMPLETION_MATCHING_H_ +#define KEYVI_DICTIONARY_MATCHING_PREFIX_COMPLETION_MATCHING_H_ + +#include +#include +#include +#include +#include + +#include "keyvi/dictionary/fsa/automata.h" +#include "keyvi/dictionary/fsa/codepoint_state_traverser.h" +#include "keyvi/dictionary/fsa/traverser_types.h" +#include "keyvi/dictionary/fsa/zip_state_traverser.h" +#include "keyvi/dictionary/match.h" +#include "keyvi/dictionary/util/utf8_utils.h" +#include "keyvi/stringdistance/levenshtein.h" +#include "utf8.h" + +// #define ENABLE_TRACING +#include "keyvi/dictionary/util/trace.h" + +namespace keyvi { +namespace index { +namespace internal { +template +keyvi::dictionary::Match NextFilteredMatchSingle(const MatcherT&, const DeletedT&); +template +keyvi::dictionary::Match NextFilteredMatch(const MatcherT&, const DeletedT&); +} // namespace internal +} // namespace index +namespace dictionary { +namespace matching { + +template +class PrefixCompletionMatching final { + public: + /** + * Create a prefix completer from a single Fsa + * + * @param fsa the fsa + * @param query the query + */ + static PrefixCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const std::string& query) { + return FromSingleFsa(fsa, fsa->GetStartState(), query); + } + + /** + * Create a prefix completer from a single Fsa + * + * @param fsa the fsa + * @param start_state the state to start from + * @param query the query + */ + static PrefixCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const uint64_t start_state, + const std::string& query) { + if (start_state == 0) { + return PrefixCompletionMatching(); + } + + std::unique_ptr> traversal_stack = std::make_unique>(); + traversal_stack->reserve(1024); + + const size_t query_length = query.size(); + size_t depth = 0; + uint64_t state = start_state; + + Match first_match; + + TRACE("start state %d", state); + + while (state != 0 && depth != query_length) { + traversal_stack->push_back(query[depth]); + state = fsa->TryWalkTransition(state, query[depth++]); + } + + TRACE("state %d", state); + + if (state == 0) { + return PrefixCompletionMatching(); + } + + TRACE("matched prefix, length %d", depth); + + std::unique_ptr traverser = std::make_unique(fsa, state); + + if (fsa->IsFinalState(state)) { + first_match = Match(0, query_length, query, 0, fsa, fsa->GetStateValue(state)); + } + + TRACE("create matcher"); + return PrefixCompletionMatching(std::move(traverser), std::move(first_match), std::move(traversal_stack), + query_length); + } + + /** + * Create a prefix completer from multiple Fsas + * + * @param fsas a vector of fsas + * @param query the query + */ + static PrefixCompletionMatching FromMulipleFsas(const std::vector& fsas, const std::string& query) { + const size_t query_length = query.size(); + std::vector> fsa_start_state_pairs; + + for (const fsa::automata_t& fsa : fsas) { + uint64_t state = fsa->GetStartState(); + + size_t depth = 0; + while (state != 0 && depth != query_length) { + state = fsa->TryWalkTransition(state, query[depth++]); + } + + if (state != 0) { + fsa_start_state_pairs.emplace_back(fsa, state); + } + } + + if (fsa_start_state_pairs.size() == 0) { + return PrefixCompletionMatching(); + } + + // create the traversal stack + std::unique_ptr> traversal_stack = std::make_unique>(); + traversal_stack->reserve(1024); + + for (const char& c : query) { + traversal_stack->push_back(c); + } + + Match first_match; + // check for a match given the exact prefix + for (const auto& fsa_state : fsa_start_state_pairs) { + if (fsa_state.first->IsFinalState(fsa_state.second)) { + first_match = + Match(0, query_length, query, 0, fsa_state.first, fsa_state.first->GetStateValue(fsa_state.second)); + break; + } + } + + std::unique_ptr traverser = std::make_unique(fsa_start_state_pairs); + + return PrefixCompletionMatching(std::move(traverser), std::move(first_match), std::move(traversal_stack), + query_length); + } + + Match FirstMatch() const { return first_match_; } + + Match NextMatch() { + for (; traverser_ptr_ && *traverser_ptr_; (*traverser_ptr_)++) { + traversal_stack_->resize(prefix_length_ + traverser_ptr_->GetDepth() - 1); + traversal_stack_->push_back(traverser_ptr_->GetStateLabel()); + TRACE("Current depth %d (%d)", prefix_length_ + traverser_ptr_->GetDepth() - 1, traversal_stack_->size()); + + if (traverser_ptr_->IsFinalState()) { + std::string match_str = std::string(traversal_stack_->begin(), traversal_stack_->end()); + + TRACE("found final state at depth %d %s", prefix_length_ + traverser_ptr_->GetDepth(), match_str.c_str()); + Match m(0, prefix_length_ + traverser_ptr_->GetDepth(), match_str, 0, traverser_ptr_->GetFsa(), + traverser_ptr_->GetStateValue()); + + (*traverser_ptr_)++; + return m; + } + } + + return Match(); + } + + void SetMinWeight(uint32_t min_weight) { traverser_ptr_->SetMinWeight(min_weight); } + + private: + PrefixCompletionMatching(std::unique_ptr&& traverser, Match&& first_match, + std::unique_ptr>&& traversal_stack, const size_t prefix_length) + : traverser_ptr_(std::move(traverser)), + first_match_(std::move(first_match)), + traversal_stack_(std::move(traversal_stack)), + prefix_length_(prefix_length) {} + + PrefixCompletionMatching() {} + + private: + std::unique_ptr traverser_ptr_; + const Match first_match_; + std::unique_ptr> traversal_stack_; + const size_t prefix_length_ = 0; + + // reset method for the index in the special case the match is deleted + template + friend Match index::internal::NextFilteredMatchSingle(const MatcherT&, const DeletedT&); + template + friend Match index::internal::NextFilteredMatch(const MatcherT&, const DeletedT&); + + void ResetLastMatch() {} +}; + +} /* namespace matching */ +} /* namespace dictionary */ +} /* namespace keyvi */ +#endif // KEYVI_DICTIONARY_MATCHING_PREFIX_COMPLETION_MATCHING_H_ diff --git a/keyvi/tests/keyvi/dictionary/dictionary_test.cpp b/keyvi/tests/keyvi/dictionary/dictionary_test.cpp index e97f46f3c..02004425b 100644 --- a/keyvi/tests/keyvi/dictionary/dictionary_test.cpp +++ b/keyvi/tests/keyvi/dictionary/dictionary_test.cpp @@ -223,6 +223,80 @@ BOOST_AUTO_TEST_CASE(DictGetZerobyte) { BOOST_CHECK_EQUAL("22", boost::get(m.GetAttribute("weight"))); } +BOOST_AUTO_TEST_CASE(DictGetPrefixCompletion) { + std::vector> test_data = { + {"eric a", 331}, {"eric b", 1331}, {"eric c", 1431}, {"eric d", 231}, {"eric e", 431}, + {"eric f", 531}, {"eric g", 631}, {"eric h", 731}, {"eric i", 831}, {"eric j", 131}, + }; + + testing::TempDictionary dictionary(&test_data); + dictionary_t d(new Dictionary(dictionary.GetFsa())); + + std::vector expected_matches = {"eric c", "eric b", "eric i"}; + + size_t i = 0; + + for (auto m : d->GetPrefixCompletion("eric", 3)) { + if (i >= expected_matches.size()) { + BOOST_FAIL("got more results than expected."); + } + BOOST_CHECK_EQUAL(expected_matches[i++], m.GetMatchedString()); + } + + BOOST_CHECK_EQUAL(expected_matches.size(), i); + + expected_matches = {"eric c", "eric b", "eric i", "eric h", "eric g"}; + + i = 0; + + for (auto m : d->GetPrefixCompletion("eric", 5)) { + if (i >= expected_matches.size()) { + BOOST_FAIL("got more results than expected."); + } + BOOST_CHECK_EQUAL(expected_matches[i++], m.GetMatchedString()); + } + + BOOST_CHECK_EQUAL(expected_matches.size(), i); + + for (auto m : d->GetPrefixCompletion("steve", 3)) { + BOOST_FAIL("expected no match"); + } +} + +BOOST_AUTO_TEST_CASE(DictGetPrefixCompletionCustomFilter) { + std::vector> test_data = { + {"mr. eric a", 331}, {"mr. eric b", 1331}, {"mr. max b", 1431}, {"mr. stefan b", 231}, {"mr. stefan e", 431}, + {"mr. heinz f", 531}, {"mr. karl b", 631}, {"mr. gustav b", 731}, {"mr. gustav h", 831}, {"mr. jeremy j", 131}, + }; + + testing::TempDictionary dictionary(&test_data); + dictionary_t d(new Dictionary(dictionary.GetFsa())); + + // all names ending with 'b' and weight > 500 + std::vector expected_matches = {"mr. max b", "mr. eric b", "mr. gustav b", "mr. karl b"}; + + size_t i = 0; + + auto completer = d->GetPrefixCompletion("mr. "); + auto completer_it = completer.begin(); + + while (completer_it != completer.end()) { + if (completer_it->GetMatchedString().back() == 'b') { + if (i >= expected_matches.size()) { + BOOST_FAIL("got more results than expected."); + } + BOOST_CHECK_EQUAL(expected_matches[i++], completer_it->GetMatchedString()); + } + completer_it.SetMinWeight(500); + completer_it++; + } + + // test that bogus call does not cause bad_function + completer.end().SetMinWeight(5); + + BOOST_CHECK_EQUAL(expected_matches.size(), i); +} + BOOST_AUTO_TEST_SUITE_END() } /* namespace dictionary */ diff --git a/keyvi/tests/keyvi/dictionary/fsa/state_traverser_test.cpp b/keyvi/tests/keyvi/dictionary/fsa/state_traverser_test.cpp index 29a4ec8bc..b8bbec821 100644 --- a/keyvi/tests/keyvi/dictionary/fsa/state_traverser_test.cpp +++ b/keyvi/tests/keyvi/dictionary/fsa/state_traverser_test.cpp @@ -23,11 +23,12 @@ * Author: hendrik */ +#include "keyvi/dictionary/fsa/state_traverser.h" + #include #include "keyvi/dictionary/fsa/automata.h" #include "keyvi/dictionary/fsa/generator.h" -#include "keyvi/dictionary/fsa/state_traverser.h" #include "keyvi/testing/temp_dictionary.h" namespace keyvi { @@ -315,6 +316,91 @@ BOOST_AUTO_TEST_CASE(zeroByte) { BOOST_CHECK(s.AtEnd()); } +BOOST_AUTO_TEST_CASE(traversal_min_weight) { + std::vector> test_data = {{"aaaa", 5}, {"aabb", 15}, {"aabc", 10}, {"aacd", 20}, + {"bbcd", 40}, {"cbcd", 18}, {"cefgh", 12}}; + + testing::TempDictionary dictionary(&test_data); + automata_t f = dictionary.GetFsa(); + + StateTraverser s(f); + + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(1, s.GetDepth()); + BOOST_CHECK_EQUAL(40, s.GetInnerWeight()); + BOOST_CHECK(!s.AtEnd()); + + s++; + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(2, s.GetDepth()); + s++; + BOOST_CHECK_EQUAL('c', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + + s++; + BOOST_CHECK_EQUAL('d', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + BOOST_CHECK_EQUAL(40, s.GetInnerWeight()); + BOOST_CHECK(!s.AtEnd()); + + s.SetMinWeight(12); + s++; + BOOST_CHECK(!s.AtEnd()); + + BOOST_CHECK_EQUAL('a', s.GetStateLabel()); + BOOST_CHECK_EQUAL(1, s.GetDepth()); + BOOST_CHECK_EQUAL(20, s.GetInnerWeight()); + s++; + + BOOST_CHECK_EQUAL('a', s.GetStateLabel()); + BOOST_CHECK_EQUAL(2, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('c', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('d', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('c', s.GetStateLabel()); + BOOST_CHECK_EQUAL(1, s.GetDepth()); + BOOST_CHECK_EQUAL(18, s.GetInnerWeight()); + s++; + + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(2, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('c', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('d', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('e', s.GetStateLabel()); + BOOST_CHECK_EQUAL(2, s.GetDepth()); + BOOST_CHECK_EQUAL(12, s.GetInnerWeight()); + s++; + + s.SetMinWeight(20); + s++; + BOOST_CHECK(s.AtEnd()); + BOOST_CHECK_EQUAL(0, s.GetStateLabel()); + BOOST_CHECK_EQUAL(0, s.GetDepth()); +} + BOOST_AUTO_TEST_SUITE_END() } /* namespace fsa */ diff --git a/keyvi/tests/keyvi/dictionary/fsa/zip_state_traverser_test.cpp b/keyvi/tests/keyvi/dictionary/fsa/zip_state_traverser_test.cpp index 31179ff33..3602adca0 100644 --- a/keyvi/tests/keyvi/dictionary/fsa/zip_state_traverser_test.cpp +++ b/keyvi/tests/keyvi/dictionary/fsa/zip_state_traverser_test.cpp @@ -1224,6 +1224,95 @@ BOOST_AUTO_TEST_CASE(weightedTraversal_with_prune) { s++; } +BOOST_AUTO_TEST_CASE(weightedTraversal_min_weight) { + std::vector> test_data1 = { + {"aabc", 412}, + {"aabde", 22}, + {"efde", 24}, + }; + testing::TempDictionary dictionary1(&test_data1); + automata_t f1 = dictionary1.GetFsa(); + + std::vector> test_data2 = { + {"cdbde", 444}, + {"cdef", 10}, + {"cdzzz", 56}, + {"efde", 5}, + }; + testing::TempDictionary dictionary2(&test_data2); + automata_t f2 = dictionary2.GetFsa(); + + std::vector> test_data3 = { + {"cdbde", 333}, + {"cdef", 34}, + {"cdzzz", 15}, + {"efde", 10}, + }; + testing::TempDictionary dictionary3(&test_data3); + automata_t f3 = dictionary3.GetFsa(); + + std::vector> test_data4 = { + {"aabc", 1}, {"aabde", 2}, {"cdbde", 3}, {"cdef", 4}, {"cdzzz", 5}, {"efde", 6}, + }; + testing::TempDictionary dictionary4(&test_data4); + automata_t f4 = dictionary4.GetFsa(); + + ZipStateTraverser s({f1, f2, f3, f4}); + + // we should get 'c' first + BOOST_CHECK_EQUAL('c', s.GetStateLabel()); + BOOST_CHECK_EQUAL(1, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('d', s.GetStateLabel()); + BOOST_CHECK_EQUAL(2, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('d', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('e', s.GetStateLabel()); + BOOST_CHECK_EQUAL(5, s.GetDepth()); + BOOST_CHECK_EQUAL(444, s.GetInnerWeight()); + s++; + + BOOST_CHECK_EQUAL('z', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + BOOST_CHECK_EQUAL(56, s.GetInnerWeight()); + s++; + + BOOST_CHECK_EQUAL('z', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + BOOST_CHECK_EQUAL(56, s.GetInnerWeight()); + s.SetMinWeight(200); + s++; + + BOOST_CHECK_EQUAL('a', s.GetStateLabel()); + BOOST_CHECK_EQUAL(1, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('a', s.GetStateLabel()); + BOOST_CHECK_EQUAL(2, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('b', s.GetStateLabel()); + BOOST_CHECK_EQUAL(3, s.GetDepth()); + s++; + + BOOST_CHECK_EQUAL('c', s.GetStateLabel()); + BOOST_CHECK_EQUAL(4, s.GetDepth()); + BOOST_CHECK_EQUAL(412, s.GetInnerWeight()); + s++; + + // traverser at the end + BOOST_CHECK_EQUAL(0, s.GetStateLabel()); +} + BOOST_AUTO_TEST_CASE(nearTraversal) { std::vector test_data1 = {"aaaa", "aabb", "aabc", "aacd", "bbcd", "cdefgh"}; testing::TempDictionary dictionary1(&test_data1); diff --git a/keyvi/tests/keyvi/dictionary/matching/prefix_completion_matching_test.cpp b/keyvi/tests/keyvi/dictionary/matching/prefix_completion_matching_test.cpp new file mode 100644 index 000000000..f42c6c203 --- /dev/null +++ b/keyvi/tests/keyvi/dictionary/matching/prefix_completion_matching_test.cpp @@ -0,0 +1,156 @@ +/* * keyvi - A key value store. + * + * Copyright 2024 Hendrik Muhs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "keyvi/dictionary/matching/prefix_completion_matching.h" + +#include +#include + +#include "keyvi/dictionary/dictionary.h" +#include "keyvi/testing/temp_dictionary.h" + +namespace keyvi { +namespace dictionary { +namespace matching { + +BOOST_AUTO_TEST_SUITE(PrefixCompletionMatchingTests) + +void test_prefix_completion_matching(std::vector>* test_data, const std::string& query, + const std::vector expected) { + testing::TempDictionary dictionary(test_data); + + // test using weights + auto matcher_weights = std::make_shared>( + matching::PrefixCompletionMatching<>::FromSingleFsa(dictionary.GetFsa(), query)); + + MatchIterator::MatchIteratorPair it = MatchIterator::MakeIteratorPair( + [matcher_weights]() { return matcher_weights->NextMatch(); }, matcher_weights->FirstMatch()); + + auto expected_it = expected.begin(); + for (auto m : it) { + BOOST_CHECK(expected_it != expected.end()); + BOOST_CHECK_EQUAL(*expected_it++, m.GetMatchedString()); + } + + // test without weights + std::vector expected_sorted = expected; + std::sort(expected_sorted.begin(), expected_sorted.end()); + + auto matcher_no_weights = std::make_shared>>( + matching::PrefixCompletionMatching>::FromSingleFsa(dictionary.GetFsa(), query)); + MatchIterator::MatchIteratorPair matcher_no_weights_it = MatchIterator::MakeIteratorPair( + [matcher_no_weights]() { return matcher_no_weights->NextMatch(); }, matcher_no_weights->FirstMatch()); + + expected_it = expected_sorted.begin(); + for (auto m : matcher_no_weights_it) { + BOOST_CHECK(expected_it != expected_sorted.end()); + BOOST_CHECK_EQUAL(*expected_it++, m.GetMatchedString()); + } + BOOST_CHECK(expected_it == expected_sorted.end()); + + // test with multiple dictionaries + // split test data into 3 groups with some duplication + std::vector> test_data_1; + std::vector> test_data_2; + std::vector> test_data_3; + + for (size_t i = 0; i < test_data->size(); ++i) { + if (i % 1 == 0 || i % 5 == 0) { + test_data_1.push_back((*test_data)[i]); + } + if (i % 2 == 0 || i == 3) { + test_data_2.push_back((*test_data)[i]); + } + if (i % 3 == 0) { + test_data_3.push_back((*test_data)[i]); + } + } + testing::TempDictionary d1(&test_data_1); + testing::TempDictionary d2(&test_data_2); + testing::TempDictionary d3(&test_data_3); + std::vector fsas = {d1.GetFsa(), d2.GetFsa(), d3.GetFsa()}; + + auto matcher_zipped = + std::make_shared>>( + matching::PrefixCompletionMatching>::FromMulipleFsas( + fsas, query)); + MatchIterator::MatchIteratorPair matcher_zipped_it = MatchIterator::MakeIteratorPair( + [matcher_zipped]() { return matcher_zipped->NextMatch(); }, matcher_zipped->FirstMatch()); + expected_it = expected.begin(); + for (auto m : matcher_zipped_it) { + BOOST_CHECK(expected_it != expected.end()); + BOOST_CHECK_EQUAL(*expected_it++, m.GetMatchedString()); + } + BOOST_CHECK(expected_it == expected.end()); + + auto matcher_zipped_no_weights = + std::make_shared>>>( + matching::PrefixCompletionMatching>>::FromMulipleFsas(fsas, + query)); + + MatchIterator::MatchIteratorPair matcher_zipped_no_weights_it = + MatchIterator::MakeIteratorPair([matcher_zipped_no_weights]() { return matcher_zipped_no_weights->NextMatch(); }, + matcher_zipped_no_weights->FirstMatch()); + expected_it = expected_sorted.begin(); + for (auto m : matcher_zipped_no_weights_it) { + BOOST_CHECK(expected_it != expected_sorted.end()); + BOOST_CHECK_EQUAL(*expected_it++, m.GetMatchedString()); + } + BOOST_CHECK(expected_it == expected_sorted.end()); +} + +BOOST_AUTO_TEST_CASE(prefix_0) { + std::vector> test_data = { + {"aaaa", 1000}, {"aabb", 1001}, {"aabc", 1002}, {"aacd", 1030}, {"bbcd", 1040}}; + + test_prefix_completion_matching(&test_data, "aa", std::vector{"aacd", "aabc", "aabb", "aaaa"}); +} + +BOOST_AUTO_TEST_CASE(prefix_1) { + std::vector> test_data = {{"aa", 100}, {"aaaa", 1000}, {"aabb", 1001}, + {"aabc", 1002}, {"aacd", 1030}, {"bbcd", 1040}}; + + test_prefix_completion_matching(&test_data, "aa", std::vector{"aa", "aacd", "aabc", "aabb", "aaaa"}); +} + +BOOST_AUTO_TEST_CASE(prefix_completion_edge_cases) { + std::vector> test_data = { + {"aaaa", 1000}, {"aabb", 1001}, {"aabc", 1002}, {"aacd", 1030}, {"bbcd", 1040}}; + + test_prefix_completion_matching(&test_data, "", std::vector{"bbcd", "aacd", "aabc", "aabb", "aaaa"}); + test_prefix_completion_matching(&test_data, "c", {}); + test_prefix_completion_matching(&test_data, "cc", {}); + test_prefix_completion_matching(&test_data, " ", {}); +} + +BOOST_AUTO_TEST_CASE(prefix_completion_cjk) { + std::vector> test_data = { + {"あsだ", 331}, {"あsだs", 23698}, {"あsaだsっdさ", 18838}, + {"あkだsdさ", 11387}, {"あsだsっd", 10189}, {"あxださ", 10188}, + }; + testing::TempDictionary dictionary(&test_data); + dictionary_t d(new Dictionary(dictionary.GetFsa())); + + test_prefix_completion_matching(&test_data, "あs", + std::vector{"あsだ", "あsだs", "あsだsっd", "あsaだsっdさ"}); +} + +BOOST_AUTO_TEST_SUITE_END() + +} /* namespace matching */ +} /* namespace dictionary */ +} /* namespace keyvi */ diff --git a/python/src/addons/Dictionary.pyx b/python/src/addons/Dictionary.pyx index ebe3f51f1..b0c79fc7b 100644 --- a/python/src/addons/Dictionary.pyx +++ b/python/src/addons/Dictionary.pyx @@ -107,3 +107,4 @@ def GetManifest(self, *args): return call_deprecated_method("GetManifest", "manifest", self.manifest, *args) + diff --git a/python/src/addons/Match.pyx b/python/src/addons/Match.pyx index e569d1e38..372714270 100644 --- a/python/src/addons/Match.pyx +++ b/python/src/addons/Match.pyx @@ -177,3 +177,7 @@ def IsEmpty(self, *args): """deprecated, use bool operator""" return not call_deprecated_method("IsEmpty", "__bool__", self.__bool__, *args) + + @property + def weight(self): + return self.inst.get().GetWeight() diff --git a/python/src/addons/autwrap_workarounds.pyx b/python/src/addons/autwrap_workarounds.pyx index 73c48a022..52a42f1ac 100644 --- a/python/src/addons/autwrap_workarounds.pyx +++ b/python/src/addons/autwrap_workarounds.pyx @@ -18,6 +18,10 @@ import sys import warnings +# definition of progress callback for all compilers +cdef void progress_compiler_callback(size_t a, size_t b, void* py_callback) noexcept with gil: + (py_callback)(a, b) + def get_package_root(): module_location = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(keyvi._pycore.__file__)), "..")) @@ -35,11 +39,6 @@ def get_interpreter_executable(): return executable - -# definition for all compilers -cdef void progress_compiler_callback(size_t a, size_t b, void* py_callback) noexcept with gil: - (py_callback)(a, b) - def call_deprecated_method(deprecated_method_name, new_method_name, new_method, *args): msg = f"{deprecated_method_name} is deprecated and will be removed in a future version. Use {new_method_name} instead." warnings.warn(msg, DeprecationWarning) diff --git a/python/src/addons/match_iterator.pyx b/python/src/addons/match_iterator.pyx index 70c45a160..2c3675be2 100644 --- a/python/src/addons/match_iterator.pyx +++ b/python/src/addons/match_iterator.pyx @@ -41,3 +41,6 @@ cdef class MatchIterator: py_result.inst = shared_ptr[_Match](_r) return py_result + + def set_min_weight(self, w): + self.it.SetMinWeight(w) diff --git a/python/src/converters/__init__.py b/python/src/converters/__init__.py index 3c9ad72ce..6dbfa9ec4 100644 --- a/python/src/converters/__init__.py +++ b/python/src/converters/__init__.py @@ -1,4 +1,4 @@ -from .pykeyvi_autowrap_conversion_providers import * +from .match_iterator_converter import * from autowrap.ConversionProvider import special_converters diff --git a/python/src/converters/match_iterator_converter.py b/python/src/converters/match_iterator_converter.py new file mode 100644 index 000000000..48c5bda82 --- /dev/null +++ b/python/src/converters/match_iterator_converter.py @@ -0,0 +1,24 @@ +from __future__ import print_function +from autowrap.Code import Code +from autowrap.ConversionProvider import TypeConverterBase + + +class MatchIteratorPairConverter(TypeConverterBase): + def get_base_types(self): + return ("_MatchIteratorPair",) + + def matches(self, cpp_type): + return not cpp_type.is_ptr + + def matching_python_type(self, cpp_type): + return "MatchIterator" + + def output_conversion(self, cpp_type, input_cpp_var, output_py_var): + return Code().add( + """ + |cdef MatchIterator $output_py_var = MatchIterator.__new__(MatchIterator) + |$output_py_var.it = _r.begin() + |$output_py_var.end = _r.end() + """, + locals(), + ) diff --git a/python/src/converters/pykeyvi_autowrap_conversion_providers.py b/python/src/converters/pykeyvi_autowrap_conversion_providers.py deleted file mode 100644 index 0d80d105c..000000000 --- a/python/src/converters/pykeyvi_autowrap_conversion_providers.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import print_function -from autowrap.Code import Code -from autowrap.ConversionProvider import TypeConverterBase - - -class MatchIteratorPairConverter(TypeConverterBase): - - def get_base_types(self): - return "_MatchIteratorPair", - - def matches(self, cpp_type): - return not cpp_type.is_ptr - - def matching_python_type(self, cpp_type): - return "MatchIterator" - - #def type_check_expression(self, cpp_type, argument_var): - # if cpp_type.is_ref: - # return "isinstance(%s, String)" % (argument_var,) - # return "isinstance(%s, bytes)" % (argument_var,) - - #def input_conversion(self, cpp_type, argument_var, arg_num): - #if cpp_type.is_ref: - # call_as = "deref(%s.inst.get())" % argument_var - #else: - # call_as = "(_String(%s))" % argument_var - #code = cleanup = "" - #return code, call_as, cleanup - - def output_conversion(self, cpp_type, input_cpp_var, output_py_var): - return Code().add(""" - |cdef MatchIterator $output_py_var = MatchIterator.__new__(MatchIterator) - |$output_py_var.it = _r.begin() - |$output_py_var.end = _r.end() - """, locals()) diff --git a/python/src/pxds/dictionary.pxd b/python/src/pxds/dictionary.pxd index fdc9d4d64..6f5f69a3c 100644 --- a/python/src/pxds/dictionary.pxd +++ b/python/src/pxds/dictionary.pxd @@ -3,9 +3,11 @@ from libcpp.string cimport string as libcpp_string from libcpp.string cimport string as libcpp_utf8_string from libcpp.string cimport string as libcpp_utf8_output_string from libc.stdint cimport int32_t +from libc.stdint cimport uint32_t from libc.stdint cimport uint64_t from libcpp cimport bool -from match cimport Match +from libcpp.pair cimport pair as libcpp_pair +from match cimport Match as _Match from match_iterator cimport MatchIteratorPair as _MatchIteratorPair cdef extern from "keyvi/dictionary/dictionary.h" namespace "keyvi::dictionary": @@ -20,18 +22,41 @@ cdef extern from "keyvi/dictionary/dictionary.h" namespace "keyvi::dictionary": populate_key_part_no_readahead_value_part # populate the key part, but disable read ahead value part cdef cppclass Dictionary: + # wrap-doc: + # Keyvi dictionary, basically a set of key values. Keyvi dictionaries + # are immutable containers, created by a previours compile run. + # Immutability has performance benefits. If you are looking for an + # updateable container, have a look at keyvi index. + # + # Keyvi dictionaries allow multiple types of approximate and completion + # matches due to its internal FST based data structure. Dictionary (libcpp_utf8_string filename) except + Dictionary (libcpp_utf8_string filename, loading_strategy_types) except + - bool Contains (libcpp_utf8_string) # wrap-ignore - Match operator[](libcpp_utf8_string) # wrap-ignore - _MatchIteratorPair Get (libcpp_utf8_string) # wrap-as:match - _MatchIteratorPair GetNear (libcpp_utf8_string, size_t minimum_prefix_length) except + # wrap-as:match_near - _MatchIteratorPair GetNear (libcpp_utf8_string, size_t minimum_prefix_length, bool greedy) except + # wrap-as:match_near - _MatchIteratorPair GetFuzzy (libcpp_utf8_string, int32_t max_edit_distance) except + # wrap-as:match_fuzzy - _MatchIteratorPair GetFuzzy (libcpp_utf8_string, int32_t max_edit_distance, size_t minimum_exact_prefix) except + # wrap-as:match_fuzzy + bool Contains (libcpp_utf8_string key) # wrap-ignore + _Match operator[](libcpp_utf8_string key) # wrap-ignore + _MatchIteratorPair Get (libcpp_utf8_string key) # wrap-as:match + _MatchIteratorPair GetNear (libcpp_utf8_string key, size_t minimum_prefix_length) except + # wrap-as:match_near + _MatchIteratorPair GetNear (libcpp_utf8_string key, size_t minimum_prefix_length, bool greedy) except + # wrap-as:match_near + _MatchIteratorPair GetFuzzy (libcpp_utf8_string key, int32_t max_edit_distance) except + # wrap-as:match_fuzzy + _MatchIteratorPair GetFuzzy (libcpp_utf8_string key, int32_t max_edit_distance, size_t minimum_exact_prefix) except + # wrap-as:match_fuzzy + _MatchIteratorPair GetPrefixCompletion (libcpp_utf8_string key) except + # wrap-as:complete_prefix + # wrap-doc: + # complete the given key to full matches by matching the given key as + # prefix. In case the used dictionary supports inner weights, the + # completer traverses the dictionary according to weights. If weights + # are not available the dictionary gets traversed in byte-order. + _MatchIteratorPair GetPrefixCompletion (libcpp_utf8_string key, size_t top_n) except + # wrap-as:complete_prefix + # wrap-doc: + # complete the given key to full matches by matching the given key as + # prefix. This version of prefix completions ensure the return of the + # top name completions. Due to depth-first traversal the traverser + # immediately yields results when it visits them. The results are + # neither in order nor limited to n. It is up to the caller to resort + # and truncate the lists of results. + # Only the number of top completions is guaranteed. _MatchIteratorPair GetAllItems () # wrap-ignore - _MatchIteratorPair Lookup(libcpp_utf8_string) # wrap-as:search - _MatchIteratorPair LookupText(libcpp_utf8_string) # wrap-as:search_tokenized + _MatchIteratorPair Lookup(libcpp_utf8_string key) # wrap-as:search + _MatchIteratorPair LookupText(libcpp_utf8_string text) # wrap-as:search_tokenized libcpp_utf8_output_string GetManifest() except + # wrap-as:manifest libcpp_string GetStatistics() # wrap-ignore uint64_t GetSize() # wrap-ignore diff --git a/python/src/pxds/match.pxd b/python/src/pxds/match.pxd index 35b282e42..e8b261872 100644 --- a/python/src/pxds/match.pxd +++ b/python/src/pxds/match.pxd @@ -15,6 +15,7 @@ cdef extern from "keyvi/dictionary/match.h" namespace "keyvi::dictionary": void SetEnd(size_t end) # wrap-ignore float GetScore() # wrap-ignore void SetScore(float score) # wrap-ignore + uint32_t GetWeight() # wrap-ignore libcpp_utf8_output_string GetMatchedString() # wrap-ignore void SetMatchedString (libcpp_utf8_string matched_string) # wrap-ignore PyObject* GetAttributePy(libcpp_utf8_string) except + nogil # wrap-ignore diff --git a/python/src/pxds/match_iterator.pxd b/python/src/pxds/match_iterator.pxd index b4e0b94ac..7a297d055 100644 --- a/python/src/pxds/match_iterator.pxd +++ b/python/src/pxds/match_iterator.pxd @@ -1,5 +1,6 @@ # same import style as autowrap from match cimport Match as _Match +from libc.stdint cimport uint32_t cdef extern from "keyvi/dictionary/match_iterator.h" namespace "keyvi::dictionary": cdef cppclass MatchIterator: @@ -9,6 +10,7 @@ cdef extern from "keyvi/dictionary/match_iterator.h" namespace "keyvi::dictionar MatchIterator& operator++() bint operator==(MatchIterator) bint operator!=(MatchIterator) + void SetMinWeight(uint32_t) # wrap-ignore cdef extern from "keyvi/dictionary/match_iterator.h" namespace "keyvi::dictionary::MatchIterator": cdef cppclass MatchIteratorPair: diff --git a/python/tests/dictionary/prefix_completion_test.py b/python/tests/dictionary/prefix_completion_test.py new file mode 100644 index 000000000..533fb7692 --- /dev/null +++ b/python/tests/dictionary/prefix_completion_test.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Usage: py.test tests + +import heapq +import sys +import os + +from keyvi.compiler import CompletionDictionaryCompiler + +root = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(root, "../")) +from test_tools import tmp_dictionary + + +def test_prefix_simple(): + c = CompletionDictionaryCompiler({"memory_limit_mb": "10"}) + c.Add("eric", 33) + c.Add("jeff", 33) + c.Add("eric bla", 233) + c.Add("eric blu", 113) + c.Add("eric ble", 413) + c.Add("eric blx", 223) + c.Add("eric bllllx", 193) + c.Add("eric bxxxx", 23) + c.Add("eric boox", 143) + with tmp_dictionary(c, "completion.kv") as d: + assert [m.matched_string for m in d.complete_prefix("eric")] == [ + "eric", + "eric ble", + "eric bla", + "eric blx", + "eric bllllx", + "eric blu", + "eric boox", + "eric bxxxx", + ] + + # note: we are getting one more("eric"), because of dfs traversal + assert [m.matched_string for m in d.complete_prefix("eric", 2)] == [ + "eric", + "eric ble", + "eric bla", + ] + + def filter_x(completer): + for m in completer: + if m.matched_string.endswith("x"): + completer.set_min_weight(40) + yield m + + assert [m.matched_string for m in filter_x(d.complete_prefix("eric"))] == [ + "eric blx", + "eric bllllx", + "eric boox", + ] + + class TopNFilter: + def __init__(self, n) -> None: + self.n = n + self.heap = [] + self.visits = 0 + + def filter(self, completer): + for m in completer: + assert m.weight == m.value + self.visits += 1 + if len(self.heap) < self.n: + heapq.heappush(self.heap, m.weight) + yield m + elif m.weight > self.heap[0]: + heapq.heappop(self.heap) + heapq.heappush(self.heap, m.weight) + completer.set_min_weight(self.heap[0]) + yield m + + top5 = TopNFilter(5) + # note: we are getting more erics, because of dfs traversal + assert [m.matched_string for m in top5.filter(d.complete_prefix("eric"))] == [ + "eric", + "eric ble", + "eric bla", + "eric blx", + "eric bllllx", + "eric blu", + "eric boox", + ] + + # by traversing using min weight, we should _not_ visit all entries + assert top5.visits < len(d) + + top3 = TopNFilter(3) + # note: getting more erics, because of dfs traversal + assert [m.matched_string for m in top3.filter(d.complete_prefix("eric"))] == [ + "eric", + "eric ble", + "eric bla", + "eric blx", + ] + + # top-3 should have less visits than top-5 + assert top3.visits < top5.visits + + +def test_mismatches(): + c = CompletionDictionaryCompiler({"memory_limit_mb": "10"}) + c.Add("a", 33) + c.Add("ab", 33) + c.Add("abcd", 233) + with tmp_dictionary(c, "completion.kv") as d: + assert [m.matched_string for m in d.complete_prefix("v")] == [] + assert [m.matched_string for m in d.complete_prefix("vwxyz")] == [] + assert [m.matched_string for m in d.complete_prefix("av")] == [] + assert [m.matched_string for m in d.complete_prefix("abcde")] == [] + assert [m.matched_string for m in d.complete_prefix(" ")] == []