Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/user-guides/advanced/embedding-search-providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ core:
search_threshold: None
cache:
enabled: False
key_generator: md5
key_generator: sha256
store: filesystem
store_config: {}

Expand All @@ -35,7 +35,7 @@ knowledge_base:
search_threshold: None
cache:
enabled: False
key_generator: md5
key_generator: sha256
store: filesystem
store_config: {}
```
Expand All @@ -51,7 +51,7 @@ core:
embedding_model: text-embedding-ada-002
cache:
enabled: False
key_generator: md5
key_generator: sha256
store: filesystem
store_config: {}

Expand All @@ -63,15 +63,15 @@ knowledge_base:
embedding_model: text-embedding-ada-002
cache:
enabled: False
key_generator: md5
key_generator: sha256
store: filesystem
store_config: {}
```

The default implementation is also designed to support asynchronous execution of the embedding computation process, thereby enhancing the efficiency of the search functionality.

The `cache` configuration is optional. If enabled, it uses the specified `key_generator` and `store` to cache the embeddings. The `store_config` can be used to provide additional configuration options required for the store.
The default `cache` configuration uses the `md5` key generator and the `filesystem` store. The cache is disabled by default.
The default `cache` configuration uses the `sha256` key generator and the `filesystem` store. The cache is disabled by default.

## Batch Implementation

Expand Down
9 changes: 9 additions & 0 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def generate_key(self, text: str) -> str:
return hashlib.md5(text.encode("utf-8")).hexdigest()


class SHA256KeyGenerator(KeyGenerator):
"""SHA256-based key generator."""

name = "sha256"

def generate_key(self, text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()


class CacheStore(ABC):
"""Abstract class for cache stores."""

Expand Down
12 changes: 5 additions & 7 deletions nemoguardrails/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import logging
import os
from time import time
Expand All @@ -22,6 +21,7 @@
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.utils import split_markdown_in_topic_chunks
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, KnowledgeBaseConfig
from nemoguardrails.utils import compute_hash

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,18 +114,16 @@ async def build(self):
if not index_items:
return

# We compute the md5
# We compute the hash using default hash algorithm
# As part of the hash, we also include the embedding engine and the model
# to prevent the cache being used incorrectly when the embedding model changes.
hash_prefix = self.config.embedding_search_provider.parameters.get(
"embedding_engine", ""
) + self.config.embedding_search_provider.parameters.get("embedding_model", "")

md5_hash = hashlib.md5(
(hash_prefix + "".join(all_text_items)).encode("utf-8")
).hexdigest()
cache_file = os.path.join(CACHE_FOLDER, f"{md5_hash}.ann")
embedding_size_file = os.path.join(CACHE_FOLDER, f"{md5_hash}.esize")
hash_value = compute_hash(hash_prefix + "".join(all_text_items))
cache_file = os.path.join(CACHE_FOLDER, f"{hash_value}.ann")
embedding_size_file = os.path.join(CACHE_FOLDER, f"{hash_value}.esize")

# If we have already computed this before, we use it
if (
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class EmbeddingsCacheConfig(BaseModel):
description="Whether caching of the embeddings should be enabled or not.",
)
key_generator: str = Field(
default="md5",
default="sha256",
description="The method to use for generating the cache keys.",
)
store: str = Field(
Expand Down
23 changes: 23 additions & 0 deletions nemoguardrails/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import dataclasses
import fnmatch
import hashlib
import importlib.resources as pkg_resources
import json
import os
Expand Down Expand Up @@ -408,3 +409,25 @@ def safe_eval(input_value: str) -> str:
escaped_value = input_value.replace("'", "\\'").replace('"', '\\"')
input_value = f"'{escaped_value}'"
return literal_eval(input_value)


def compute_hash(text: str) -> str:
"""
Return the hash of the given text using MD5 if available,
otherwise use SHA256.

Args:
text (str): The input text to hash.

Returns:
str: The hexadecimal digest of the hash.
"""
try:
# Attempt to use MD5 by doing a dummy call.
hashlib.md5(b"")
hash_func = hashlib.md5
except (AttributeError, ValueError):
# MD5 is not available use sha256 instead
hash_func = hashlib.sha256

return hash_func(text.encode("utf-8")).hexdigest()
24 changes: 24 additions & 0 deletions tests/test_cache_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
KeyGenerator,
MD5KeyGenerator,
RedisCacheStore,
SHA256KeyGenerator,
cache_embeddings,
)
from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig
Expand Down Expand Up @@ -58,6 +59,29 @@ def test_md5_key_generator():
assert len(key) == 32 # MD5 hash is 32 characters long


def test_sha256_key_generator():
key_gen = SHA256KeyGenerator()
key = key_gen.generate_key("test")
assert isinstance(key, str)
assert len(key) == 64 # SHA256 hash is 64 characters long


@pytest.mark.parametrize(
"name, expected_class",
[
("hash", HashKeyGenerator),
("md5", MD5KeyGenerator),
("sha256", SHA256KeyGenerator),
],
)
def test_key_generator_class(name, expected_class):
assert KeyGenerator.from_name(name) == expected_class


def test_embedding_cache_config_default():
assert EmbeddingsCacheConfig().key_generator == "sha256"


def test_in_memory_cache_store():
cache = InMemoryCacheStore()
cache.set("key", "value")
Expand Down
21 changes: 20 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch

import pytest

from nemoguardrails.utils import new_event_dict, safe_eval
from nemoguardrails.utils import compute_hash, new_event_dict, safe_eval


def test_event_generation():
Expand Down Expand Up @@ -144,3 +145,21 @@ def test_safe_eval(input_value, expected_result):
"""Test safe_eval with various input values."""
result = safe_eval(input_value)
assert result == expected_result


@pytest.fixture(params=[AttributeError, ValueError])
def md5_not_available(request):
with patch("hashlib.md5", side_effect=request.param):
yield


def test_hash_without_md5(md5_not_available):
hash_value = compute_hash("test")
assert isinstance(hash_value, str)
assert len(hash_value) == 64 # SHA256 hash is 64 characters long


def test_hash_with_md5():
hash_value = compute_hash("test")
assert isinstance(hash_value, str)
assert len(hash_value) == 32 # MD5 hash is 32 characters long