From 69892fbdde076658b09bef0003f297963bbaac14 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 13 Feb 2024 19:09:03 +0000 Subject: [PATCH] Fix: Clamping bounded Levenshtein --- include/stringzilla/stringzilla.h | 12 ++++++------ scripts/test.py | 12 ++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index e3977a1c..619d6739 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -2239,7 +2239,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // typedef sz_size_t _distance_t; // Compute the number of columns in our Levenshtein matrix. - sz_size_t n = shorter_length + 1; + sz_size_t const n = shorter_length + 1; // If a buffering memory-allocator is provided, this operation is practically free, // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. @@ -2259,8 +2259,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // // Let's export the UTF8 sequence into the newly allocated buffer at the end. if (can_be_unicode == sz_true_k) { - sz_u32_t *longer_utf32 = (sz_u32_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_u32_t *shorter_utf32 = longer_utf32 + longer_length; + sz_u32_t *const longer_utf32 = (sz_u32_t *)(buffer + sizeof(_distance_t) * (n * 2)); + sz_u32_t *const shorter_utf32 = longer_utf32 + longer_length; // Export the UTF8 sequences into the newly allocated buffer. longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); @@ -2342,7 +2342,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // } \ sz_size_t result = previous_distances[shorter_length]; \ alloc->free(buffer, buffer_length, alloc->handle); \ - return result; + return sz_min_of_two(result, bound); // Dispatch the actual computation. if (!bound) { @@ -2379,8 +2379,8 @@ SZ_PUBLIC sz_size_t sz_edit_distance_serial( // // Bounded computations may exit early. if (bound) { // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return shorter_length <= bound ? shorter_length : bound; - if (shorter_length == 0) return longer_length <= bound ? longer_length : bound; + if (longer_length == 0) return sz_min_of_two(shorter_length, bound); + if (shorter_length == 0) return sz_min_of_two(longer_length, bound); // If the difference in length is beyond the `bound`, there is no need to check at all. if (longer_length - shorter_length > bound) return bound; } diff --git a/scripts/test.py b/scripts/test.py index 172dd7f1..c965b7b1 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -120,16 +120,12 @@ def test_unit_len(): def test_slice_of_split(): - def impl(native_str): + def impl(native_str: str): native_split = native_str.split() text = sz.Str(native_str) - split = text.split() - for split_idx in range(len(native_split)): - native_slice = native_split[split_idx:] - idx = split_idx - for word in split[split_idx:]: - assert str(word) == native_split[idx] - idx += 1 + sz_split = text.split() + for slice_idx in range(len(native_split)): + assert str(sz_split[slice_idx]) == native_split[slice_idx] native_str = "Weebles wobble before they fall down, don't they?" impl(native_str)