From 42988d55a722a5cdeaeee4e061150fc8fcb421fa Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 12 Aug 2024 13:08:15 -0400 Subject: [PATCH 1/4] update cache to accept empty metadata --- redisvl/extensions/llmcache/schema.py | 4 ++-- redisvl/extensions/llmcache/semantic.py | 2 +- tests/integration/test_llmcache.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 8075496b..515b1421 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -44,9 +44,9 @@ def non_empty_metadata(cls, v): def to_dict(self) -> Dict: data = self.dict(exclude_none=True) data["prompt_vector"] = array_to_buffer(self.prompt_vector) - if self.metadata: + if self.metadata is not None: data["metadata"] = serialize(self.metadata) - if self.filters: + if self.filters is not None: data.update(self.filters) del data["filters"] return data diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index f0f38a04..309e1747 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -300,7 +300,7 @@ def check( key = cache_search_result["id"] self._refresh_ttl(key) - print(cache_search_result, flush=True) + # print(cache_search_result, flush=True) # Create cache hit cache_hit = CacheHit(**cache_search_result) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 34c15113..03a6f0eb 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -295,6 +295,22 @@ def test_store_with_metadata(cache, vectorizer): assert check_result[0]["prompt"] == prompt +def test_store_with_empty_metadata(cache, vectorizer): + prompt = "This is another test prompt." + response = "This is another test response." + metadata = {} + vector = vectorizer.embed(prompt) + + cache.store(prompt, response, vector=vector, metadata=metadata) + check_result = cache.check(vector=vector, num_results=1) + + assert len(check_result) == 1 + print(check_result, flush=True) + assert check_result[0]["response"] == response + assert check_result[0]["metadata"] == metadata + assert check_result[0]["prompt"] == prompt + + def test_store_with_invalid_metadata(cache, vectorizer): prompt = "This is another test prompt." response = "This is another test response." From 16050b7ad84aa5912f01160af53caa2598af7169 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 12 Aug 2024 13:21:12 -0400 Subject: [PATCH 2/4] support dynamic distance threshold --- redisvl/extensions/llmcache/semantic.py | 10 ++++++++-- .../extensions/session_manager/semantic_session.py | 12 +++++++++--- tests/integration/test_llmcache.py | 2 +- tests/integration/test_session_manager.py | 4 ++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 309e1747..50522055 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -233,6 +233,7 @@ def check( num_results: int = 1, return_fields: Optional[List[str]] = None, filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -255,6 +256,8 @@ def check( filter_expression (Optional[FilterExpression]) : Optional filter expression that can be used to filter cache results. Defaults to None and the full cache will be searched. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -274,9 +277,12 @@ def check( if not (prompt or vector): raise ValueError("Either prompt or vector must be specified.") + # overrides + distance_threshold = distance_threshold or self._distance_threshold + return_fields = return_fields or self.return_fields vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) - return_fields = return_fields or self.return_fields if not isinstance(return_fields, list): raise TypeError("return_fields must be a list of field names") @@ -285,7 +291,7 @@ def check( vector=vector, vector_field_name=self.vector_field_name, return_fields=self.return_fields, - distance_threshold=self._distance_threshold, + distance_threshold=distance_threshold, num_results=num_results, return_score=True, filter_expression=filter_expression, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 5ce0318d..773f3fc5 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -137,6 +137,7 @@ def get_relevant( fall_back: bool = False, session_tag: Optional[str] = None, raw: bool = False, + distance_threshold: Optional[float] = None, ) -> Union[List[str], List[Dict[str, str]]]: """Searches the chat history for information semantically related to the specified prompt. @@ -151,10 +152,12 @@ def get_relevant( as_text (bool): Whether to return the prompts and responses as text or as JSON top_k (int): The number of previous messages to return. Default is 5. - fall_back (bool): Whether to drop back to recent conversation history - if no relevant context is found. session_tag (Optional[str]): Tag to be added to entries to link to a specific session. Defaults to instance uuid. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. + fall_back (bool): Whether to drop back to recent conversation history + if no relevant context is found. raw (bool): Whether to return the full Redis hash entry or just the message. @@ -169,6 +172,9 @@ def get_relevant( if top_k == 0: return [] + # override distance threshold + distance_threshold = distance_threshold or self._distance_threshold + return_fields = [ self.session_field_name, self.role_field_name, @@ -187,7 +193,7 @@ def get_relevant( vector=self._vectorizer.embed(prompt), vector_field_name=self.vector_field_name, return_fields=return_fields, - distance_threshold=self._distance_threshold, + distance_threshold=distance_threshold, num_results=top_k, return_score=True, filter_expression=session_filter, diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 03a6f0eb..23a41299 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer): vector = vectorizer.embed(prompt) cache.store(prompt, response, vector=vector) - check_result = cache.check(vector=vector) + check_result = cache.check(vector=vector, distance_threshold=0.4) assert len(check_result) == 1 print(check_result, flush=True) diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index d21e6651..56943447 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -463,6 +463,10 @@ def test_semantic_add_and_get_relevant(semantic_session): semantic_session.set_distance_threshold(0.5) default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system + assert default_context == semantic_session.get_relevant( + "list of fruits and vegetables", + distance_threshold=0.5 + ) # test tool calls can also be returned context = semantic_session.get_relevant("winter sports like skiing") From 40012b06c892a747cb1d0d52aa1331bf8fadced3 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 12 Aug 2024 13:22:12 -0400 Subject: [PATCH 3/4] Revert "support dynamic distance threshold" This reverts commit 16050b7ad84aa5912f01160af53caa2598af7169. --- redisvl/extensions/llmcache/semantic.py | 10 ++-------- .../extensions/session_manager/semantic_session.py | 12 +++--------- tests/integration/test_llmcache.py | 2 +- tests/integration/test_session_manager.py | 4 ---- 4 files changed, 6 insertions(+), 22 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 50522055..309e1747 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -233,7 +233,6 @@ def check( num_results: int = 1, return_fields: Optional[List[str]] = None, filter_expression: Optional[FilterExpression] = None, - distance_threshold: Optional[float] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -256,8 +255,6 @@ def check( filter_expression (Optional[FilterExpression]) : Optional filter expression that can be used to filter cache results. Defaults to None and the full cache will be searched. - distance_threshold (Optional[float]): The threshold for semantic - vector distance. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -277,12 +274,9 @@ def check( if not (prompt or vector): raise ValueError("Either prompt or vector must be specified.") - # overrides - distance_threshold = distance_threshold or self._distance_threshold - return_fields = return_fields or self.return_fields vector = vector or self._vectorize_prompt(prompt) - self._check_vector_dims(vector) + return_fields = return_fields or self.return_fields if not isinstance(return_fields, list): raise TypeError("return_fields must be a list of field names") @@ -291,7 +285,7 @@ def check( vector=vector, vector_field_name=self.vector_field_name, return_fields=self.return_fields, - distance_threshold=distance_threshold, + distance_threshold=self._distance_threshold, num_results=num_results, return_score=True, filter_expression=filter_expression, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 773f3fc5..5ce0318d 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -137,7 +137,6 @@ def get_relevant( fall_back: bool = False, session_tag: Optional[str] = None, raw: bool = False, - distance_threshold: Optional[float] = None, ) -> Union[List[str], List[Dict[str, str]]]: """Searches the chat history for information semantically related to the specified prompt. @@ -152,12 +151,10 @@ def get_relevant( as_text (bool): Whether to return the prompts and responses as text or as JSON top_k (int): The number of previous messages to return. Default is 5. - session_tag (Optional[str]): Tag to be added to entries to link to a specific - session. Defaults to instance uuid. - distance_threshold (Optional[float]): The threshold for semantic - vector distance. fall_back (bool): Whether to drop back to recent conversation history if no relevant context is found. + session_tag (Optional[str]): Tag to be added to entries to link to a specific + session. Defaults to instance uuid. raw (bool): Whether to return the full Redis hash entry or just the message. @@ -172,9 +169,6 @@ def get_relevant( if top_k == 0: return [] - # override distance threshold - distance_threshold = distance_threshold or self._distance_threshold - return_fields = [ self.session_field_name, self.role_field_name, @@ -193,7 +187,7 @@ def get_relevant( vector=self._vectorizer.embed(prompt), vector_field_name=self.vector_field_name, return_fields=return_fields, - distance_threshold=distance_threshold, + distance_threshold=self._distance_threshold, num_results=top_k, return_score=True, filter_expression=session_filter, diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 23a41299..03a6f0eb 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer): vector = vectorizer.embed(prompt) cache.store(prompt, response, vector=vector) - check_result = cache.check(vector=vector, distance_threshold=0.4) + check_result = cache.check(vector=vector) assert len(check_result) == 1 print(check_result, flush=True) diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..d21e6651 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -463,10 +463,6 @@ def test_semantic_add_and_get_relevant(semantic_session): semantic_session.set_distance_threshold(0.5) default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system - assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 - ) # test tool calls can also be returned context = semantic_session.get_relevant("winter sports like skiing") From d958fdc8168ee6f2f19e931676917224d8a8c673 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 16 Aug 2024 09:23:36 -0400 Subject: [PATCH 4/4] remove debug line --- redisvl/extensions/llmcache/semantic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 309e1747..1b78ea9e 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -300,8 +300,6 @@ def check( key = cache_search_result["id"] self._refresh_ttl(key) - # print(cache_search_result, flush=True) - # Create cache hit cache_hit = CacheHit(**cache_search_result) cache_hit_dict = {