Skip to content

Commit

Permalink
revert filter callback changes and instead introduce SetMinWeight in …
Browse files Browse the repository at this point in the history
…MatchIterator
  • Loading branch information
hendrikmuhs committed Feb 27, 2024
1 parent 2a1b7ad commit 37e1eb1
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 202 deletions.
49 changes: 24 additions & 25 deletions keyvi/include/keyvi/dictionary/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
#include "keyvi/dictionary/fsa/traverser_types.h"
#include "keyvi/dictionary/match.h"
#include "keyvi/dictionary/match_iterator.h"
#include "keyvi/dictionary/matching/filter.h"
#include "keyvi/dictionary/matching/fuzzy_matching.h"
#include "keyvi/dictionary/matching/near_matching.h"
#include "keyvi/dictionary/matching/prefix_completion_matching.h"
#include "keyvi/dictionary/util/bounded_priority_queue.h"

// #define ENABLE_TRACING
#include "keyvi/dictionary/util/trace.h"
Expand Down Expand Up @@ -326,40 +326,39 @@ class Dictionary final {
return MatchIterator::MakeIteratorPair(func, data->FirstMatch());
}

MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query,
const matching::filter_t filter = matching::accept_all) const {
MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query) const {
auto data = std::make_shared<matching::PrefixCompletionMatching<>>(
matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query, filter));
matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query));

auto func = [data]() { return data->NextMatch(); };
return MatchIterator::MakeIteratorPair(func, data->FirstMatch());
return MatchIterator::MakeIteratorPair(
func, data->FirstMatch(),
std::bind(&matching::PrefixCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1));
}

MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query, size_t top_n) const {
auto top_results = std::make_shared<matching::filter::TopN>(top_n);

auto data =
std::make_shared<matching::PrefixCompletionMatching<>>(matching::PrefixCompletionMatching<>::FromSingleFsa(
fsa_, query, std::bind(&matching::filter::TopN::filter, &(*top_results), std::placeholders::_1)));
auto data = std::make_shared<matching::PrefixCompletionMatching<>>(
matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query));

auto func = [data, top_results]() { return data->NextMatch(); };
return MatchIterator::MakeIteratorPair(func, data->FirstMatch());
}
// TODO(hendrik): C++14, use a unique pointer and move this into func
auto best_weights = std::make_shared<util::BoundedPriorityQueue<uint32_t>>(top_n);

/**
* Complete a prefix.
*/
MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query, matching::filter_wrapper_t filter,
void* user_data) const {
auto filter_wrapper = std::make_shared<matching::filter::FilterWrapper>(filter, user_data);
auto func = [data, best_weights]() {
auto m = data->NextMatch();
while (!m.IsEmpty()) {
if (m.GetWeight() >= best_weights->Back()) {
best_weights->Put(m.GetWeight());
return m;
}

auto data =
std::make_shared<matching::PrefixCompletionMatching<>>(matching::PrefixCompletionMatching<>::FromSingleFsa(
fsa_, query,
std::bind(&matching::filter::FilterWrapper::filter, &(*filter_wrapper), std::placeholders::_1)));
m = data->NextMatch();
}
return Match();
};

auto func = [data, filter_wrapper]() { return data->NextMatch(); };
return MatchIterator::MakeIteratorPair(func, data->FirstMatch());
return MatchIterator::MakeIteratorPair(
func, data->FirstMatch(),
std::bind(&matching::PrefixCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1));
}

std::string GetManifest() const { return fsa_->GetManifest(); }
Expand Down
1 change: 1 addition & 0 deletions keyvi/include/keyvi/dictionary/fsa/state_traverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ inline uint64_t StateTraverser<traversal::WeightedTransition>::FilterByMinWeight
*/
template <>
inline void StateTraverser<traversal::WeightedTransition>::SetMinWeight(uint32_t min_weight) {
TRACE("set min weight for weighted transition specialization %d", min_weight);
stack_.traversal_stack_payload.min_weight = min_weight;
}

Expand Down
16 changes: 11 additions & 5 deletions keyvi/include/keyvi/dictionary/match_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,26 @@ class MatchIterator : public boost::iterator_facade<MatchIterator, Match const,
public:
typedef util::iterator_pair<MatchIterator> MatchIteratorPair;

