Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 104 additions & 83 deletions strings/knuth_morris_pratt.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,122 @@
from __future__ import annotations
"""
Implementation of the Knuth-Morris-Pratt (KMP) string searching algorithm.
The KMP algorithm searches for all occurrences of a "pattern" within a main "text"
by employing the observation that when a mismatch occurs, the pattern itself
embodies sufficient information to determine where the next match could begin,
thus bypassing re-examination of previously matched characters.

This results in an optimal time complexity of O(n + m), where n is the length
of the text and m is the length of the pattern.

def knuth_morris_pratt(text: str, pattern: str) -> int:
"""
The Knuth-Morris-Pratt Algorithm for finding a pattern within a piece of text
with complexity O(n + m)
Source: https://en.wikipedia.org/wiki/Knuth-Morris-Pratt_algorithm
"""

from __future__ import annotations

1) Preprocess pattern to identify any suffixes that are identical to prefixes

This tells us where to continue from if we get a mismatch between a character
in our pattern and the text.
def _compute_lps_array(pattern: str) -> list[int]:
"""
Computes the Longest Proper Prefix Suffix (LPS) array for the KMP algorithm.
The LPS array for a pattern of length m is an array lps of size m where lps[i]
is the length of the longest proper prefix of pattern[0...i] that is also a
suffix of pattern[0...i].

A "proper prefix" is a prefix of the string, but not the whole string.
A "proper suffix" is a suffix of the string, but not the whole string.

:param pattern: The pattern string to compute the LPS array for.
:return: The LPS array, which is used to guide the search.

>>> _compute_lps_array("aabaabaaa")
[0, 1, 0, 1, 2, 3, 4, 5, 2]
>>> _compute_lps_array("ababaca")
[0, 0, 1, 2, 3, 0, 1]
>>> _compute_lps_array("AAAA")
[0, 1, 2, 3]
>>> _compute_lps_array("abcde")
[0, 0, 0, 0, 0]
"""
m = len(pattern)
lps = [0] * m
length = 0 # Length of the previous longest prefix suffix
i = 1

while i < m:
if pattern[i] == pattern[length]:
length += 1
lps[i] = length
i += 1
elif length != 0:
length = lps[length - 1]
else:
lps[i] = 0
i += 1
return lps

2) Step through the text one character at a time and compare it to a character in
the pattern updating our location within the pattern if necessary

>>> kmp = "knuth_morris_pratt"
>>> all(
... knuth_morris_pratt(kmp, s) == kmp.find(s)
... for s in ("kn", "h_m", "rr", "tt", "not there")
... )
True
def knuth_morris_pratt_search(text: str, pattern: str) -> list[int]:
"""
Finds all occurrences of a pattern in a text using the KMP algorithm.

:param text: The text to search in.
:param pattern: The pattern to search for.
:return: A list of starting indices of all occurrences of the pattern.
Returns an empty list if the pattern is not found or is empty.

>>> # Test cases from the original file
>>> knuth_morris_pratt_search("alskfjaldsabc1abc1abc12k23adsfabcabc", "abc1abc12")
[14]
>>> knuth_morris_pratt_search("alskfjaldsk23adsfabcabc", "abc1abc12")
[]
>>> knuth_morris_pratt_search("ABABZABABYABABX", "ABABX")
[10]
>>> knuth_morris_pratt_search("ABAAAAAB", "AAAB")
[4]
>>> knuth_morris_pratt_search("abcxabcdabxabcdabcdabcy", "abcdabcy")
[15]
>>> # More comprehensive test cases
>>> knuth_morris_pratt_search("AABAACAADAABAABA", "AABA")
[0, 9, 12]
>>> knuth_morris_pratt_search("knuth_morris_pratt", "kn")
[0]
>>> knuth_morris_pratt_search("knuth_morris_pratt", "h_m")
[4]
>>> knuth_morris_pratt_search("knuth_morris_pratt", "rr")
[8]
>>> knuth_morris_pratt_search("knuth_morris_pratt", "tt")
[16]
>>> knuth_morris_pratt_search("knuth_morris_pratt", "not there")
[]
>>> knuth_morris_pratt_search("test", "")
[]
"""
n = len(text)
m = len(pattern)
if m == 0:
return []

# 1) Construct the failure array
failure = get_failure_array(pattern)
lps = _compute_lps_array(pattern)
found_indices = []
i = 0 # index for text
j = 0 # index for pattern

# 2) Step through text searching for pattern
i, j = 0, 0 # index into text, pattern
while i < len(text):
while i < n:
if pattern[j] == text[i]:
if j == (len(pattern) - 1):
return i - j
i += 1
j += 1

# if this is a prefix in our pattern
# just go back far enough to continue
elif j > 0:
j = failure[j - 1]
continue
i += 1
return -1


def get_failure_array(pattern: str) -> list[int]:
"""
Calculates the new index we should go to if we fail a comparison
:param pattern:
:return:
"""
failure = [0]
i = 0
j = 1
while j < len(pattern):
if pattern[i] == pattern[j]:
i += 1
elif i > 0:
i = failure[i - 1]
continue
j += 1
failure.append(i)
return failure
if j == m:
found_indices.append(i - j)
j = lps[j - 1]
elif i < n and pattern[j] != text[i]:
if j != 0:
j = lps[j - 1]
else:
i += 1
return found_indices


if __name__ == "__main__":
import doctest

doctest.testmod()

# Test 1)
pattern = "abc1abc12"
text1 = "alskfjaldsabc1abc1abc12k23adsfabcabc"
text2 = "alskfjaldsk23adsfabcabc"
assert knuth_morris_pratt(text1, pattern)
assert knuth_morris_pratt(text2, pattern)

# Test 2)
pattern = "ABABX"
text = "ABABZABABYABABX"
assert knuth_morris_pratt(text, pattern)

# Test 3)
pattern = "AAAB"
text = "ABAAAAAB"
assert knuth_morris_pratt(text, pattern)

# Test 4)
pattern = "abcdabcy"
text = "abcxabcdabxabcdabcdabcy"
assert knuth_morris_pratt(text, pattern)

# Test 5) -> Doctests
kmp = "knuth_morris_pratt"
assert all(
knuth_morris_pratt(kmp, s) == kmp.find(s)
for s in ("kn", "h_m", "rr", "tt", "not there")
)

# Test 6)
pattern = "aabaabaaa"
assert get_failure_array(pattern) == [0, 1, 0, 1, 2, 3, 4, 5, 2]