diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index c3a1b269..8fce67ca 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional -from redisvl.redis.utils import hashify - class BaseLLMCache: def __init__(self, ttl: Optional[int] = None): @@ -79,14 +77,3 @@ async def astore( """Async store the specified key-value pair in the cache along with metadata.""" raise NotImplementedError - - def hash_input(self, prompt: str) -> str: - """Hashes the input prompt using SHA256. - - Args: - prompt (str): Input string to be hashed. - - Returns: - str: Hashed string. - """ - return hashify(prompt) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 515b1421..95fe753a 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -32,7 +32,7 @@ class CacheEntry(BaseModel): def generate_id(cls, values): # Ensure entry_id is set if not values.get("entry_id"): - values["entry_id"] = hashify(values["prompt"]) + values["entry_id"] = hashify(values["prompt"], values.get("filters")) return values @validator("metadata") diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index a421022b..28d15509 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -1,5 +1,5 @@ import hashlib -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np @@ -40,6 +40,9 @@ def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: return np.frombuffer(buffer, dtype=dtype).tolist() -def hashify(content: str) -> str: - """Create a secure hash of some arbitrary input text.""" +def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str: + """Create a secure hash of some arbitrary input text and optional dictionary.""" + if extras: + extra_string = " ".join([str(k) + str(v) for k, v in sorted(extras.items())]) + content = content + extra_string return hashlib.sha256(content.encode("utf-8")).hexdigest() diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index e3aef3dc..6c106e87 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -800,3 +800,46 @@ def test_index_updating(redis_url): filter_expression=tag_filter, ) assert len(response) == 1 + + +def test_no_key_collision_on_identical_prompts(redis_url): + private_cache = SemanticCache( + name="private_cache", + redis_url=redis_url, + filterable_fields=[ + {"name": "user_id", "type": "tag"}, + {"name": "zip_code", "type": "numeric"}, + ], + ) + + private_cache.store( + prompt="What is the phone number linked to my account?", + response="The number on file is 123-555-0000", + filters={"user_id": "gabs"}, + ) + + private_cache.store( + prompt="What's the phone number linked in my account?", + response="The number on file is 123-555-9999", + ###filters={"user_id": "cerioni"}, + filters={"user_id": "cerioni", "zip_code": 90210}, + ) + + private_cache.store( + prompt="What's the phone number linked in my account?", + response="The number on file is 123-555-1111", + filters={"user_id": "bart"}, + ) + + results = private_cache.check( + "What's the phone number linked in my account?", num_results=5 + ) + assert len(results) == 3 + + zip_code_filter = Num("zip_code") != 90210 + filtered_results = private_cache.check( + "what's the phone number linked in my account?", + num_results=5, + filter_expression=zip_code_filter, + ) + assert len(filtered_results) == 2 diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index e3961e6b..aa3a3add 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -48,7 +48,7 @@ def test_cache_entry_to_dict(): filters={"category": "technology"}, ) result = entry.to_dict() - assert result["entry_id"] == hashify("What is AI?") + assert result["entry_id"] == hashify("What is AI?", {"category": "technology"}) assert result["metadata"] == json.dumps({"author": "John"}) assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3]) assert result["category"] == "technology"