Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix completions with custom python filter #287

Merged
merged 27 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
29a2ccc
implement prefix completion the new way
hendrikmuhs Feb 3, 2024
c6ce35d
introduce a better way to set the minimum weight during traversal, re…
hendrikmuhs Feb 4, 2024
a9581a3
implement filtering functor
hendrikmuhs Feb 6, 2024
366af27
simplify
hendrikmuhs Feb 6, 2024
7d4d826
add code doc
hendrikmuhs Feb 6, 2024
1db1355
implement filter callbacks, including a top-n filter
hendrikmuhs Feb 11, 2024
8091ea7
add test for custom callback
hendrikmuhs Feb 15, 2024
beb4e8e
implement python filter callback for prefix completion
hendrikmuhs Feb 15, 2024
8bfe68d
add all prefix completions to python wrapper
hendrikmuhs Feb 16, 2024
22200f4
add code docs and tests
hendrikmuhs Feb 19, 2024
6332f7e
revert filter callback changes and instead introduce SetMinWeight in …
hendrikmuhs Feb 27, 2024
7738ecf
fix style
hendrikmuhs Feb 27, 2024
a4e0dcd
reactivate tests
hendrikmuhs Feb 27, 2024
a61c913
fix leftovers
hendrikmuhs Feb 27, 2024
4ae5864
std::binary_function got removed in C++17
hendrikmuhs Feb 28, 2024
1d58246
reset min weight functor, too
hendrikmuhs Feb 28, 2024
520c4c0
use make_unique
hendrikmuhs Mar 1, 2024
4840270
fix corner cases
hendrikmuhs Mar 1, 2024
57302d3
add more tests
hendrikmuhs Mar 1, 2024
106eba0
expose weight
hendrikmuhs Mar 5, 2024
659a2b5
fix top-n testcase
hendrikmuhs Mar 5, 2024
80602b3
re-factor inner method to get the inner weight and add support to get…
hendrikmuhs Mar 5, 2024
f959732
expose the dictionary weight, only applies to completion dicts
hendrikmuhs Mar 5, 2024
3ac42b7
fix completion matching
hendrikmuhs Mar 5, 2024
ef51bb3
improve completion testcase
hendrikmuhs Mar 5, 2024
65396f1
ensure weight and value are equal
hendrikmuhs Mar 5, 2024
ea7290d
fix format
hendrikmuhs Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ForwardBackwardCompletion final {
ForwardBackwardCompletion(dictionary_t forward_dictionary, dictionary_t backward_dictionary)
: forward_completions_(forward_dictionary), backward_completions_(backward_dictionary) {}

struct result_compare : public std::binary_function<Match, Match, bool> {
struct result_compare {
hendrikmuhs marked this conversation as resolved.
Show resolved Hide resolved
bool operator()(const Match& m1, const Match& m2) const { return m1.GetScore() < m2.GetScore(); }
};

Expand Down
36 changes: 36 additions & 0 deletions keyvi/include/keyvi/dictionary/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include "keyvi/dictionary/match_iterator.h"
#include "keyvi/dictionary/matching/fuzzy_matching.h"
#include "keyvi/dictionary/matching/near_matching.h"
#include "keyvi/dictionary/matching/prefix_completion_matching.h"
#include "keyvi/dictionary/util/bounded_priority_queue.h"

// #define ENABLE_TRACING
#include "keyvi/dictionary/util/trace.h"
Expand Down Expand Up @@ -324,6 +326,40 @@ class Dictionary final {
return MatchIterator::MakeIteratorPair(func, data->FirstMatch());
}

MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query) const {
auto data = std::make_shared<matching::PrefixCompletionMatching<>>(
matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query));

auto func = [data]() { return data->NextMatch(); };
return MatchIterator::MakeIteratorPair(
func, data->FirstMatch(),
std::bind(&matching::PrefixCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1));
}

MatchIterator::MatchIteratorPair GetPrefixCompletion(const std::string& query, size_t top_n) const {
auto data = std::make_shared<matching::PrefixCompletionMatching<>>(
matching::PrefixCompletionMatching<>::FromSingleFsa(fsa_, query));

auto best_weights = std::make_shared<util::BoundedPriorityQueue<uint32_t>>(top_n);

auto func = [data, best_weights = std::move(best_weights)]() {
auto m = data->NextMatch();
while (!m.IsEmpty()) {
if (m.GetWeight() >= best_weights->Back()) {
best_weights->Put(m.GetWeight());
return m;
}

m = data->NextMatch();
}
return Match();
};

return MatchIterator::MakeIteratorPair(
func, data->FirstMatch(),
std::bind(&matching::PrefixCompletionMatching<>::SetMinWeight, &(*data), std::placeholders::_1));
}

