From 1228718b9762a06f49aefeef3f3419d641c5d6b1 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 23 Aug 2024 13:27:44 -0400 Subject: [PATCH 1/3] check for existing index first --- redisvl/extensions/llmcache/semantic.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 9845d925..cd6c99a1 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -37,6 +37,7 @@ def __init__( redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, + overwrite: bool = False, **kwargs, ): """Semantic Cache for Large Language Models. @@ -57,6 +58,8 @@ def __init__( redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. connection_kwargs (Dict[str, Any]): The connection arguments for the redis client. Defaults to empty {}. + overwrite (bool): Whether or not to force overwrite the schema for + the semantic cache index. Defaults to false. Raises: TypeError: If an invalid vectorizer is provided. @@ -99,10 +102,24 @@ def __init__( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) + # Check for existing cache index? + if not overwrite: + if self._index.exists(): + existing_index = SearchIndex.from_existing( + name, redis_client=self._index.client + ) + if existing_index.schema != self._index.schema: + raise ValueError( + f"Existing index {name} schema does not match the user provided schema for the semantic cache. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + # Initialize other components self._set_vectorizer(vectorizer) self.set_threshold(distance_threshold) - self._index.create(overwrite=False) + + # Create the index + self._index.create(overwrite=overwrite, drop=False) def _modify_schema( self, From 3e73b74e49b80d343d27817e49d86c7cbc217704 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 23 Aug 2024 14:37:55 -0700 Subject: [PATCH 2/3] adds index update test. minor index check formatting --- redisvl/extensions/llmcache/semantic.py | 20 ++++++------ tests/integration/test_llmcache.py | 43 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index cd6c99a1..17856196 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -65,6 +65,7 @@ def __init__( TypeError: If an invalid vectorizer is provided. TypeError: If the TTL value is not an int. ValueError: If the threshold is not between 0 and 1. + ValueError: If existing schema does not match new schema and overwrite is False. """ super().__init__(ttl) @@ -102,17 +103,16 @@ def __init__( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - # Check for existing cache index? - if not overwrite: - if self._index.exists(): - existing_index = SearchIndex.from_existing( - name, redis_client=self._index.client + # Check for existing cache index + if not overwrite and self._index.exists(): + existing_index = SearchIndex.from_existing( + name, redis_client=self._index.client + ) + if existing_index.schema != self._index.schema: + raise ValueError( + f"Existing index {name} schema does not match the user provided schema for the semantic cache. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." ) - if existing_index.schema != self._index.schema: - raise ValueError( - f"Existing index {name} schema does not match the user provided schema for the semantic cache. " - "If you wish to overwrite the index schema, set overwrite=True during initialization." - ) # Initialize other components self._set_vectorizer(vectorizer) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 23a41299..e4fac763 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -513,3 +513,46 @@ def test_complex_filters(cache_with_filters): "prompt 1", filter_expression=combined_filter, num_results=5 ) assert len(results) == 1 + + +def test_index_updating(redis_url): + cache_no_tags = SemanticCache( + name="test_cache", + redis_url=redis_url, + ) + + cache_no_tags.store( + prompt="this prompt has no tags", + response="this response has no tags", + filters={"some_tag": "abc"}, + ) + + # filterable_fields not defined in schema, so no tags will match + tag_filter = Tag("some_tag") == "abc" + + response = cache_no_tags.check( + prompt="this prompt has no tag", + filter_expression=tag_filter, + ) + assert response == [] + + with pytest.raises(ValueError): + cache_with_tags = SemanticCache( + name="test_cache", + redis_url=redis_url, + filterable_fields=[{"name": "some_tag", "type": "tag"}], + ) + + cache_overwrite = SemanticCache( + name="test_cache", + redis_url=redis_url, + filterable_fields=[{"name": "some_tag", "type": "tag"}], + overwrite=True, + ) + + response = cache_overwrite.check( + prompt="this prompt has no tag", + filter_expression=tag_filter, + ) + + assert len(response) == 1 From 1c019cfb5e5fadfc7888a051466b4914b7daca8b Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 23 Aug 2024 14:42:30 -0700 Subject: [PATCH 3/3] formatting --- tests/integration/test_llmcache.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index e4fac763..2263b745 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -522,8 +522,8 @@ def test_index_updating(redis_url): ) cache_no_tags.store( - prompt="this prompt has no tags", - response="this response has no tags", + prompt="this prompt has tags", + response="this response has tags", filters={"some_tag": "abc"}, ) @@ -531,7 +531,7 @@ def test_index_updating(redis_url): tag_filter = Tag("some_tag") == "abc" response = cache_no_tags.check( - prompt="this prompt has no tag", + prompt="this prompt has a tag", filter_expression=tag_filter, ) assert response == [] @@ -551,8 +551,7 @@ def test_index_updating(redis_url): ) response = cache_overwrite.check( - prompt="this prompt has no tag", + prompt="this prompt has a tag", filter_expression=tag_filter, ) - assert len(response) == 1