explicit MatchIterator(std::function<Match()> match_functor, const Match& first_match = Match())
: match_functor_(match_functor) {
explicit MatchIterator(std::function<Match()> match_functor, const Match& first_match = Match(),
std::function<void(uint32_t)> set_min_weight = {})
: match_functor_(match_functor), set_min_weight_(set_min_weight) {
current_match_ = first_match;
if (first_match.IsEmpty()) {
increment();
}
}

static MatchIteratorPair MakeIteratorPair(std::function<Match()> f, const Match& first_match = Match()) {
return MatchIteratorPair(MatchIterator(f, first_match), MatchIterator());
static MatchIteratorPair MakeIteratorPair(std::function<Match()> f, const Match& first_match = Match(),
std::function<void(uint32_t)> set_min_weight = {}) {
return MatchIteratorPair(MatchIterator(f, first_match, set_min_weight), MatchIterator());
}

static MatchIteratorPair EmptyIteratorPair() { return MatchIteratorPair(MatchIterator(), MatchIterator()); }

MatchIterator() : match_functor_(0) {}
MatchIterator() : match_functor_(0), set_min_weight_({}) {}

void SetMinWeight(uint32_t min_weight) { set_min_weight_(min_weight); }

// What we implement is determined by the boost::forward_traversal_tag
// template parameter
private:
Expand Down Expand Up @@ -105,6 +110,7 @@ class MatchIterator : public boost::iterator_facade<MatchIterator, Match const,
private:
std::function<Match()> match_functor_;
Match current_match_;
std::function<void(uint32_t)> set_min_weight_;
};

} /* namespace dictionary */
Expand Down
81 changes: 0 additions & 81 deletions keyvi/include/keyvi/dictionary/matching/filter.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "keyvi/dictionary/fsa/traverser_types.h"
#include "keyvi/dictionary/fsa/zip_state_traverser.h"
#include "keyvi/dictionary/match.h"
#include "keyvi/dictionary/matching/filter.h"
#include "keyvi/dictionary/util/utf8_utils.h"
#include "keyvi/stringdistance/levenshtein.h"
#include "utf8.h"
Expand Down Expand Up @@ -62,9 +61,9 @@ class PrefixCompletionMatching final {
* @param fsa the fsa
* @param query the query
*/
static PrefixCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const std::string& query,
const filter_t filter = accept_all) {
return FromSingleFsa(fsa, fsa->GetStartState(), query, filter);
static PrefixCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const std::string& query
) {
return FromSingleFsa(fsa, fsa->GetStartState(), query);
}

/**
Expand All @@ -75,7 +74,7 @@ class PrefixCompletionMatching final {
* @param query the query
*/
static PrefixCompletionMatching FromSingleFsa(const fsa::automata_t& fsa, const uint64_t start_state,
const std::string& query, const filter_t filter = accept_all) {
const std::string& query) {
if (start_state == 0) {
return PrefixCompletionMatching();
}
Expand Down Expand Up @@ -115,7 +114,7 @@ class PrefixCompletionMatching final {

TRACE("create matcher");
return PrefixCompletionMatching(std::move(traverser), std::move(first_match), std::move(traversal_stack),
query_length, filter);
query_length);
}

