Skip to content

Commit

Permalink
fix fuzzy completion matching
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmuhs committed Mar 19, 2024
1 parent 49d97c1 commit 7408ba9
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 52 deletions.
55 changes: 19 additions & 36 deletions keyvi/include/keyvi/stringdistance/needleman_wunsch.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class NeedlemanWunsch final {
: max_distance_(other.max_distance_),
compare_sequence_(std::move(other.compare_sequence_)),
intermediate_scores_(std::move(other.intermediate_scores_)),
completion_row_(other.completion_row_),
last_put_position_(other.last_put_position_),
latest_calculated_row_(other.latest_calculated_row_),
input_sequence_(std::move(other.input_sequence_)),
Expand All @@ -70,7 +69,6 @@ class NeedlemanWunsch final {
other.max_distance_ = 0;
other.last_put_position_ = 0;
other.latest_calculated_row_ = 0;
other.completion_row_ = std::numeric_limits<int32_t>::max();
}

~NeedlemanWunsch() {}
Expand All @@ -83,12 +81,6 @@ class NeedlemanWunsch final {
EnsureCapacity(row + 1);
compare_sequence_[position] = codepoint;

// reset completion row if we walked backwards
if (row <= completion_row_) {
TRACE("reset completion row");
completion_row_ = std::numeric_limits<int32_t>::max();
}

last_put_position_ = position;
size_t columns = distance_matrix_.Columns();

Expand All @@ -113,13 +105,8 @@ class NeedlemanWunsch final {
// if left_cutoff >= columns, the candidate string is longer than the input + max edit distance, we can make a
// shortcut
if (left_cutoff >= columns) {
// last character == exact match?
if (row > completion_row_ || input_sequence_.empty() ||
compare_sequence_[columns - 2] == input_sequence_.back()) {
intermediate_scores_[row] = intermediate_scores_[row - 1] + cost_function_.GetCompletionCost();
} else {
intermediate_scores_[row] = intermediate_scores_[row - 1] + cost_function_.GetInsertionCost(codepoint);
}
intermediate_scores_[row] = intermediate_scores_[row - 1] + std::min(cost_function_.GetCompletionCost(),
cost_function_.GetInsertionCost(codepoint));
return intermediate_scores_[row];
}

Expand All @@ -129,32 +116,24 @@ class NeedlemanWunsch final {
int32_t field_result;

for (size_t column = left_cutoff; column < right_cutoff; ++column) {
TRACE("calculating column %d", column);
// 1. check for exact match according to the substitution cost
// function
int32_t substitution_cost = cost_function_.GetSubstitutionCost(input_sequence_[column - 1], codepoint);
int32_t substitution_result = substitution_cost + distance_matrix_.Get(row - 1, column - 1);

if (substitution_cost == 0) {
// codePoints match
// short cut: codePoints match
field_result = substitution_result;
} else {
// 2. calculate costs for deletion, insertion and transposition
// 2. calculate costs for deletion
int32_t deletion_result =
distance_matrix_.Get(row, column - 1) + cost_function_.GetDeletionCost(input_sequence_[column - 1]);

int32_t completion_result = std::numeric_limits<int32_t>::max();

if (row > completion_row_) {
completion_result = distance_matrix_.Get(row - 1, column) + cost_function_.GetCompletionCost();
} else if (column + 1 == columns && columns > 1 &&
compare_sequence_[last_put_position_ - 1] == input_sequence_.back()) {
completion_row_ = row;
TRACE("set completion row %d columns: %d", row, columns);
completion_result = distance_matrix_.Get(row - 1, column) + cost_function_.GetCompletionCost();
}

// 3. calculate costs for insertion, transposition
int32_t insertion_result = distance_matrix_.Get(row - 1, column) + cost_function_.GetInsertionCost(codepoint);

// 4. calculate costs for transposition (swap of 2 characters: house <--> huose)
int32_t transposition_result = std::numeric_limits<int32_t>::max();

if (row > 1 && column > 1 && input_sequence_[column - 1] == compare_sequence_[position - 1] &&
Expand All @@ -164,15 +143,23 @@ class NeedlemanWunsch final {
cost_function_.GetTranspositionCost(input_sequence_[column - 1], input_sequence_[column - 2]);
}

// 4. take the minimum cost
TRACE("deletion: %d vs. insertion %d vs. transposition %d vs. substitution %d", deletion_result,
insertion_result, transposition_result, substitution_result);

// 5. take the minimum cost
field_result = std::min({deletion_result, insertion_result, transposition_result, substitution_result});
}

// 6. check if we have a completion case, only calculated on the last column
if (column + 1 == columns) {
field_result =
std::min({deletion_result, insertion_result, transposition_result, substitution_result, completion_result});
std::min(distance_matrix_.Get(row - 1, column) + cost_function_.GetCompletionCost(), field_result);
}

// put cost into matrix
// 7. put cost into matrix
distance_matrix_.Set(row, column, field_result);

// take the best intermediate result from the possible cells in the matrix
// 8. keep track of the best intermediate result from the possible cells in the matrix
if ((column + 1 == columns || column + max_distance_ >= row) && field_result <= intermediate_score) {
intermediate_score = field_result;
}
Expand Down Expand Up @@ -213,7 +200,6 @@ class NeedlemanWunsch final {
std::vector<uint32_t> compare_sequence_;
std::vector<int32_t> intermediate_scores_;

size_t completion_row_ = 0;
size_t last_put_position_ = 0;
size_t latest_calculated_row_ = 0;

Expand All @@ -228,7 +214,6 @@ class NeedlemanWunsch final {
}

latest_calculated_row_ = 1;
completion_row_ = std::numeric_limits<int32_t>::max();

// initialize compare Sequence and immediateScore
compare_sequence_.reserve(rows);
Expand All @@ -242,9 +227,7 @@ class NeedlemanWunsch final {

if (compare_sequence_.size() < capacity) {
compare_sequence_.resize(capacity);
compare_sequence_.resize(compare_sequence_.capacity());
intermediate_scores_.resize(capacity);
intermediate_scores_.resize(intermediate_scores_.capacity());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ BOOST_AUTO_TEST_CASE(approx1) {

std::vector<std::string> expected_output;
expected_output.push_back("aabc");
// not matching aabcül because of last character mismatch
expected_output.push_back("aabcül");
expected_output.push_back("aabcdefghijklmnop"); // this matches because aab_c_d, "c" is an insert

auto expected_it = expected_output.begin();
Expand Down
6 changes: 3 additions & 3 deletions python/tests/completion/fuzzy_completion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def test_fuzzy_completion():
assert len(matches) == 9

matches = [m.matched_string for m in completer.GetFuzzyCompletions('tue', 1)]
assert len(matches) == 1
assert len(matches) == 21

matches = [m.matched_string for m in completer.GetFuzzyCompletions('tuv h', 1)]
assert len(matches) == 2
assert len(matches) == 8

matches = [m.matched_string for m in completer.GetFuzzyCompletions('tuv h', 2)]
assert len(matches) == 7
assert len(matches) == 12

matches = [m.matched_string for m in completer.GetFuzzyCompletions('tuk töffnungszeiten', 2)]
assert len(matches) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
"80s video megamix": {"w": 96, "id": "b4"},
}

multiword_data_non_ascii = {
"bäder öfen übelkeit": {"w": 43, "id": "a1"},
"übelkeit kräuterschnapps alles gut": {"w": 72, "id": "a2"},
"öfen übelkeit rauchvergiftung ": {"w": 372, "id": "a3"},
}

PERMUTATION_LOOKUP_TABLE = {
2: [
[0],
Expand Down Expand Up @@ -117,37 +123,40 @@ def __call__(self, key_value):
)


def test_multiword_simple():
def create_dict(data):
pipeline = []
pipeline.append(MultiWordPermutation())
c = CompletionDictionaryCompiler()

for key, value in multiword_data.items():
for key, value in data.items():
weight = value["w"]

for e in reduce(lambda x, y: y(x), pipeline, (key, key)):
c.Add(e, weight)

with tmp_dictionary(c, "completion.kv") as d:
return c


def test_multiword_simple():
with tmp_dictionary(create_dict(multiword_data), "completion.kv") as d:
assert [
m.matched_string for m in d.complete_fuzzy_multiword("zonbies 8", 1)
] == ["80s movie with zombies"]
#assert [m.matched_string for m in d.complete_fuzzy_multiword("80th mo", 2)] == [
# "80s movie with zombies",
# "80s monsters tribute art",
#]
#assert [
# m.matched_string for m in d.complete_fuzzy_multiword("witsah 80s", 3)
#] == ["80s movie with zombies", "80s cartoon with cars"]
assert [m.matched_string for m in d.complete_fuzzy_multiword("80th mo", 2)] == [
"80s movie with zombies",
"80s monsters tribute art",
]
assert [
m.matched_string for m in d.complete_fuzzy_multiword("witsah 80s", 3)
] == ["80s movie with zombies", "80s cartoon with cars"]

assert [m.matched_string for m in d.complete_fuzzy_multiword("80ts mo", 1)] == [
"80s movie with zombies",
"80s monsters tribute art",
]

# todo: this should work with edit distance 1
assert [
m.matched_string for m in d.complete_fuzzy_multiword("tehno fa", 2)
m.matched_string for m in d.complete_fuzzy_multiword("tehno fa", 1)
] == [
"80s techno fashion",
]
Expand All @@ -164,3 +173,27 @@ def test_multiword_simple():
] == ["80s techno fashion"]

assert [m.matched_string for m in d.complete_fuzzy_multiword("", 10)] == []


def test_multiword_nonascii():
with tmp_dictionary(create_dict(multiword_data_non_ascii), "completion.kv") as d:
assert [m.matched_string for m in d.complete_fuzzy_multiword("öfen", 0)] == [
"öfen übelkeit rauchvergiftung ",
"bäder öfen übelkeit",
]
assert [m.matched_string for m in d.complete_fuzzy_multiword("ofen", 1, 0)] == [
"öfen übelkeit rauchvergiftung ",
"bäder öfen übelkeit",
]

assert [
m.matched_string for m in d.complete_fuzzy_multiword("krauterlc", 2)
] == [
"übelkeit kräuterschnapps alles gut",
]

assert [
m.matched_string for m in d.complete_fuzzy_multiword("krauterl", 2)
] == [
"übelkeit kräuterschnapps alles gut",
]

0 comments on commit 7408ba9

Please sign in to comment.