std::string GetManifest() const { return fsa_->GetManifest(); }

private:
Expand Down
19 changes: 19 additions & 0 deletions keyvi/include/keyvi/dictionary/fsa/comparable_state_traverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ class ComparableStateTraverser final {

size_t GetOrder() const { return order_; }

/**
* Set the minimum weight states must be greater or equal to.
*
* Only available for WeightedTransition specialization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this method is only available to WeightedTransition, should we prevent calls to it with something like throw std:: logic_error("Not implemented") instead of having a method that is no-op ?

Am good with this as well, just wanna see if this was intentional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had it like this in an earlier version, but I reverted, because it broke completions on dictionaries without weights (test case should exist, so it is easy to find out).

I will try this once more and make a new PR if it doesn't break this case.

So many changes, at some point I lost track.

*
* @param weight minimum transition weight
*/
inline void SetMinWeight(uint32_t weight) {}

private:
innerTraverserType state_traverser_;
std::vector<label_t> label_stack_;
Expand Down Expand Up @@ -267,6 +276,16 @@ inline bool ComparableStateTraverser<NearStateTraverser>::operator<(const Compar
return order_ > rhs.order_;
}

/**
* Set the minimum weight states must be greater or equal to.
*
* @param weight minimum transition weight
*/
template <>
inline void ComparableStateTraverser<WeightedStateTraverser>::SetMinWeight(uint32_t weight) {
state_traverser_.SetMinWeight(weight);
}

} /* namespace fsa */
} /* namespace dictionary */
} /* namespace keyvi */
Expand Down
48 changes: 46 additions & 2 deletions keyvi/include/keyvi/dictionary/fsa/state_traverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "keyvi/dictionary/fsa/automata.h"
#include "keyvi/dictionary/fsa/traversal/traversal_base.h"
#include "keyvi/dictionary/fsa/traversal/weighted_traversal.h"

