diff --git a/docs/user-guides/advanced/embedding-search-providers.md b/docs/user-guides/advanced/embedding-search-providers.md index d26e6d4ca..1df5d5f9f 100644 --- a/docs/user-guides/advanced/embedding-search-providers.md +++ b/docs/user-guides/advanced/embedding-search-providers.md @@ -19,7 +19,7 @@ core: search_threshold: None cache: enabled: False - key_generator: md5 + key_generator: sha256 store: filesystem store_config: {} @@ -35,7 +35,7 @@ knowledge_base: search_threshold: None cache: enabled: False - key_generator: md5 + key_generator: sha256 store: filesystem store_config: {} ``` @@ -51,7 +51,7 @@ core: embedding_model: text-embedding-ada-002 cache: enabled: False - key_generator: md5 + key_generator: sha256 store: filesystem store_config: {} @@ -63,7 +63,7 @@ knowledge_base: embedding_model: text-embedding-ada-002 cache: enabled: False - key_generator: md5 + key_generator: sha256 store: filesystem store_config: {} ``` @@ -71,7 +71,7 @@ knowledge_base: The default implementation is also designed to support asynchronous execution of the embedding computation process, thereby enhancing the efficiency of the search functionality. The `cache` configuration is optional. If enabled, it uses the specified `key_generator` and `store` to cache the embeddings. The `store_config` can be used to provide additional configuration options required for the store. -The default `cache` configuration uses the `md5` key generator and the `filesystem` store. The cache is disabled by default. +The default `cache` configuration uses the `sha256` key generator and the `filesystem` store. The cache is disabled by default. ## Batch Implementation diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index d8c0a77a0..e8a348049 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -64,6 +64,15 @@ def generate_key(self, text: str) -> str: return hashlib.md5(text.encode("utf-8")).hexdigest() +class SHA256KeyGenerator(KeyGenerator): + """SHA256-based key generator.""" + + name = "sha256" + + def generate_key(self, text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + class CacheStore(ABC): """Abstract class for cache stores.""" diff --git a/nemoguardrails/kb/kb.py b/nemoguardrails/kb/kb.py index 47d90d6e0..685f8a9a0 100644 --- a/nemoguardrails/kb/kb.py +++ b/nemoguardrails/kb/kb.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import logging import os from time import time @@ -22,6 +21,7 @@ from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem from nemoguardrails.kb.utils import split_markdown_in_topic_chunks from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, KnowledgeBaseConfig +from nemoguardrails.utils import compute_hash log = logging.getLogger(__name__) @@ -114,18 +114,16 @@ async def build(self): if not index_items: return - # We compute the md5 + # We compute the hash using default hash algorithm # As part of the hash, we also include the embedding engine and the model # to prevent the cache being used incorrectly when the embedding model changes. hash_prefix = self.config.embedding_search_provider.parameters.get( "embedding_engine", "" ) + self.config.embedding_search_provider.parameters.get("embedding_model", "") - md5_hash = hashlib.md5( - (hash_prefix + "".join(all_text_items)).encode("utf-8") - ).hexdigest() - cache_file = os.path.join(CACHE_FOLDER, f"{md5_hash}.ann") - embedding_size_file = os.path.join(CACHE_FOLDER, f"{md5_hash}.esize") + hash_value = compute_hash(hash_prefix + "".join(all_text_items)) + cache_file = os.path.join(CACHE_FOLDER, f"{hash_value}.ann") + embedding_size_file = os.path.join(CACHE_FOLDER, f"{hash_value}.esize") # If we have already computed this before, we use it if ( diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 688e51be5..57d784001 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -247,7 +247,7 @@ class EmbeddingsCacheConfig(BaseModel): description="Whether caching of the embeddings should be enabled or not.", ) key_generator: str = Field( - default="md5", + default="sha256", description="The method to use for generating the cache keys.", ) store: str = Field( diff --git a/nemoguardrails/utils.py b/nemoguardrails/utils.py index 6c84933d9..5d3b3b451 100644 --- a/nemoguardrails/utils.py +++ b/nemoguardrails/utils.py @@ -15,6 +15,7 @@ import asyncio import dataclasses import fnmatch +import hashlib import importlib.resources as pkg_resources import json import os @@ -408,3 +409,25 @@ def safe_eval(input_value: str) -> str: escaped_value = input_value.replace("'", "\\'").replace('"', '\\"') input_value = f"'{escaped_value}'" return literal_eval(input_value) + + +def compute_hash(text: str) -> str: + """ + Return the hash of the given text using MD5 if available, + otherwise use SHA256. + + Args: + text (str): The input text to hash. + + Returns: + str: The hexadecimal digest of the hash. + """ + try: + # Attempt to use MD5 by doing a dummy call. + hashlib.md5(b"") + hash_func = hashlib.md5 + except (AttributeError, ValueError): + # MD5 is not available use sha256 instead + hash_func = hashlib.sha256 + + return hash_func(text.encode("utf-8")).hexdigest() diff --git a/tests/test_cache_embeddings.py b/tests/test_cache_embeddings.py index ef5bb5fd2..d6daf6dcf 100644 --- a/tests/test_cache_embeddings.py +++ b/tests/test_cache_embeddings.py @@ -30,6 +30,7 @@ KeyGenerator, MD5KeyGenerator, RedisCacheStore, + SHA256KeyGenerator, cache_embeddings, ) from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig @@ -58,6 +59,29 @@ def test_md5_key_generator(): assert len(key) == 32 # MD5 hash is 32 characters long +def test_sha256_key_generator(): + key_gen = SHA256KeyGenerator() + key = key_gen.generate_key("test") + assert isinstance(key, str) + assert len(key) == 64 # SHA256 hash is 64 characters long + + +@pytest.mark.parametrize( + "name, expected_class", + [ + ("hash", HashKeyGenerator), + ("md5", MD5KeyGenerator), + ("sha256", SHA256KeyGenerator), + ], +) +def test_key_generator_class(name, expected_class): + assert KeyGenerator.from_name(name) == expected_class + + +def test_embedding_cache_config_default(): + assert EmbeddingsCacheConfig().key_generator == "sha256" + + def test_in_memory_cache_store(): cache = InMemoryCacheStore() cache.set("key", "value") diff --git a/tests/test_utils.py b/tests/test_utils.py index 1ce893a2d..79d506beb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch import pytest -from nemoguardrails.utils import new_event_dict, safe_eval +from nemoguardrails.utils import compute_hash, new_event_dict, safe_eval def test_event_generation(): @@ -144,3 +145,21 @@ def test_safe_eval(input_value, expected_result): """Test safe_eval with various input values.""" result = safe_eval(input_value) assert result == expected_result + + +@pytest.fixture(params=[AttributeError, ValueError]) +def md5_not_available(request): + with patch("hashlib.md5", side_effect=request.param): + yield + + +def test_hash_without_md5(md5_not_available): + hash_value = compute_hash("test") + assert isinstance(hash_value, str) + assert len(hash_value) == 64 # SHA256 hash is 64 characters long + + +def test_hash_with_md5(): + hash_value = compute_hash("test") + assert isinstance(hash_value, str) + assert len(hash_value) == 32 # MD5 hash is 32 characters long