diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index 5ca7abb9..a1c88466 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -58,11 +58,3 @@ def store( def hash_input(self, prompt: str): """Hashes the input using SHA256.""" return hashify(prompt) - - def serialize(self, metadata: Dict[str, Any]) -> str: - """Serlize the input into a string.""" - return json.dumps(metadata) - - def deserialize(self, metadata: str) -> Dict[str, Any]: - """Deserialize the input from a string.""" - return json.loads(metadata) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 8ec167b7..b17c18c9 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,21 +1,53 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from redis import Redis from redisvl.extensions.llmcache.base import BaseLLMCache from redisvl.index import SearchIndex from redisvl.query import RangeQuery +from redisvl.query.filter import FilterExpression, Tag from redisvl.redis.utils import array_to_buffer -from redisvl.schema.schema import IndexSchema +from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp, deserialize, serialize from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +class SemanticCacheIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, vector_dims: int): + + return cls( + index={"name": name, "prefix": name}, # type: ignore + fields=[ # type: ignore + {"name": "prompt", "type": "text"}, + {"name": "response", "type": "text"}, + {"name": "inserted_at", "type": "numeric"}, + {"name": "updated_at", "type": "numeric"}, + {"name": "label", "type": "tag"}, + { + "name": "prompt_vector", + "type": "vector", + "attrs": { + "dims": vector_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) + + class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" entry_id_field_name: str = "_id" prompt_field_name: str = "prompt" vector_field_name: str = "prompt_vector" + inserted_at_field_name: str = "inserted_at" + updated_at_field_name: str = "updated_at" + tag_field_name: str = "label" response_field_name: str = "response" metadata_field_name: str = "metadata" @@ -69,27 +101,7 @@ def __init__( model="sentence-transformers/all-mpnet-base-v2" ) - # build cache index schema - schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}}) - # add fields - schema.add_fields( - [ - {"name": self.prompt_field_name, "type": "text"}, - {"name": self.response_field_name, "type": "text"}, - { - "name": self.vector_field_name, - "type": "vector", - "attrs": { - "dims": vectorizer.dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ] - ) - - # build search index + schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims) self._index = SearchIndex(schema=schema) # handle redis connection @@ -103,12 +115,12 @@ def __init__( self.entry_id_field_name, self.prompt_field_name, self.response_field_name, + self.tag_field_name, self.vector_field_name, self.metadata_field_name, ] self.set_vectorizer(vectorizer) self.set_threshold(distance_threshold) - self._index.create(overwrite=False) @property @@ -182,6 +194,14 @@ def delete(self) -> None: index.""" self._index.delete(drop=True) + def drop(self, document_ids: Union[str, List[str]]) -> None: + """Remove a specific entry or entries from the cache by it's ID. + + Args: + document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. + """ + self._index.drop_keys(document_ids) + def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" if self._ttl: @@ -195,7 +215,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: return self._vectorizer.embed(prompt) def _search_cache( - self, vector: List[float], num_results: int, return_fields: Optional[List[str]] + self, + vector: List[float], + num_results: int, + return_fields: Optional[List[str]], + tag_filter: Optional[FilterExpression], ) -> List[Dict[str, Any]]: """Searches the semantic cache for similar prompt vectors and returns the specified return fields for each cache hit.""" @@ -217,6 +241,8 @@ def _search_cache( num_results=num_results, return_score=True, ) + if tag_filter: + query.set_filter(tag_filter) # type: ignore # Gather and return the cache hits cache_hits: List[Dict[str, Any]] = self._index.query(query) @@ -226,7 +252,7 @@ def _search_cache( self._refresh_ttl(key) # Check for metadata and deserialize if self.metadata_field_name in hit: - hit[self.metadata_field_name] = self.deserialize( + hit[self.metadata_field_name] = deserialize( hit[self.metadata_field_name] ) return cache_hits @@ -248,6 +274,7 @@ def check( vector: Optional[List[float]] = None, num_results: int = 1, return_fields: Optional[List[str]] = None, + tag_filter: Optional[FilterExpression] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -267,6 +294,8 @@ def check( return_fields (Optional[List[str]], optional): The fields to include in each returned result. If None, defaults to all available fields in the cached entry. + tag_filter (Optional[FilterExpression]) : the tag filter to filter + results by. Default is None and full cache is searched. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -291,7 +320,7 @@ def check( self._check_vector_dims(vector) # Check for cache hits by searching the cache - cache_hits = self._search_cache(vector, num_results, return_fields) + cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter) return cache_hits def store( @@ -300,6 +329,7 @@ def store( response: str, vector: Optional[List[float]] = None, metadata: Optional[dict] = None, + tag: Optional[str] = None, ) -> str: """Stores the specified key-value pair in the cache along with metadata. @@ -311,6 +341,8 @@ def store( demand. metadata (Optional[dict], optional): The optional metadata to cache alongside the prompt and response. Defaults to None. + tag (Optional[str]): The optional tag to assign to the cache entry. + Defaults to None. Returns: str: The Redis key for the entries added to the semantic cache. @@ -333,19 +365,67 @@ def store( self._check_vector_dims(vector) # Construct semantic cache payload + now = current_timestamp() id_field = self.entry_id_field_name payload = { id_field: self.hash_input(prompt), self.prompt_field_name: prompt, self.response_field_name: response, self.vector_field_name: array_to_buffer(vector), + self.inserted_at_field_name: now, + self.updated_at_field_name: now, } if metadata is not None: if not isinstance(metadata, dict): raise TypeError("If specified, cached metadata must be a dictionary.") # Serialize the metadata dict and add to cache payload - payload[self.metadata_field_name] = self.serialize(metadata) + payload[self.metadata_field_name] = serialize(metadata) + if tag is not None: + payload[self.tag_field_name] = tag # Load LLMCache entry with TTL keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) return keys[0] + + def update(self, key: str, **kwargs) -> None: + """Update specific fields within an existing cache entry. If no fields + are passed, then only the document TTL is refreshed. + + Args: + key (str): the key of the document to update. + kwargs: + + Raises: + ValueError if an incorrect mapping is provided as a kwarg. + TypeError if metadata is provided and not of type dict. + + .. code-block:: python + key = cache.store('this is a prompt', 'this is a response') + cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) + ) + """ + if not kwargs: + self._refresh_ttl(key) + return + + for _key, val in kwargs.items(): + if _key not in { + self.prompt_field_name, + self.vector_field_name, + self.response_field_name, + self.tag_field_name, + self.metadata_field_name, + }: + raise ValueError(f" {key} is not a valid field within document") + + # Check for metadata and deserialize + if _key == self.metadata_field_name: + if isinstance(val, dict): + kwargs[_key] = serialize(val) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + kwargs.update({self.updated_at_field_name: current_timestamp()}) + self._index.client.hset(key, mapping=kwargs) # type: ignore + self._refresh_ttl(key) diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 5f5cc882..eafb47ad 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -54,6 +54,6 @@ def serialize(data: Dict[str, Any]) -> str: return json.dumps(data) -def deserialize(self, data: str) -> Dict[str, Any]: +def deserialize(data: str) -> Dict[str, Any]: """Deserialize the input from a string.""" return json.loads(data) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index ef4ad7fe..b272ac30 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,11 +1,12 @@ from collections import namedtuple -from time import sleep +from time import sleep, time import pytest from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache from redisvl.index.index import SearchIndex +from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer @@ -89,6 +90,46 @@ def test_store_and_check(cache, vectorizer): assert "metadata" not in check_result[0] +def test_return_fields(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + cache.store(prompt, response, vector=vector) + + # check default return fields + check_result = cache.check(vector=vector) + assert set(check_result[0].keys()) == { + "id", + "_id", + "prompt", + "response", + "prompt_vector", + "vector_distance", + } + + # check all return fields + fields = [ + "id", + "_id", + "prompt", + "response", + "inserted_at", + "updated_at", + "prompt_vector", + "vector_distance", + ] + check_result = cache.check(vector=vector, return_fields=fields[:]) + assert set(check_result[0].keys()) == set(fields) + + # check only some return fields + fields = ["inserted_at", "updated_at"] + check_result = cache.check(vector=vector, return_fields=fields[:]) + fields.extend(["id", "vector_distance"]) # id and vector_distance always returned + assert set(check_result[0].keys()) == set(fields) + + +# Test clearing the cache def test_clear(cache, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -128,6 +169,65 @@ def test_ttl_refresh(cache_with_ttl, vectorizer): assert len(check_result) == 1 +# Test manual expiration of single document +def test_drop_document(cache, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + cache.store(prompt, response, vector=vector) + check_result = cache.check(vector=vector) + + cache.drop(check_result[0]["id"]) + recheck_result = cache.check(vector=vector) + assert len(recheck_result) == 0 + + +# Test manual expiration of multiple documents +def test_drop_documents(cache, vectorizer): + prompts = [ + "This is a test prompt.", + "This is also test prompt.", + "This is another test prompt.", + ] + responses = [ + "This is a test response.", + "This is also test response.", + "This is a another test response.", + ] + for prompt, response in zip(prompts, responses): + vector = vectorizer.embed(prompt) + cache.store(prompt, response, vector=vector) + + check_result = cache.check(vector=vector, num_results=3) + keys = [r["id"] for r in check_result[0:2]] # drop first 2 entries + cache.drop(keys) + + recheck_result = cache.check(vector=vector, num_results=3) + assert len(recheck_result) == 1 + + +# Test updating document fields +def test_updating_document(cache): + prompt = "This is a test prompt." + response = "This is a test response." + cache.store(prompt=prompt, response=response) + + check_result = cache.check(prompt=prompt, return_fields=["updated_at"]) + key = check_result[0]["id"] + + sleep(1) + + metadata = {"foo": "bar"} + cache.update(key=key, metadata=metadata) + + updated_result = cache.check( + prompt=prompt, return_fields=["updated_at", "metadata"] + ) + assert updated_result[0]["metadata"] == metadata + assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] + + def test_ttl_expiration_after_update(cache_with_ttl, vectorizer): prompt = "This is a test prompt." response = "This is a test response." @@ -279,3 +379,60 @@ def test_vector_size(cache, vectorizer): with pytest.raises(ValueError): cache.check(vector=[1, 2, 3]) + + +# test we can pass a list of tags and we'll include all results that match +def test_multiple_tags(cache): + tag_1 = "group 0" + tag_2 = "group 1" + tag_3 = "group 2" + tag_4 = "group 3" + tags = [tag_1, tag_2, tag_3, tag_4] + + filter_1 = Tag("label") == tag_1 + filter_2 = Tag("label") == tag_2 + filter_3 = Tag("label") == tag_3 + + for i in range(4): + prompt = f"test prompt {i}" + response = f"test response {i}" + cache.store(prompt, response, tag=tags[i]) + + # test we can specify one specific tag + results = cache.check("test prompt 1", tag_filter=filter_1, num_results=5) + assert len(results) == 1 + assert results[0]["prompt"] == "test prompt 0" + + # test we can pass a list of tags + combined_filter = filter_1 | filter_2 | filter_3 + results = cache.check("test prompt 1", tag_filter=combined_filter, num_results=5) + assert len(results) == 3 + + # test that default tag param searches full cache + results = cache.check("test prompt 1", num_results=5) + assert len(results) == 4 + + # test no results are returned if we pass a nonexistant tag + bad_filter = Tag("label") == "bad tag" + results = cache.check("test prompt 1", tag_filter=bad_filter, num_results=5) + assert len(results) == 0 + + +def test_complex_filters(cache): + cache.store(prompt="prompt 1", response="response 1") + cache.store(prompt="prompt 2", response="response 2") + sleep(1) + current_timestamp = time() + cache.store(prompt="prompt 3", response="response 3") + + # test we can do range filters on inserted_at and updated_at fields + range_filter = Num("inserted_at") < current_timestamp + results = cache.check("prompt 1", tag_filter=range_filter, num_results=5) + assert len(results) == 2 + + # test we can combine range filters and text filters + prompt_filter = Text("prompt") % "*pt 1" + combined_filter = prompt_filter & range_filter + + results = cache.check("prompt 1", tag_filter=combined_filter, num_results=5) + assert len(results) == 1