// #define ENABLE_TRACING
#include "keyvi/dictionary/util/trace.h"
Expand Down Expand Up @@ -119,13 +120,14 @@ class StateTraverser final {

void operator++(int) {
TRACE("statetraverser++");

// ignore cases where we are already at the end
if (current_state_ == 0) {
TRACE("at the end");
return;
}

current_state_ = stack_.GetStates().GetNextState();
current_state_ = FilterByMinWeight(stack_.GetStates().GetNextState());
TRACE("next state: %ld depth: %ld", current_state_, stack_.GetDepth());

while (current_state_ == 0) {
Expand All @@ -139,7 +141,7 @@ class StateTraverser final {
TRACE("state is 0, go up");
--stack_;
stack_.GetStates()++;
current_state_ = stack_.GetStates().GetNextState();
current_state_ = FilterByMinWeight(stack_.GetStates().GetNextState());
TRACE("next state %ld depth %ld", current_state_, stack_.GetDepth());
}

Expand All @@ -155,6 +157,15 @@ class StateTraverser final {

operator bool() const { return !at_end_; }

/**
* Set the minimum weight states must be greater or equal to.
*
* Only available for WeightedTransition specialization.
*
* @param min_weight minimum transition weight
*/
inline void SetMinWeight(uint32_t min_weight) {}

bool AtEnd() const { return at_end_; }

private:
Expand All @@ -165,6 +176,13 @@ class StateTraverser final {
bool at_end_;
traversal::TraversalStack<TransitionT> stack_;

/**
* Filter hook for weighted traversal to filter weights lower than the minimum weight (see spezialisation).
*
* Default: no filter
*/
inline uint64_t FilterByMinWeight(uint64_t state) { return state; }

template <class innerTraverserType>
friend class ComparableStateTraverser;
const traversal::TraversalStack<TransitionT> &GetStack() const { return stack_; }
Expand All @@ -176,6 +194,32 @@ class StateTraverser final {
const traversal::TraversalPayload<TransitionT> &GetTraversalPayload() const { return stack_.traversal_stack_payload; }
};

/**
* Filter state that doesn't meet the min weight requirement.
*
* This happens when SetMinWeight has been calles after the transitions got already read.
*
* @param state the state we currently look at.
*
* @return the current state if it has a weight higher than the minimum weight, 0 otherwise.
*/
template <>
inline uint64_t StateTraverser<traversal::WeightedTransition>::FilterByMinWeight(uint64_t state) {
TRACE("filter min weight for weighted transition specialization");
return state > 0 && stack_.GetStates().GetNextInnerWeight() >= stack_.traversal_stack_payload.min_weight ? state : 0;
}

/**
* Set the minimum weight states must be greater or equal to.
*
* @param weight minimum transition weight
*/
template <>
inline void StateTraverser<traversal::WeightedTransition>::SetMinWeight(uint32_t min_weight) {
TRACE("set min weight for weighted transition specialization %d", min_weight);
stack_.traversal_stack_payload.min_weight = min_weight;
}

} /* namespace fsa */
} /* namespace dictionary */
} /* namespace keyvi */
Expand Down

This file was deleted.

14 changes: 14 additions & 0 deletions keyvi/include/keyvi/dictionary/fsa/traversal/weighted_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ static bool WeightedTransitionCompare(const WeightedTransition& a, const Weighte
return a.weight > b.weight;
}

template <>
struct TraversalPayload<WeightedTransition> {
size_t current_depth;
uint32_t min_weight = 0;
};

template <>
inline void TraversalState<WeightedTransition>::Add(uint64_t s, uint32_t weight, unsigned char l,
TraversalPayload<WeightedTransition>* payload) {
if (weight >= payload->min_weight) {
traversal_state_payload.transitions.push_back(WeightedTransition(s, weight, l));
}
}

template <>
inline void TraversalState<WeightedTransition>::PostProcess(TraversalPayload<WeightedTransition>* payload) {
if (traversal_state_payload.transitions.size() > 0) {
Expand Down
2 changes: 0 additions & 2 deletions keyvi/include/keyvi/dictionary/fsa/traverser_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@
#define KEYVI_DICTIONARY_FSA_TRAVERSER_TYPES_H_

#include "keyvi/dictionary/fsa/state_traverser.h"
#include "keyvi/dictionary/fsa/traversal/bounded_weighted_traversal.h"
#include "keyvi/dictionary/fsa/traversal/near_traversal.h"
#include "keyvi/dictionary/fsa/traversal/weighted_traversal.h"

namespace keyvi {
namespace dictionary {
namespace fsa {

using BoundedWeightedStateTraverser2 = StateTraverser<traversal::BoundedWeightedTransition>;
using NearStateTraverser = StateTraverser<traversal::NearTransition>;
using WeightedStateTraverser = StateTraverser<traversal::WeightedTransition>;

Expand Down
23 changes: 23 additions & 0 deletions keyvi/include/keyvi/dictionary/fsa/zip_state_traverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ class ZipStateTraverser final {

const std::vector<label_t> &GetStateLabels() const { return traverser_queue_.top()->GetStateLabels(); }

/**
* Set the minimum weight states must be greater or equal to.
*
* Only available for WeightedTransition specialization.
*
* @param weight minimum transition weight
*/
inline void SetMinWeight(uint32_t weight) {}

private:
heap_t traverser_queue_;
bool final_ = false;
Expand All @@ -226,6 +235,8 @@ class ZipStateTraverser final {
automata_t fsa_;
size_t equal_states_ = 1;
bool pruned = false;
// this field only used for weighted traversal, ignored otherwise
uint32_t min_weight_ = 0;

inline void PreIncrement() {}

Expand Down Expand Up @@ -500,6 +511,18 @@ inline ZipStateTraverser<WeightedStateTraverser>::ZipStateTraverser(
FillInValues();
}

/**
* Set the minimum weight states must be greater or equal to.
*
* @param weight minimum transition weight
*/
template <>
inline void ZipStateTraverser<WeightedStateTraverser>::SetMinWeight(uint32_t weight) {
for (auto t : traverser_queue_) {
t->SetMinWeight(weight);
}
}

} // namespace fsa
} // namespace dictionary
} // namespace keyvi
Expand Down
Loading
Loading