/**
Expand All @@ -124,8 +123,8 @@ class PrefixCompletionMatching final {
* @param fsas a vector of fsas
* @param query the query
*/
static PrefixCompletionMatching FromMulipleFsas(const std::vector<fsa::automata_t>& fsas, const std::string& query,
const filter_t filter = accept_all) {
static PrefixCompletionMatching FromMulipleFsas(const std::vector<fsa::automata_t>& fsas, const std::string& query
) {
const size_t query_length = query.size();
std::vector<std::pair<fsa::automata_t, uint64_t>> fsa_start_state_pairs;

Expand Down Expand Up @@ -170,17 +169,11 @@ class PrefixCompletionMatching final {
traverser.reset(new innerTraverserType(fsa_start_state_pairs));

return PrefixCompletionMatching(std::move(traverser), std::move(first_match), std::move(traversal_stack),
query_length, filter);
query_length);
}

Match FirstMatch() const {
if (first_match_.IsEmpty()) {
return first_match_;
}

filter_result_t fr = filter_(first_match_);
traverser_ptr_->SetMinWeight(fr.second);
return fr.first ? first_match_ : Match();
return first_match_;
}

Match NextMatch() {
Expand All @@ -196,15 +189,6 @@ class PrefixCompletionMatching final {
Match m(0, prefix_length_ + traverser_ptr_->GetDepth(), match_str, 0, traverser_ptr_->GetFsa(),
traverser_ptr_->GetStateValue(), traverser_ptr_->GetInnerWeight());

filter_result_t fr = filter_(m);

TRACE("filter result: %s, min weight: %d", fr.first ? "true" : "false", fr.second);

traverser_ptr_->SetMinWeight(fr.second);
if (!fr.first) {
continue;
}

(*traverser_ptr_)++;
return m;
}
Expand All @@ -213,15 +197,18 @@ class PrefixCompletionMatching final {
return Match();
}

void SetMinWeight(uint32_t min_weight) {
traverser_ptr_->SetMinWeight(min_weight);
}

private:
PrefixCompletionMatching(std::unique_ptr<innerTraverserType>&& traverser, Match&& first_match,
std::unique_ptr<std::vector<unsigned char>>&& traversal_stack, const size_t prefix_length,
const filter_t filter)
std::unique_ptr<std::vector<unsigned char>>&& traversal_stack, const size_t prefix_length
)
: traverser_ptr_(std::move(traverser)),
first_match_(std::move(first_match)),
traversal_stack_(std::move(traversal_stack)),
prefix_length_(prefix_length),
filter_(filter) {}
prefix_length_(prefix_length) {}

PrefixCompletionMatching() {}

Expand All @@ -230,7 +217,6 @@ class PrefixCompletionMatching final {
const Match first_match_;
std::unique_ptr<std::vector<unsigned char>> traversal_stack_;
const size_t prefix_length_ = 0;
const filter_t filter_;

// reset method for the index in the special case the match is deleted
template <class MatcherT, class DeletedT>
Expand Down
3 changes: 2 additions & 1 deletion keyvi/tests/keyvi/dictionary/dictionary_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ BOOST_AUTO_TEST_CASE(DictGetPrefixCompletion) {
}
}

/*
BOOST_AUTO_TEST_CASE(DictGetPrefixCompletionCustomFilter) {
std::vector<std::pair<std::string, uint32_t>> test_data = {
{"mr. eric a", 331}, {"mr. eric b", 1331}, {"mr. max b", 1431}, {"mr. stefan b", 231}, {"mr. stefan e", 431},
Expand All @@ -286,7 +287,7 @@ BOOST_AUTO_TEST_CASE(DictGetPrefixCompletionCustomFilter) {
}
BOOST_CHECK_EQUAL(expected_matches.size(), i);
}
}*/

BOOST_AUTO_TEST_SUITE_END()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <boost/test/unit_test.hpp>

#include "keyvi/dictionary/dictionary.h"
#include "keyvi/dictionary/matching/filter.h"
#include "keyvi/testing/temp_dictionary.h"

namespace keyvi {
Expand Down Expand Up @@ -114,6 +113,7 @@ void test_prefix_completion_matching(std::vector<std::pair<std::string, uint32_t
BOOST_CHECK(expected_it == expected_sorted.end());
}

/*
void test_prefix_completion_matching_with_filter(std::vector<std::pair<std::string, uint32_t>>* test_data,
const std::string& query, const std::vector<std::string> expected,
const filter_t filter) {
Expand Down Expand Up @@ -170,7 +170,7 @@ void test_prefix_completion_matching_with_filter_multi_fsa(std::vector<std::pair
BOOST_CHECK_EQUAL(*expected_it++, m.GetMatchedString());
}
BOOST_CHECK(expected_it == expected.end());
}
}*/

BOOST_AUTO_TEST_CASE(prefix_0) {
std::vector<std::pair<std::string, uint32_t>> test_data = {
Expand Down Expand Up @@ -205,13 +205,18 @@ BOOST_AUTO_TEST_CASE(prefix_completion_cjk) {
std::vector<std::string>{"あsだ", "あsだs", "あsだsっd", "あsaだsっdさ"});
}

/*
BOOST_AUTO_TEST_CASE(prefix_top_3) {
std::vector<std::pair<std::string, uint32_t>> test_data = {
{"eric a", 331}, {"eric b", 1331}, {"eric c", 1431}, {"eric d", 231}, {"eric e", 431},
{"eric f", 531}, {"eric g", 631}, {"eric h", 731}, {"eric i", 831}, {"eric j", 131},
};
filter::TopN top3(3);
test_prefix_completion_matching_with_filter(&test_data, "eric",
std::vector<std::string>{"eric c", "eric b", "eric i"},
std::bind(&filter::TopN::filter, &top3, std::placeholders::_1));
Expand All @@ -227,7 +232,7 @@ BOOST_AUTO_TEST_CASE(prefix_top_3) {
test_prefix_completion_matching_with_filter(&test_data, "steven", std::vector<std::string>{},
std::bind(&filter::TopN::filter, &top3_3, std::placeholders::_1));
}

*/
BOOST_AUTO_TEST_SUITE_END()

} /* namespace matching */
Expand Down
Loading

0 comments on commit 37e1eb1

Please sign in to comment.