Skip to content

Commit

Permalink
Add a dictionary factory backed by MARISA-tries (adbar#133)
Browse files Browse the repository at this point in the history
* Add a dictionary factory backed by MARISA-tries

This adds an additional dictionary factory backed by MARISA-tries. This
dictionary factory on average offers 20x lower memory usage and 100x
faster initialization time, in exchange for reduced lemmatization and
language detection performance.

The first time loading a dictionary with the `TrieDictionaryFactory`
requires more memory and will take a few seconds, as the trie-backed
dictionary has to be generated on-the-fly from the pickled dict-based
dictionary first.

* Change dictionary format to use strings again

This changes the format of the dictionary returned by
`DictionaryFactory().get_dictionary()` from
`Dict[ByteString, ByteString]` to `Mapping[str, str] to accommodate
alternative dictionary factory implementations better and to ease the
dictionary handling again. This keeps the storage of pickled
dictionaries with byte strings though, as they're smaller than when
using strings.
  • Loading branch information
Dunedan authored Jun 26, 2024
1 parent 5f4fa16 commit 63933fc
Show file tree
Hide file tree
Showing 14 changed files with 450 additions and 38 deletions.
55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,61 @@ LANG_CACHE_SIZE = 5 # How many language dictionaries to keep in memory at once
For more information see the
[extended documentation](https://adbar.github.io/simplemma/).

### Reducing memory usage

For situations where low memory usage and fast initialization time are
more important than lemmatization and language detection performance,
Simplemma ships another `DictionaryFactory`, which uses a trie as
underlying data structure instead of a Python dict.

Using the `TrieDictionaryFactory` reduces memory usage on average by
20x and initialization time by 100x, but comes at the cost that
performance can be down 50% or even more compared to what Simplemma
otherwise achieves, depending on the specific usage.

To use the `TrieDictionaryFactory` you have to install Simplemma with
the `marisa-trie` extra dependency:

```
pip install simplemma[marisa-trie]
```

Then you have to create a custom strategy using the
`TrieDictionaryFactory` and use that for `Lemmatizer` and
`LanguageDetector` instances:

``` python
>>> from simplemma import LanguageDetector, Lemmatizer
>>> from simplemma.strategies import DefaultStrategy
>>> from simplemma.strategies.dictionaries import TrieDictionaryFactory

>>> lemmatization_strategy = DefaultStrategy(dictionary_factory=TrieDictionaryFactory())

>>> lemmatizer = Lemmatizer(lemmatization_strategy=lemmatization_strategy)
>>> lemmatizer.lemmatize('doughnuts', lang='en')
'doughnut'

>>> language_detector = LanguageDetector('la', lemmatization_strategy=lemmatization_strategy)
>>> language_detector.proportion_in_target_languages("opera post physica posita (τὰ μετὰ τὰ φυσικά)")
0.5
```

While memory usage and initialization time when using the
`TrieDictionaryFactory` are significantly lower compared to the
`DefaultDictionaryFactory`, that's only true if the trie dictionaries
are available on disk. That's not the case when using the
`TrieDictionaryFactory` for the first time, as Simplemma only ships
the dictionaries as Python dicts. The trie dictionaries have to be
generated once from the Python dicts. That happens on-the-fly when
using the `TrieDictionaryFactory` for the first time for a language and
will take a few seconds and use as much memory as loading the Python
dicts for the language requires. For further invocations the trie
dictionaries get cached on disk.

If the computer supposed to run Simplemma doesn't have enough memory to
generate the trie dictionaries, they can also be generated on another
computer with the same CPU architecture and copied over to the cache
directory.

## Supported languages

Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
black==24.4.2
flake8==7.0.0
marisa_trie==1.2.0
mypy==1.10.0
platformdirs==4.2.2
pytest==8.2.1
pytest-cov==5.0.0
types-requests==2.32.0.20240523
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_version(package):
],
description="A simple multilingual lemmatizer for Python.",
install_requires=requirements,
extras_require={"marisa-trie": ["marisa-trie", "platformdirs"]},
license="MIT license",
long_description=readme, # + '\n\n' + history,
long_description_content_type="text/markdown",
Expand Down
6 changes: 5 additions & 1 deletion simplemma/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from .affix_decomposition import AffixDecompositionStrategy
from .default import DefaultStrategy
from .dictionaries import DefaultDictionaryFactory, DictionaryFactory
from .dictionaries import (
DefaultDictionaryFactory,
DictionaryFactory,
TrieDictionaryFactory,
)
from .dictionary_lookup import DictionaryLookupStrategy
from .fallback.lemmatization_fallback_strategy import LemmatizationFallbackStrategy
from .fallback.raise_error import RaiseErrorFallbackStrategy
Expand Down
1 change: 1 addition & 0 deletions simplemma/strategies/dictionaries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Dictionary-based lemmatization strategy."""

from .dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .trie_directory_factory import TrieDictionaryFactory
34 changes: 26 additions & 8 deletions simplemma/strategies/dictionaries/dictionary_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache
from os import listdir, path
from pathlib import Path
from typing import ByteString, Dict, Protocol
from typing import ByteString, Dict, Mapping, Protocol

DATA_FOLDER = str(Path(__file__).parent / "data")
SUPPORTED_LANGUAGES = [
Expand Down Expand Up @@ -62,22 +62,41 @@ class DictionaryFactory(Protocol):
def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.
Args:
lang (str): The language code.
Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.
Raises:
ValueError: If the specified language is not supported.
"""
raise NotImplementedError


class MappingStrToByteString(Mapping[str, str]):
"""Wrapper around ByString dict to make them behave like str dict."""

__slots__ = ["_dict"]

def __init__(self, dictionary: Dict[bytes, bytes]):
self._dict = dictionary

def __getitem__(self, item: str):
return self._dict[item.encode()].decode()

def __iter__(self):
for key in self._dict:
yield key.decode()

def __len__(self):
return len(self._dict)


class DefaultDictionaryFactory(DictionaryFactory):
"""
Default Dictionary Factory.
Expand All @@ -86,7 +105,7 @@ class DefaultDictionaryFactory(DictionaryFactory):
It provides functionality for loading and caching dictionaries from disk that are included in Simplemma.
"""

__slots__ = ["_data", "_load_dictionary_from_disk"]
__slots__ = ["_load_dictionary_from_disk"]

def __init__(self, cache_max_size: int = 8):
"""
Expand All @@ -96,27 +115,26 @@ def __init__(self, cache_max_size: int = 8):
cache_max_size (int): The maximum size of the cache for loaded dictionaries.
Defaults to `8`.
"""
self._data: Dict[str, Dict[ByteString, ByteString]] = {}
self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)(
_load_dictionary_from_disk
)

def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.
Args:
lang (str): The language code.
Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.
Raises:
ValueError: If the specified language is not supported.
"""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")
return self._load_dictionary_from_disk(lang)
return MappingStrToByteString(self._load_dictionary_from_disk(lang))
123 changes: 123 additions & 0 deletions simplemma/strategies/dictionaries/trie_directory_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
from collections.abc import MutableMapping
from functools import lru_cache
from pathlib import Path
from typing import List, Mapping, Optional

from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found]
from platformdirs import user_cache_dir

from simplemma import __version__ as SIMPLEMMA_VERSION
from simplemma.strategies.dictionaries.dictionary_factory import (
DefaultDictionaryFactory,
DictionaryFactory,
SUPPORTED_LANGUAGES,
)

logger = logging.getLogger(__name__)


class TrieWrapDict(MutableMapping):
"""Wrapper around BytesTrie to make them behave like dicts."""

def __init__(self, trie: BytesTrie):
self._trie = trie

def __getitem__(self, item):
return self._trie[item][0].decode()

def __setitem__(self, key, value):
raise NotImplementedError

def __delitem__(self, key):
raise NotImplementedError

def __iter__(self):
for key in self._trie.iterkeys():
yield key

def __len__(self):
return len(self._trie)


class TrieDictionaryFactory(DictionaryFactory):
"""Memory optimized DictionaryFactory backed by MARISA-tries.
This dictionary factory creates dictionaries, which are backed by a
MARISA-trie instead of a dict, to make them consume very little
memory compared to the DefaultDictionaryFactory. Trade-offs are that
lookup performance isn't as good as with dicts.
"""

__slots__: List[str] = []

def __init__(
self,
cache_max_size: int = 8,
use_disk_cache: bool = True,
disk_cache_dir: Optional[str] = None,
) -> None:
"""Initialize the TrieDictionaryFactory.
Args:
cache_max_size (int): The maximum number dictionaries to
keep in memory. Defaults to `8`.
use_disk_cache (bool): Whether to cache the tries on disk to
speed up loading time. Defaults to `True`.
disk_cache_dir (Optional[str]): Path where the generated
tries should be stored in. Defaults to a Simplemma-
specific subdirectory of the user's cache directory.
"""

if disk_cache_dir:
self._cache_dir = Path(disk_cache_dir)
else:
self._cache_dir = (
Path(user_cache_dir("simplemma")) / "marisa_trie" / SIMPLEMMA_VERSION
)
self._use_disk_cache = use_disk_cache
self._get_dictionary = lru_cache(maxsize=cache_max_size)(
self._get_dictionary_uncached
)

def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie:
"""Create a trie from a pickled dictionary."""
unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang)
return BytesTrie(
zip(
unpickled_dict.keys(),
[value.encode() for value in unpickled_dict.values()],
),
cache_size=HUGE_CACHE,
)

def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None:
"""Persist the trie to disk for later usage.
The persisted trie can be loaded by subsequent runs to speed up
loading times.
"""
logger.debug("Caching trie on disk. This might take a second.")
self._cache_dir.mkdir(parents=True, exist_ok=True)

trie.save(self._cache_dir / f"{lang}.dic")

def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]:
"""Get the dictionary for the given language."""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")

if self._use_disk_cache and (self._cache_dir / f"{lang}.dic").exists():
trie = BytesTrie().load(str(self._cache_dir / f"{lang}.dic"))
else:
trie = self._create_trie_from_pickled_dict(lang)
if self._use_disk_cache:
self._write_trie_to_disk(lang, trie)

return TrieWrapDict(trie)

def get_dictionary(
self,
lang: str,
) -> Mapping[str, str]:
return self._get_dictionary(lang)
13 changes: 3 additions & 10 deletions simplemma/strategies/dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
It provides lemmatization using dictionary lookup.
"""

from typing import ByteString, Dict, Optional
from typing import Optional

from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .lemmatization_strategy import LemmatizationStrategy
Expand All @@ -26,13 +26,6 @@ def __init__(
"""
self._dictionary_factory = dictionary_factory

def _get(
self, token: str, dictionary: Dict[ByteString, ByteString]
) -> Optional[str]:
"Convenience function to handle bytestring to string conversion."
result = dictionary.get(token.encode("utf-8"))
return result.decode("utf-8") if result else None # type: ignore[union-attr]

def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
Get Lemma using Dictionary Lookup
Expand All @@ -50,9 +43,9 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
# Search the language data, reverse case to extend coverage.
dictionary = self._dictionary_factory.get_dictionary(lang)
result = self._get(token, dictionary)
result = dictionary.get(token)
if result:
return result
# Try upper or lowercase.
token = token.lower() if token[0].isupper() else token.capitalize()
return self._get(token, dictionary)
return dictionary.get(token)
4 changes: 2 additions & 2 deletions simplemma/strategies/greedy_dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str:
return token

dictionary = self._dictionary_factory.get_dictionary(lang)
candidate = token.encode("utf-8")
candidate = token
for _ in range(self._steps):
if candidate not in dictionary:
break
Expand All @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str:

candidate = new_candidate

return candidate.decode("utf-8")
return candidate
8 changes: 2 additions & 6 deletions simplemma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple.
"""

from typing import ByteString, Tuple, Union
from typing import Tuple, Union


def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
Expand All @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
return lang # type: ignore[return-value]


def levenshtein_dist(
first: Union[ByteString, str], second: Union[ByteString, str]
) -> int:
def levenshtein_dist(str1: str, str2: str) -> int:
"""
Calculate the Levenshtein distance between two strings.
Expand All @@ -49,8 +47,6 @@ def levenshtein_dist(
int: The Levenshtein distance between the two strings.
"""
str1 = first.encode("utf-8") if isinstance(first, str) else first
str2 = second.encode("utf-8") if isinstance(second, str) else second
# inspired by this noticeably faster code:
# https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b
if str1 == str2:
Expand Down
Loading

0 comments on commit 63933fc

Please sign in to comment.