diff --git a/keyvi/include/keyvi/dictionary/dictionary.h b/keyvi/include/keyvi/dictionary/dictionary.h index 60f412ae6..05cf91829 100644 --- a/keyvi/include/keyvi/dictionary/dictionary.h +++ b/keyvi/include/keyvi/dictionary/dictionary.h @@ -37,6 +37,7 @@ #include "keyvi/dictionary/match.h" #include "keyvi/dictionary/match_iterator.h" #include "keyvi/dictionary/matching/fuzzy_matching.h" +#include "keyvi/dictionary/matching/fuzzy_multiword_completion_matching.h" #include "keyvi/dictionary/matching/multiword_completion_matching.h" #include "keyvi/dictionary/matching/near_matching.h" #include "keyvi/dictionary/matching/prefix_completion_matching.h" @@ -397,6 +398,20 @@ class Dictionary final { std::bind(&matching::MultiwordCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1)); } + MatchIterator::MatchIteratorPair GetFuzzyMultiwordCompletion(const std::string& query, + const int32_t max_edit_distance, + const size_t minimum_exact_prefix = 0, + const unsigned char multiword_separator = 0x1b) const { + auto data = std::make_shared>( + matching::FuzzyMultiwordCompletionMatching<>::FromSingleFsa(fsa_, query, max_edit_distance, + minimum_exact_prefix, multiword_separator)); + + auto func = [data]() { return data->NextMatch(); }; + return MatchIterator::MakeIteratorPair( + func, data->FirstMatch(), + std::bind(&matching::FuzzyMultiwordCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1)); + } + std::string GetManifest() const { return fsa_->GetManifest(); } private: diff --git a/keyvi/include/keyvi/dictionary/fsa/codepoint_state_traverser.h b/keyvi/include/keyvi/dictionary/fsa/codepoint_state_traverser.h index cff616dad..f2dac8136 100644 --- a/keyvi/include/keyvi/dictionary/fsa/codepoint_state_traverser.h +++ b/keyvi/include/keyvi/dictionary/fsa/codepoint_state_traverser.h @@ -29,6 +29,7 @@ #include #include "keyvi/dictionary/fsa/automata.h" +#include "keyvi/dictionary/fsa/traverser_types.h" #include "keyvi/dictionary/util/utf8_utils.h" // #define ENABLE_TRACING @@ -130,6 +131,15 @@ class CodePointStateTraverser final { operator bool() const { return wrapped_state_traverser_; } + /** + * Set the minimum weight states must be greater or equal to. + * + * Only available for WeightedTransition specialization. + * + * @param weight minimum transition weight + */ + inline void SetMinWeight(uint32_t weight) {} + private: innerTraverserType wrapped_state_traverser_; std::vector transitions_stack_; @@ -178,6 +188,16 @@ class CodePointStateTraverser final { } }; +/** + * Set the minimum weight states must be greater or equal to. + * + * @param weight minimum transition weight + */ +template <> +inline void CodePointStateTraverser::SetMinWeight(uint32_t weight) { + wrapped_state_traverser_.SetMinWeight(weight); +} + } /* namespace fsa */ } /* namespace dictionary */ } /* namespace keyvi */ diff --git a/keyvi/include/keyvi/dictionary/matching/fuzzy_multiword_completion_matching.h b/keyvi/include/keyvi/dictionary/matching/fuzzy_multiword_completion_matching.h new file mode 100644 index 000000000..4b6390199 --- /dev/null +++ b/keyvi/include/keyvi/dictionary/matching/fuzzy_multiword_completion_matching.h @@ -0,0 +1,308 @@ +/* keyvi - A key value store. + * + * Copyright 2024 Hendrik Muhs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * fuzzy_multiword_completion_matching.h + */ + +#ifndef KEYVI_DICTIONARY_MATCHING_FUZZY_MULTIWORD_COMPLETION_MATCHING_H_ +#define KEYVI_DICTIONARY_MATCHING_FUZZY_MULTIWORD_COMPLETION_MATCHING_H_ + +#include +#include +#include +#include +#include +#include + +#include "keyvi/dictionary/fsa/automata.h" +#include "keyvi/dictionary/fsa/codepoint_state_traverser.h" +#include "keyvi/dictionary/fsa/traverser_types.h" +#include "keyvi/dictionary/fsa/zip_state_traverser.h" +#include "keyvi/dictionary/match.h" +#include "keyvi/dictionary/util/transform.h" +#include "keyvi/dictionary/util/utf8_utils.h" +#include "keyvi/stringdistance/levenshtein.h" +#include "utf8.h" + +// #define ENABLE_TRACING +#include "keyvi/dictionary/util/trace.h" + +namespace keyvi { +namespace index { +namespace internal { +template +keyvi::dictionary::Match NextFilteredMatchSingle(const MatcherT&, const DeletedT&); +template +keyvi::dictionary::Match NextFilteredMatch(const MatcherT&, const DeletedT&); +} // namespace internal +} // namespace index +namespace dictionary { +namespace matching { + +template > +class FuzzyMultiwordCompletionMatching final { + public: + /** + * Create a fuzzy multiword completer from a single Fsa + * + * @param fsa the fsa + * @param query the query + */ + static FuzzyMultiwordCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const std::string& query, + const int32_t max_edit_distance, + const size_t minimum_exact_prefix = 0, + const unsigned char multiword_separator = 0x1b) { + return FromSingleFsa(fsa, fsa->GetStartState(), query, max_edit_distance, minimum_exact_prefix, + multiword_separator); + } + + /** + * Create a fuzzy multiword completer from a single Fsa + * + * @param fsa the fsa + * @param start_state the state to start from + * @param query the query + */ + static FuzzyMultiwordCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const uint64_t start_state, + const std::string& query, const int32_t max_edit_distance, + const size_t minimum_exact_prefix = 0, + const unsigned char multiword_separator = 0x1b) { + if (start_state == 0) { + return FuzzyMultiwordCompletionMatching(); + } + size_t number_of_tokens; + std::string query_bow = util::Transform::BagOfWordsPartial(query, number_of_tokens); + + std::vector codepoints; + utf8::unchecked::utf8to32(query_bow.begin(), query_bow.end(), back_inserter(codepoints)); + const size_t utf8_query_length = codepoints.size(); + + if (utf8_query_length < minimum_exact_prefix) { + return FuzzyMultiwordCompletionMatching(); + } + + std::unique_ptr metric = + std::make_unique(codepoints, 20, max_edit_distance); + + size_t depth = 0; + uint64_t state = start_state; + size_t utf8_depth = 0; + // match exact + while (state != 0 && depth < minimum_exact_prefix) { + const size_t code_point_length = util::Utf8Utils::GetCharLength(query[utf8_depth]); + for (size_t i = 0; i < code_point_length; ++i, ++utf8_depth) { + state = fsa->TryWalkTransition(state, query[utf8_depth]); + if (0 == state) { + break; + } + } + metric->Put(codepoints[depth], depth); + ++depth; + } + + if (state == 0 || depth != minimum_exact_prefix) { + return FuzzyMultiwordCompletionMatching(); + } + + TRACE("matched prefix, length %d", depth); + + std::unique_ptr traverser = std::make_unique(fsa, state); + + Match first_match; + if (depth == utf8_query_length && fsa->IsFinalState(state)) { + TRACE("first_match %d %s", utf8_query_length, query); + first_match = Match(0, utf8_query_length, query, 0, fsa, fsa->GetStateValue(state)); + } + + return FuzzyMultiwordCompletionMatching(std::move(traverser), std::move(first_match), std::move(metric), + max_edit_distance, minimum_exact_prefix, number_of_tokens, + multiword_separator); + } + + /** + * Create a fuzzy multiword completer from multiple Fsas + * + * @param fsas a vector of fsas + * @param query the query + */ + static FuzzyMultiwordCompletionMatching FromMulipleFsas(const std::vector& fsas, + const std::string& query, const int32_t max_edit_distance, + const size_t minimum_exact_prefix = 0, + const unsigned char multiword_separator = 0x1b) { + size_t number_of_tokens; + std::string query_bow = util::Transform::BagOfWordsPartial(query, number_of_tokens); + + std::vector codepoints; + utf8::unchecked::utf8to32(query_bow.begin(), query_bow.end(), back_inserter(codepoints)); + const size_t query_length = codepoints.size(); + + std::unique_ptr metric = + std::make_unique(codepoints, 20, max_edit_distance); + + std::vector> fsa_start_state_pairs; + + // match the exact prefix on all fsas + for (const fsa::automata_t& fsa : fsas) { + uint64_t state = fsa->GetStartState(); + size_t depth, utf8_depth = 0; + + while (state != 0 && depth < minimum_exact_prefix) { + const size_t code_point_length = util::Utf8Utils::GetCharLength(query[utf8_depth]); + for (size_t i = 0; i < code_point_length; ++i, ++utf8_depth) { + state = fsa->TryWalkTransition(state, query[utf8_depth]); + if (0 == state) { + break; + } + } + ++depth; + } + + if (state != 0 && depth == minimum_exact_prefix) { + fsa_start_state_pairs.emplace_back(fsa, state); + } + } + + if (fsa_start_state_pairs.size() == 0) { + return FuzzyMultiwordCompletionMatching(); + } + + // fill the metric + for (size_t utf8_depth = 0; utf8_depth < minimum_exact_prefix; ++utf8_depth) { + metric->Put(codepoints[utf8_depth], utf8_depth); + } + + Match first_match; + // check for a match given the exact prefix + for (const auto& fsa_state : fsa_start_state_pairs) { + if (fsa_state.first->IsFinalState(fsa_state.second)) { + first_match = + Match(0, query_length, query, 0, fsa_state.first, fsa_state.first->GetStateValue(fsa_state.second)); + break; + } + } + + std::unique_ptr traverser = std::make_unique(fsa_start_state_pairs); + + return FuzzyMultiwordCompletionMatching(std::move(traverser), std::move(first_match), std::move(metric), + minimum_exact_prefix, number_of_tokens, multiword_separator); + } + + Match FirstMatch() const { return first_match_; } + + Match NextMatch() { + for (; traverser_ptr_ && *traverser_ptr_; (*traverser_ptr_)++) { + uint64_t label = traverser_ptr_->GetStateLabel(); + TRACE("label [%c] prefix length %ld traverser depth: %ld", label, prefix_length_, traverser_ptr_->GetDepth()); + + while (token_start_positions_.size() > 0 && traverser_ptr_->GetDepth() <= token_start_positions_.back()) { + TRACE("pop token stack"); + token_start_positions_.pop_back(); + } + + if (label == multiword_separator_) { + TRACE("found MW boundary at %d", traverser_ptr_->GetDepth()); + if (token_start_positions_.size() != number_of_tokens_ - 1) { + TRACE("found MW boundary before seeing enough tokens (%d/%d)", token_start_positions_.size(), + number_of_tokens_); + traverser_ptr_->Prune(); + TRACE("pruned, now at %d", traverser_ptr_->GetDepth()); + continue; + } + + multiword_boundary_ = traverser_ptr_->GetDepth(); + } else if (traverser_ptr_->GetDepth() <= multiword_boundary_) { + // reset the multiword boundary if we went up + multiword_boundary_ = 0; + TRACE("reset MW boundary at %d %d", traverser_ptr_->GetDepth(), multiword_boundary_); + } + + // only match up to the number of tokens in input + if (label == 0x20 && multiword_boundary_ == 0) { + // todo: should every token be matched with the exact prefix, except for the last token? + TRACE("push space(%d)", token_start_positions_.size()); + token_start_positions_.push_back(traverser_ptr_->GetDepth()); + } + + int32_t intermediate_score = distance_metric_->Put(label, prefix_length_ + traverser_ptr_->GetDepth() - 1); + + TRACE("Candidate: [%s] %ld intermediate score: %d(%d)", distance_metric_->GetCandidate().c_str(), + prefix_length_ + traverser_ptr_->GetDepth() - 1, intermediate_score, max_edit_distance_); + + if (intermediate_score > max_edit_distance_) { + traverser_ptr_->Prune(); + continue; + } + + if (traverser_ptr_->IsFinalState()) { + std::string match_str = multiword_boundary_ > 0 + ? distance_metric_->GetCandidate(prefix_length_ + multiword_boundary_) + : distance_metric_->GetCandidate(); + + TRACE("found final state at depth %d %s", prefix_length_ + traverser_ptr_->GetDepth(), match_str.c_str()); + Match m(0, prefix_length_ + traverser_ptr_->GetDepth(), match_str, distance_metric_->GetScore(), + traverser_ptr_->GetFsa(), traverser_ptr_->GetStateValue()); + + (*traverser_ptr_)++; + return m; + } + } + + return Match(); + } + + void SetMinWeight(uint32_t min_weight) { traverser_ptr_->SetMinWeight(min_weight); } + + private: + FuzzyMultiwordCompletionMatching(std::unique_ptr&& traverser, Match&& first_match, + std::unique_ptr&& distance_metric, + const int32_t max_edit_distance, const size_t prefix_length, size_t number_of_tokens, + const unsigned char multiword_separator) + : traverser_ptr_(std::move(traverser)), + first_match_(std::move(first_match)), + distance_metric_(std::move(distance_metric)), + max_edit_distance_(max_edit_distance), + prefix_length_(prefix_length), + number_of_tokens_(number_of_tokens), + multiword_separator_(static_cast(multiword_separator)) {} + + FuzzyMultiwordCompletionMatching() {} + + private: + std::unique_ptr traverser_ptr_; + const Match first_match_; + std::unique_ptr distance_metric_; + const int32_t max_edit_distance_ = 0; + const size_t prefix_length_ = 0; + const size_t number_of_tokens_ = 0; + const uint64_t multiword_separator_ = 0; + std::vector token_start_positions_; + size_t multiword_boundary_ = 0; + + // reset method for the index in the special case the match is deleted + template + friend Match index::internal::NextFilteredMatchSingle(const MatcherT&, const DeletedT&); + template + friend Match index::internal::NextFilteredMatch(const MatcherT&, const DeletedT&); + + void ResetLastMatch() {} +}; + +} /* namespace matching */ +} /* namespace dictionary */ +} /* namespace keyvi */ +#endif // KEYVI_DICTIONARY_MATCHING_FUZZY_MULTIWORD_COMPLETION_MATCHING_H_ diff --git a/keyvi/include/keyvi/stringdistance/needleman_wunsch.h b/keyvi/include/keyvi/stringdistance/needleman_wunsch.h index 80dfd8099..11a9ec694 100644 --- a/keyvi/include/keyvi/stringdistance/needleman_wunsch.h +++ b/keyvi/include/keyvi/stringdistance/needleman_wunsch.h @@ -61,7 +61,6 @@ class NeedlemanWunsch final { : max_distance_(other.max_distance_), compare_sequence_(std::move(other.compare_sequence_)), intermediate_scores_(std::move(other.intermediate_scores_)), - completion_row_(other.completion_row_), last_put_position_(other.last_put_position_), latest_calculated_row_(other.latest_calculated_row_), input_sequence_(std::move(other.input_sequence_)), @@ -70,7 +69,6 @@ class NeedlemanWunsch final { other.max_distance_ = 0; other.last_put_position_ = 0; other.latest_calculated_row_ = 0; - other.completion_row_ = std::numeric_limits::max(); } ~NeedlemanWunsch() {} @@ -83,12 +81,6 @@ class NeedlemanWunsch final { EnsureCapacity(row + 1); compare_sequence_[position] = codepoint; - // reset completion row if we walked backwards - if (row <= completion_row_) { - TRACE("reset completion row"); - completion_row_ = std::numeric_limits::max(); - } - last_put_position_ = position; size_t columns = distance_matrix_.Columns(); @@ -113,12 +105,8 @@ class NeedlemanWunsch final { // if left_cutoff >= columns, the candidate string is longer than the input + max edit distance, we can make a // shortcut if (left_cutoff >= columns) { - // last character == exact match? - if (row > completion_row_ || compare_sequence_[columns - 2] == input_sequence_.back()) { - intermediate_scores_[row] = intermediate_scores_[row - 1] + cost_function_.GetCompletionCost(); - } else { - intermediate_scores_[row] = intermediate_scores_[row - 1] + cost_function_.GetInsertionCost(codepoint); - } + intermediate_scores_[row] = intermediate_scores_[row - 1] + std::min(cost_function_.GetCompletionCost(), + cost_function_.GetInsertionCost(codepoint)); return intermediate_scores_[row]; } @@ -128,31 +116,24 @@ class NeedlemanWunsch final { int32_t field_result; for (size_t column = left_cutoff; column < right_cutoff; ++column) { + TRACE("calculating column %d", column); // 1. check for exact match according to the substitution cost // function int32_t substitution_cost = cost_function_.GetSubstitutionCost(input_sequence_[column - 1], codepoint); int32_t substitution_result = substitution_cost + distance_matrix_.Get(row - 1, column - 1); if (substitution_cost == 0) { - // codePoints match + // short cut: codePoints match field_result = substitution_result; } else { - // 2. calculate costs for deletion, insertion and transposition + // 2. calculate costs for deletion int32_t deletion_result = distance_matrix_.Get(row, column - 1) + cost_function_.GetDeletionCost(input_sequence_[column - 1]); - int32_t completion_result = std::numeric_limits::max(); - - if (row > completion_row_) { - completion_result = distance_matrix_.Get(row - 1, column) + cost_function_.GetCompletionCost(); - } else if (column + 1 == columns && columns > 1 && - compare_sequence_[last_put_position_ - 1] == input_sequence_.back()) { - completion_row_ = row; - completion_result = distance_matrix_.Get(row - 1, column) + cost_function_.GetCompletionCost(); - } - + // 3. calculate costs for insertion, transposition int32_t insertion_result = distance_matrix_.Get(row - 1, column) + cost_function_.GetInsertionCost(codepoint); + // 4. calculate costs for transposition (swap of 2 characters: house <--> huose) int32_t transposition_result = std::numeric_limits::max(); if (row > 1 && column > 1 && input_sequence_[column - 1] == compare_sequence_[position - 1] && @@ -162,21 +143,23 @@ class NeedlemanWunsch final { cost_function_.GetTranspositionCost(input_sequence_[column - 1], input_sequence_[column - 2]); } - // 4. take the minimum cost - // field_result = std::min( { deletion_result, insertion_result, - // transposition_result, substitution_result }); + TRACE("deletion: %d vs. insertion %d vs. transposition %d vs. substitution %d", deletion_result, + insertion_result, transposition_result, substitution_result); - field_result = std::min(deletion_result, transposition_result); + // 5. take the minimum cost + field_result = std::min({deletion_result, insertion_result, transposition_result, substitution_result}); + } - field_result = std::min(field_result, substitution_result); - field_result = std::min(field_result, insertion_result); - field_result = std::min(field_result, completion_result); + // 6. check if we have a completion case, only calculated on the last column + if (column + 1 == columns) { + field_result = + std::min(distance_matrix_.Get(row - 1, column) + cost_function_.GetCompletionCost(), field_result); } - // put cost into matrix + // 7. put cost into matrix distance_matrix_.Set(row, column, field_result); - // take the best intermediate result from the possible cells in the matrix + // 8. keep track of the best intermediate result from the possible cells in the matrix if ((column + 1 == columns || column + max_distance_ >= row) && field_result <= intermediate_score) { intermediate_score = field_result; } @@ -202,9 +185,9 @@ class NeedlemanWunsch final { int32_t GetScore() const { return distance_matrix_.Get(latest_calculated_row_, distance_matrix_.Columns() - 1); } - std::string GetCandidate() { + std::string GetCandidate(size_t pos = 0) { std::vector utf8result; - utf8::utf32to8(compare_sequence_.begin(), compare_sequence_.begin() + last_put_position_ + 1, + utf8::utf32to8(compare_sequence_.begin() + pos, compare_sequence_.begin() + last_put_position_ + 1, back_inserter(utf8result)); return std::string(utf8result.begin(), utf8result.end()); @@ -217,7 +200,6 @@ class NeedlemanWunsch final { std::vector compare_sequence_; std::vector intermediate_scores_; - size_t completion_row_ = 0; size_t last_put_position_ = 0; size_t latest_calculated_row_ = 0; @@ -232,7 +214,6 @@ class NeedlemanWunsch final { } latest_calculated_row_ = 1; - completion_row_ = std::numeric_limits::max(); // initialize compare Sequence and immediateScore compare_sequence_.reserve(rows); @@ -246,9 +227,7 @@ class NeedlemanWunsch final { if (compare_sequence_.size() < capacity) { compare_sequence_.resize(capacity); - compare_sequence_.resize(compare_sequence_.capacity()); intermediate_scores_.resize(capacity); - intermediate_scores_.resize(intermediate_scores_.capacity()); } } diff --git a/keyvi/tests/keyvi/dictionary/completion/prefix_completion_test.cpp b/keyvi/tests/keyvi/dictionary/completion/prefix_completion_test.cpp index 9891d409f..1e14794f6 100644 --- a/keyvi/tests/keyvi/dictionary/completion/prefix_completion_test.cpp +++ b/keyvi/tests/keyvi/dictionary/completion/prefix_completion_test.cpp @@ -123,7 +123,7 @@ BOOST_AUTO_TEST_CASE(approx1) { std::vector expected_output; expected_output.push_back("aabc"); - // not matching aabcül because of last character mismatch + expected_output.push_back("aabcül"); expected_output.push_back("aabcdefghijklmnop"); // this matches because aab_c_d, "c" is an insert auto expected_it = expected_output.begin(); diff --git a/python/src/pxds/dictionary.pxd b/python/src/pxds/dictionary.pxd index a96dc6014..3fc011509 100644 --- a/python/src/pxds/dictionary.pxd +++ b/python/src/pxds/dictionary.pxd @@ -71,6 +71,26 @@ cdef extern from "keyvi/dictionary/dictionary.h" namespace "keyvi::dictionary": # In case the used dictionary supports inner weights, the # completer traverses the dictionary according to weights. If weights # are not available the dictionary gets traversed in byte-order. + _MatchIteratorPair GetFuzzyMultiwordCompletion (libcpp_utf8_string key, int32_t max_edit_distance) except + # wrap-as:complete_fuzzy_multiword + # wrap-doc: + # complete the given key to full matches by matching the given key as + # multiword allowing up to max_edit_distance distance(Levenshtein). + # The key can consist of multiple tokens separated by space. + # For matching it gets tokenized put back together bag-of-words style. + # The dictionary must be created the same way. + # In case the used dictionary supports inner weights, the + # completer traverses the dictionary according to weights. If weights + # are not available the dictionary gets traversed in byte-order. + _MatchIteratorPair GetFuzzyMultiwordCompletion (libcpp_utf8_string key, int32_t max_edit_distance, size_t minimum_exact_prefix) except + # wrap-as:complete_fuzzy_multiword + # wrap-doc: + # complete the given key to full matches by matching the given key as + # multiword allowing up to max_edit_distance distance(Levenshtein). + # The key can consist of multiple tokens separated by space. + # For matching it gets tokenized put back together bag-of-words style. + # The dictionary must be created the same way. + # In case the used dictionary supports inner weights, the + # completer traverses the dictionary according to weights. If weights + # are not available the dictionary gets traversed in byte-order. _MatchIteratorPair GetAllItems () # wrap-ignore _MatchIteratorPair Lookup(libcpp_utf8_string key) # wrap-as:search _MatchIteratorPair LookupText(libcpp_utf8_string text) # wrap-as:search_tokenized diff --git a/python/tests/completion/fuzzy_completion_test.py b/python/tests/completion/fuzzy_completion_test.py index 0c54a24e4..0d1a45288 100644 --- a/python/tests/completion/fuzzy_completion_test.py +++ b/python/tests/completion/fuzzy_completion_test.py @@ -44,13 +44,13 @@ def test_fuzzy_completion(): assert len(matches) == 9 matches = [m.matched_string for m in completer.GetFuzzyCompletions('tue', 1)] - assert len(matches) == 1 + assert len(matches) == 21 matches = [m.matched_string for m in completer.GetFuzzyCompletions('tuv h', 1)] - assert len(matches) == 2 + assert len(matches) == 8 matches = [m.matched_string for m in completer.GetFuzzyCompletions('tuv h', 2)] - assert len(matches) == 7 + assert len(matches) == 12 matches = [m.matched_string for m in completer.GetFuzzyCompletions('tuk töffnungszeiten', 2)] assert len(matches) == 1 diff --git a/python/tests/dictionary/dictionary_fuzzy_multiword_completion_test.py b/python/tests/dictionary/dictionary_fuzzy_multiword_completion_test.py new file mode 100644 index 000000000..f86e95733 --- /dev/null +++ b/python/tests/dictionary/dictionary_fuzzy_multiword_completion_test.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +# Usage: py.test tests + +import sys +import os +from functools import reduce + +from keyvi.compiler import CompletionDictionaryCompiler + +root = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(root, "../")) +from test_tools import tmp_dictionary + + +multiword_data = { + "80s action film released 2013": {"w": 43, "id": "a1"}, + "80s action heros": {"w": 72, "id": "a2"}, + "80s back": {"w": 1, "id": "a3"}, + "80s baladen": {"w": 37, "id": "a4"}, + "80s cartoon with cars": {"w": 2, "id": "a5"}, + "80s game of thrones theme": {"w": 3, "id": "a6"}, + "80s girl group": {"w": 4, "id": "a7"}, + "80s hard rock berlin": {"w": 66, "id": "a8"}, + "80s indie songs": {"w": 108, "id": "a9"}, + "80s jack the ripper documentary": {"w": 39, "id": "aa"}, + "80s monsters tribute art": {"w": 13, "id": "ab"}, + "80s movie with zombies": {"w": 67, "id": "ac"}, + "80s overall": {"w": 5, "id": "ad"}, + "80s punk oi last fm": {"w": 33, "id": "ae"}, + "80s retro shop los angeles": {"w": 13, "id": "af"}, + "80s singer roger": {"w": 13, "id": "b1"}, + "80s techno fashion": {"w": 6, "id": "b2"}, + "80s theme party cupcake": {"w": 42, "id": "b3"}, + "80s video megamix": {"w": 96, "id": "b4"}, +} + +multiword_data_non_ascii = { + "bäder öfen übelkeit": {"w": 43, "id": "a1"}, + "übelkeit kräuterschnapps alles gut": {"w": 72, "id": "a2"}, + "öfen übelkeit rauchvergiftung": {"w": 372, "id": "a3"}, +} + +multiword_data_stack_corner_case = { + "a b c d e f": {"w": 43, "id": "a1"}, + "a": {"w": 12, "id": "a2"}, +} + +PERMUTATION_LOOKUP_TABLE = { + 2: [ + [0], + [1], + [0, 1], + [1, 0], + ], + 3: [ + [0], + [1], + [2], + [0, 1], + [0, 2], + [1, 0], + [1, 2], + [2, 0], + [2, 1], + [0, 1, 2], + [0, 2, 1], + [1, 2, 0], + ], + 4: [ + [0], + [1], + [2], + [3], + [0, 1], + [0, 2], + [0, 3], + [1, 0], + [1, 2], + [1, 3], + [2, 0], + [2, 1], + [2, 3], + [3, 0], + [3, 1], + [3, 2], + [0, 1, 2], + [0, 1, 3], + [0, 2, 1], + [0, 2, 3], + [0, 3, 1], + [0, 3, 2], + [1, 2, 0], + [1, 2, 3], + [1, 3, 0], + [1, 3, 2], + [2, 3, 0], + [2, 3, 1], + [0, 1, 2, 3], + ], +} + +MULTIWORD_QUERY_SEPARATOR = "\x1b" + + +class MultiWordPermutation: + def __init__(self): + pass + + def __call__(self, key_value): + key, value = key_value + tokens = key.split(" ") + tokens_bow = sorted(tokens) + + length = len(tokens_bow) + if length not in PERMUTATION_LOOKUP_TABLE: + yield " ".join(tokens) + MULTIWORD_QUERY_SEPARATOR + value + return + + for permutation in PERMUTATION_LOOKUP_TABLE[len(tokens_bow)]: + if len(permutation) < 3: + first_token = tokens_bow[permutation[0]] + if first_token != tokens[permutation[0]] and len(first_token) == 1: + continue + yield ( + " ".join([tokens_bow[i] for i in permutation]) + + MULTIWORD_QUERY_SEPARATOR + + value + ) + + +def create_dict(data): + pipeline = [] + pipeline.append(MultiWordPermutation()) + c = CompletionDictionaryCompiler() + + for key, value in data.items(): + weight = value["w"] + + for e in reduce(lambda x, y: y(x), pipeline, (key, key)): + c.Add(e, weight) + + return c + + +def test_multiword_simple(): + with tmp_dictionary(create_dict(multiword_data), "completion.kv") as d: + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("zonbies 8", 1) + ] == ["80s movie with zombies"] + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("80th mo", 2, 2) + ] == [ + "80s movie with zombies", + "80s monsters tribute art", + ] + + # matches 80s movie with zombies twice: 80th -> 80s, 80th -> with + # note: order comes from depth first traversal + assert [m.matched_string for m in d.complete_fuzzy_multiword("80th mo", 2)] == [ + "80s movie with zombies", + "80s monsters tribute art", + "80s movie with zombies", + ] + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("witsah 80s", 3) + ] == ["80s movie with zombies", "80s cartoon with cars"] + + assert [m.matched_string for m in d.complete_fuzzy_multiword("80ts mo", 1)] == [ + "80s movie with zombies", + "80s monsters tribute art", + ] + + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("tehno fa", 1) + ] == [ + "80s techno fashion", + ] + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("teschno fa", 1) + ] == [ + "80s techno fashion", + ] + + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("90s", 10, 2) + ] == [] + + # no exact prefix: match all + assert ( + len([m.matched_string for m in d.complete_fuzzy_multiword("90s", 10)]) == 44 + ) + + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("80s xxxf", 3) + ] == ["80s techno fashion"] + + assert [m.matched_string for m in d.complete_fuzzy_multiword("", 10, 2)] == [] + + # no exact prefix: match all + assert len([m.matched_string for m in d.complete_fuzzy_multiword("", 10)]) == 44 + + +def test_multiword_nonascii(): + with tmp_dictionary(create_dict(multiword_data_non_ascii), "completion.kv") as d: + assert [m.matched_string for m in d.complete_fuzzy_multiword("öfen", 0)] == [ + "öfen übelkeit rauchvergiftung", + "bäder öfen übelkeit", + ] + assert [m.matched_string for m in d.complete_fuzzy_multiword("ofen", 1, 0)] == [ + "öfen übelkeit rauchvergiftung", + "bäder öfen übelkeit", + ] + + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("krauterlc", 2) + ] == [ + "übelkeit kräuterschnapps alles gut", + ] + + assert [ + m.matched_string for m in d.complete_fuzzy_multiword("krauterl", 2) + ] == [ + "übelkeit kräuterschnapps alles gut", + ] + + +def test_multiword_stack_corner_case(): + with tmp_dictionary( + create_dict(multiword_data_stack_corner_case), "completion.kv" + ) as d: + assert [m.matched_string for m in d.complete_fuzzy_multiword("a", 0)] == [ + "a", + ] diff --git a/python/tests/dictionary/dictionary_multiword_completion_test.py b/python/tests/dictionary/dictionary_multiword_completion_test.py index a4216c0b3..28e5b8078 100644 --- a/python/tests/dictionary/dictionary_multiword_completion_test.py +++ b/python/tests/dictionary/dictionary_multiword_completion_test.py @@ -159,3 +159,18 @@ def test_multiword_simple(): ) if len(k.split(" ")) < 5 ] + + assert set( + m.matched_string + for m in sorted( + [m for m in d.complete_multiword("")], + key=lambda m: m.weight, + reverse=True, + ) + ) == set( + k + for k, v in sorted( + multiword_data.items(), key=lambda item: item[1]["w"], reverse=True + ) + if len(k.split(" ")) < 5 + )